Java-多文件多线程下载器的设计与实现(线程池)

1、多线程下载的原因和我们的目标

多线程多文件下载是一个常见的需求,一些服务器为了负载均衡,往往会给每个请求线程设置最大的带宽,因此线程数量有时候也成为制约我们下载网络的原因之一。这类限制我们其实经常遇到,比如说百度网盘就是通过限制单线程下载速度来实现限速。

我们知道,线程的创建是比较消耗系统资源的操作,频繁的,无限制的创建线程可能会导致计算机资源被耗尽等极端情况。Java里面则提供了JUC库为我们方便的实现并发操作带来了可能性。使用线程池技术能够很好的管理线程,因此本方案采用线程池管理下载线程

我们的最终目的是实现只需要指定下载链接和线程数量即可实现多线程下载,并且可以同时对多个文件进行多线程下载,同时也支持断点续传能力。

2、下载派发器设计

下载派发器的作用是调度每个任务,在里面我们定义一个统一的线程池。

对该下载框架,我们需要尽量的简洁,相信每个接触Android开发的人都会被EventBus框架的简洁所震撼,而实现这样简洁的操作,我们需要使用一些设计模式。

我所设计的ConcurrentDownloader即采用了单例模式,以此对线程进行统一的管理,很好的简化了使用流程,降低使用门槛。

public class DownloadDispatcher {
    /**
     * 线程池最大运行线程数量
     */
    private static int DEFAULT_THREAD_COUNT=128;
    private ExecutorService executorService;
    /**
     * 双检锁单例模式
     */
    private volatile static DownloadDispatcher DEFAULT_INSTANCE;
    public static DownloadDispatcher getDefault(){
        if(DEFAULT_INSTANCE==null){
            synchronized (DownloadDispatcher.class){
                if(DEFAULT_INSTANCE==null){
                    DEFAULT_INSTANCE=new DownloadDispatcher();
                }
            }
        }
        return DEFAULT_INSTANCE;
    }

    private DownloadDispatcher() {
        this(DEFAULT_THREAD_COUNT);
    }

    /**
     * 创建一个线程池
     * @param threadCount
     */
    public DownloadDispatcher(int threadCount) {
        executorService=new ThreadPoolExecutor(threadCount
                ,threadCount
                ,0
                , TimeUnit.SECONDS
                ,new ArrayBlockingQueue<Runnable>(threadCount)
                ,new ConcurrentThreadFactory()
                ,new ExceedHandler());
    }

    /**
     * 线程工厂
     */
    private class ConcurrentThreadFactory implements ThreadFactory {
        public Thread newThread(Runnable r) {
            return new Thread(r);
        }
    }

    /**
     * 包和策略
     */
    private class ExceedHandler implements RejectedExecutionHandler {

        public void rejectedExecution(Runnable r, ThreadPoolExecutor executor) {
            System.out.println("超过最大线程数");
        }
    }

    /**
     * 创建一个下载任务
     * @param downloadLink
     * @param threadNum
     * @param listener
     */
    public void dispatchNewTask(String downloadLink,int threadNum,DownloadListener listener){
        final ConcurrentDownloader downloader=new ConcurrentDownloader(this.executorService);
        downloader.startDownload(downloadLink,threadNum,listener);
    }

}


3、下载接口的设计

接口设计应当体现下载过程的特色,以及我们的行为规范,为了方便对每个任务进行控制,每个任务都需要自己的接口,每个任务的子线程共用同一个接口,我们实现下载完成通知,以及断点续传功能就需要体现在接口中。

public interface ConcurrentDownloadListener {
    void onSuccess(int threadId);
    void onProgress(long progress);

    /**
     * 暂停后记录下载长度
     * @param threadId
     * @param startPos
     * @param downloadedLen
     * @param endPos
     */
    void onPause(String fileName,int threadId,long startPos,long downloadedLen,long endPos);
    void onFailed(int reason);
    void onCanceled(File file);
}

对下载状态的控制我们还需要一个状态机表示状态切换,这一点在下一节给出代码,本节给出下载的各种状态

public enum DownloadStatus{
    SUCCESS,
    FAILED,
    CANCELED,
    PAUSE,
    PROGRESS
}

4、单文件多线程下载的实现

在2里面,我们对文件的下载请求进行派发处理,在这里,我们将针对某个具体的文件的下载进行实现。

public class ConcurrentDownloader {
    private static final String TAG="ConcurrentDownloader";
    protected DownloadStatus downloadStatus;

    private ExecutorService executorService;

    public ConcurrentDownloader(ExecutorService executorService) {
        this.executorService=executorService;
    }



    private DownloadListener mainDownloadListener;
    /**
     * 执行当前任务的总线程数量
     */
    private int totalThreadCount;
    /**
     * 当前完成操作的线程
     */
    private int finishedThreadCount=0;
    /**
     * 暂停列表
     */
    private List<LocalFileInfo> pauseList=new ArrayList<>();
    /**
     * 已下载的文件长度
     */
    private volatile long downloadedLength=0;
    /**
     * 同步锁
     */
    private static final String LOCK="downloadLock";

    /**
     * 每个线程共用这一个回调接口
     */
    private ConcurrentDownloadListener partDownloadListener=new ConcurrentDownloadListener() {
        public void onSuccess(int threadId) {
            System.out.println("线程:"+threadId+"下载完成");
            finishedThreadCount++;
            if(totalThreadCount==finishedThreadCount){
                mainDownloadListener.onSuccess(-1);
                executorService.shutdown();
            }

        }

        public void onProgress(long progress) {
            synchronized (LOCK){
                downloadedLength+=progress;
            }

            if(downloadedLength<1024){
                Log.i(TAG,"已下载:"+downloadedLength+"B");
            }else if(downloadedLength<1024*1024){
                Log.i(TAG,"已下载:"+(downloadedLength>>>10)+"KB");
            }else if(downloadedLength<1024*1024*1024){
                Log.i(TAG,"已下载:"+(downloadedLength>>>20)+"MB");
            }else {
                Log.i(TAG,"已下载:"+(downloadedLength>>>30)+"GB");
            }

        }

        /**
         * 暂停下载后,讲下载信息写入本地文件,等待恢复
         * @param filename
         * @param threadId
         * @param startPos
         * @param downloadedLen
         * @param endPos
         */
        public synchronized void onPause(String filename,int threadId,long startPos,long downloadedLen,long endPos) {
            LocalFileInfo localFileInfo=new LocalFileInfo(filename,threadId,startPos,downloadedLen,endPos);
            pauseList.add(localFileInfo);
            if(pauseList.size()==totalThreadCount){
                SerializationHelper.writeToDisk(pauseList);
                System.out.println("全部暂停,序列化进度到本地文件");
                executorService.shutdown();
            }
        }

        public void onFailed(int reason) {
            executorService.shutdown();
        }

        public void onCanceled(File file) {
            finishedThreadCount++;
            if(finishedThreadCount==totalThreadCount){
                file.delete();
            }
            executorService.shutdown();
        }
    };

    /**
     * 暴露给派发器的接口,开始下载
     * @param links
     * @param threadCount
     * @param listener
     */
    public void startDownload(String links,int threadCount,DownloadListener listener){
        mainDownloadListener=listener;
        downloadStatus=PROGRESS;
        totalThreadCount=threadCount;
        /**
         * 获取需要下载的文件长度
         */
        long totalLen=getTotalLength(links);
        File file=new File("D:/"+ links.substring(links.lastIndexOf("/")));
        /**
         * 如果本地存在文件,那就从本地恢复信息,否则新创建下载信息
         */
        if(file.exists()){
            if(file.length()!=totalLen){
                Log.i(TAG,"本地文件与服务器文件不一致");
            }else {
                restartDownload(links,SerializationHelper.restoreDownloadInfo(file.getAbsolutePath()),file,partDownloadListener);
            }
        }else {
            file=createFileByLength(file,totalLen);
            requestNewDownload(links,totalLen,threadCount,file,partDownloadListener);
        }
    }

    /**
     * 本地创建一个文件
     * @param file
     * @param fileLen
     * @return
     */
    private File createFileByLength(File file,long fileLen){
        try {
            file.createNewFile();
            RandomAccessFile raf=new RandomAccessFile(file,"rwd");
            raf.setLength(fileLen);
            raf.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return file;
    }

    /**
     * 获取文件总长度
     * @param link
     * @return
     */
    private long getTotalLength(String link){
        long len=-1;
        try {
            URL url = new URL(link);
            HttpURLConnection conn=(HttpURLConnection)url.openConnection();
            conn.setRequestMethod("GET");
            len=conn.getContentLengthLong();
            conn.disconnect();
        } catch (Exception e) {
            e.printStackTrace();
        }
        return len;
    }

    /**
     * 从本地恢复暂停的线程信息
     * @param downloadLink
     * @param infos
     * @param file
     * @param listener
     */
    private void restartDownload(String downloadLink,List<LocalFileInfo> infos,File file,ConcurrentDownloadListener listener){
        for (LocalFileInfo info:infos){
            executorService.submit(new PartDownloadRunnable(info.getThreadId(),downloadLink,info.getStartPos()+info.getDownloadedLen(),info.getEndPos(),file,listener));
        }
    }

    /**
     * 新创建
     * @param downloadLink
     * @param totalLen
     * @param threadCount
     * @param file
     * @param listener
     */
    
    private void requestNewDownload(String downloadLink,long totalLen,int threadCount,File file,ConcurrentDownloadListener listener){
        long partLen=(totalLen+threadCount)/threadCount;
        long startPos;
        long endPos;
        System.out.println("总长度:"+totalLen+"\t块长度:"+partLen);
        for (int i=0;i<threadCount;i++){
            startPos=i*partLen;
            endPos=startPos+partLen-1>totalLen?totalLen-1:startPos+partLen-1;
            System.out.println("线程:"+i+"\tstart:"+startPos+"\tend:"+endPos);
            executorService.submit(new PartDownloadRunnable(i,downloadLink,startPos,endPos,file,listener));
        }
    }

    /**
     * 执行下载的线程实体类
     */
    class PartDownloadRunnable implements Runnable{
        private final int threadId;
        private final String downloadLink;
        private final long startPos;
        private final long endPos;
        private ConcurrentDownloadListener listener;
        private File file;

        public PartDownloadRunnable(int threadId, String downloadLink, long startPos, long endPos, File file,ConcurrentDownloadListener listener) {
            this.threadId = threadId;
            this.downloadLink = downloadLink;
            this.startPos = startPos;
            this.endPos = endPos;
            this.listener=listener;
            this.file = file;
        }

        public void run() {
            long downloadedLength=0;
            try {
                URL url=new URL(downloadLink);
                HttpURLConnection conn=(HttpURLConnection)url.openConnection();
                conn.setRequestMethod("GET");
                conn.setRequestProperty("Range","bytes="+startPos+"-"+endPos);
                RandomAccessFile raf=new RandomAccessFile(file,"rwd");
                raf.seek(startPos);
                BufferedInputStream reader=new BufferedInputStream(conn.getInputStream());
                byte[] b=new byte[4096];
                int len;
                while ((len=reader.read(b))!=-1){
                    if(downloadStatus==CANCELED){
                        raf.close();
                        reader.close();
                        conn.disconnect();
                        listener.onCanceled(file);
                        return;
                    }else if(downloadStatus==PAUSE){
                        listener.onPause(file.getAbsolutePath(),threadId,startPos,downloadedLength,endPos);
                        raf.close();
                        reader.close();
                        conn.disconnect();
                        return;
                    }else {
                        raf.write(b,0,len);
                    }
                    downloadedLength+=len;
                    listener.onProgress(len);
                }
                raf.close();
                reader.close();
                conn.disconnect();
                listener.onSuccess(threadId);

            } catch (Exception e) {
                listener.onFailed(-1);
            }
        }
    }


    /**
     * 取消下载,暴露给外部调用
     */
    public void cancelDownload(){
        downloadStatus=CANCELED;
    }

    /**
     * 暂停下载
     */
    public void pauseDownload(){
        downloadStatus=PAUSE;
    }
    
}

5、断点续传的实现

断点续传是下载器一个很有必要的功能,我们观察百度网盘,迅雷等下载工具,可以发现,其在下载时都是生成两个文件,一个是要下载的文件,另一个是记录下载信息的文件,本设计也采用相同的思路,专门用一个文件来保存下载进度。

写一个文件实现信息保存我们可以有很多方式,比如大家常用的JSON,解析JSON是一个很好的解决办法,但我这里并未采用,因为JSON数据格式很容易被读懂和篡改,我采用的方式是利用对象序列化来实现保存下载进度。

首先是我们需要保存哪些信息,在下面的类里面已经给出。

public class LocalFileInfo implements Serializable {
    private static final long serialVersionUID=0x123453245145L;
    private String filename;
    private int threadId;
    private long startPos;
    private long downloadedLen;
    private long endPos;

    public LocalFileInfo(String filename, int threadId, long startPos, long downloadedLen, long endPos) {
        this.filename = filename;
        this.threadId = threadId;
        this.startPos = startPos;
        this.downloadedLen = downloadedLen;
        this.endPos = endPos;
    }

    public String getFilename() {
        return filename;
    }

    public void setFilename(String filename) {
        this.filename = filename;
    }

    public int getThreadId() {
        return threadId;
    }

    public void setThreadId(int threadId) {
        this.threadId = threadId;
    }

    public long getStartPos() {
        return startPos;
    }

    public void setStartPos(long startPos) {
        this.startPos = startPos;
    }

    public long getDownloadedLen() {
        return downloadedLen;
    }

    public void setDownloadedLen(long downloadedLen) {
        this.downloadedLen = downloadedLen;
    }

    public long getEndPos() {
        return endPos;
    }

    public void setEndPos(long endPos) {
        this.endPos = endPos;
    }

    @Override
    public String toString() {
        return "{" +
                "\"filename\":\"" + filename + "\"" +
                ", \"threadId\":\"" + threadId + "\"" +
                ", \"startPos\":\"" + startPos + "\"" +
                ", \"downloadedLen\":\"" + downloadedLen + "\"" +
                ", \"endPos\":\"" + endPos + "\"" +
                "}";
    }

然后是如何实现序列化。经过考虑,我决定写一个工具类实现帮我们实现序列化,一个是将对象写入本地文件,另一个是从本地文件读取对象。

public class SerializationHelper {
    public static void writeToDisk(List<LocalFileInfo> info){
        File file=new File(info.get(0).getFilename()+".download_config.json");
        try {
            if(file.exists()){
                file.delete();
            }
            file.createNewFile();
            LocalFileInfo[] localFileInfos=new LocalFileInfo[info.size()];
            info.toArray(localFileInfos);
            ObjectOutputStream oos=new ObjectOutputStream(new FileOutputStream(file));
            oos.writeObject(localFileInfos);
            oos.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static List<LocalFileInfo> restoreDownloadInfo(String filename){
        List<LocalFileInfo> localFileInfoList=new ArrayList<LocalFileInfo>();
        filename+=".download_config.json";
        File file=new File(filename);
        try {
            ObjectInputStream ois=new ObjectInputStream(new FileInputStream(file));
            LocalFileInfo[] infos=(LocalFileInfo[])ois.readObject();
            localFileInfoList= Arrays.asList(infos);
            ois.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
        localFileInfoList.sort((o1, o2) -> o2.getThreadId()-o1.getThreadId());
        return localFileInfoList;
    }
}

6、如何使用

ConcurrentDownloader的使用是非常简单的,下面给大家一个测试用例

public class Main {
    public static void main(String[] args) {
        String link2004="https://mirrors.tuna.tsinghua.edu.cn/mysql/downloads/MySQL-8.0/mysql-community-server-core_8.0.20-2ubuntu20.04_amd64.deb";
        String link1804="https://mirrors.tuna.tsinghua.edu.cn/mysql/downloads/MySQL-8.0/mysql-community-server-core_8.0.20-1ubuntu18.04_amd64.deb";
        DownloadDispatcher.getDefault().dispatchNewTask(link2004,10,listener);
        DownloadDispatcher.getDefault().dispatchNewTask(link1804,10,listener);
    }

    private static DownloadListener listener=new DownloadListener() {

        @Override
        public void onSuccess(int reason) {
            System.out.println("全部完成");
        }

        @Override
        public void onPause(LocalFileInfo info) {

        }
    };
}

7、总结

上述代码并非全部代码,完整工程是一个maven工程,大家可以去我的GitHub下载:

https://github.com/bestyize/ConcurrentDownloader

(欢迎star)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章