深入理解Java併發(3) 手寫實現簡單線程池

    一、 架構分析

    我們來實現一個簡單的線程池,先看一下都需要實現哪些部分。

    我們對需求進行一下簡單的抽象分析,就可以發現這個跟生產者/消費者模型有點像,我們可以畫個圖,邊分析邊往上加

     1. 以生產者/消費者模型爲基礎創建架構。需要一個隊列taskQueue 用以緩存要執行的任務,通過execute()方法從外部向線程池提交任務,封裝一個方法用來決策,是立即創建線程、放入隊列還是拒絕。

    2.創建線程可以用工廠設計模式封裝一個ThreadFactory內部類。

    3. 拒絕策略可以簡單封裝一個RejectedHandler類,當線程數已達maxPoolSize且taskQueue滿時拒絕添加新的任務。

    4. 封裝一個runTask()方法。在方法中需要不斷地利用getTask()方法來從taskQueue中獲取任務並執行。一個線程執行完任務就會銷燬,爲了防止頻繁地創建和銷燬線程,這裏需要用while()循環不斷getTask()並執行。 

    5. 關閉並銷燬線程池。


二、具體代碼實現

package threadpooldemo;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

public class MyThreadPool {
    final static int MAX_QUEUE_SIZE = 100; //默認的最大任務隊列長度
    private volatile long corePoolSize;   //加volatile是爲了後面支持set擴展
    private volatile long maximumPoolSize;
    private BlockingQueue<Runnable> taskQueue = new LinkedBlockingQueue<>(MAX_QUEUE_SIZE);
    private volatile long keepAliveTime;
    private static final ThreadFactory threadFactory = new MyThreadFactory();
    private static final MyRejectedExecutionHandler defaultHandler = new MyRejectedExecutionHandler();
    private volatile boolean running = true;
    private volatile boolean allowCoreThreadTimeOut = false;
    private final AtomicLong threadNum = new AtomicLong(0);
      
    /**
     *線程池構造方法
     */
    public MyThreadPool(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit) {
        //core 線程數不允許爲0
        if (corePoolSize <= 0 ||
                    maximumPoolSize <= 0 ||
                    maximumPoolSize < corePoolSize ||
                    keepAliveTime < 0)
                    throw new IllegalArgumentException();
           
        this.corePoolSize = corePoolSize;
        this.maximumPoolSize = maximumPoolSize;
        this.keepAliveTime = unit.toNanos(keepAliveTime);
    }
    
    /**
     * 對Runnable task進行包裝,在其run()方法中調用runTask方法,以便在runTask中可以反覆從taskQueue中取任務執行
     */
    private class WrappedTask implements Runnable{
        private Thread thread;
        private Runnable task;
        
        public WrappedTask(Runnable task) {
            this.thread = threadFactory.newThread(this); //newThread(this),start時會自動調用run()
            this.task = task;
        }
        
        public void startTask(){
            this.thread.start();
        }
        
        public Runnable getTask(){
            return task;
        }

        @Override
        public void run() {
            runTask(this);
        }       
    }

    /**
     * 外部提交任務到線程池,選擇處理策略
     */
    public void execute(Runnable task){
        if (null == task) {
            throw new NullPointerException();
        }
        
        //線程數未到corePoolSize,直接創建新線程去幹活
        if (running && (threadNum.longValue() < corePoolSize)){
            addTask(task);
            return;
        }
        
        boolean added = taskQueue.offer(task);
        if (added) {    
            return;
        }
        //添加到隊列不成功,threadNum與maxPoolNum比較,不超出則創建非core線程,超出則reject
        if(running && (false == added)){
            if(threadNum.longValue() < maximumPoolSize){
                addTask(task);
                return;
            }
            defaultHandler.rejectedExecution(task);
        }
        defaultHandler.otherRejectedExecution(task); 
    }
    
    /**
     * 線程執行過程
     * @param wrappedTask
     */
    public void runTask(WrappedTask wrappedTask){
        Runnable task = wrappedTask.getTask();
        while(running && ((null != task) || (null != (task = getTask())))){
            task.run();
            task = null; //將task置null,以免初始不爲null導致後面 null != task始終成立 將 ||後面邏輯短路
        }
        threadNum.decrementAndGet();
    }
    
    public void addTask(Runnable firstTask){
        //前面已經檢查過參數,不可能爲null
        threadNum.incrementAndGet();
        WrappedTask wTask = new WrappedTask(firstTask);
        wTask.startTask();
    }
    
    /**
     * 從阻塞的任務隊列taskQueue中獲取待執行的任務
     */
    public Runnable getTask(){
        Runnable task = null;
        boolean timedOut = false;
        
        while(true){
            //非核心線程是否已經超時
            boolean NonCoretimed = allowCoreThreadTimeOut || threadNum.longValue() > corePoolSize;
            if(threadNum.longValue() > maximumPoolSize || (NonCoretimed && timedOut)){
                return null;
            }
             
            try {
                task = allowCoreThreadTimeOut ? taskQueue.poll(keepAliveTime, TimeUnit.NANOSECONDS): taskQueue.take();
                if(null != task){
                    return task;
                }
                timedOut = true;
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }

    public void terminate(){
        running = false;
    }
    
    /**
     * 是否允許core線程超時,若爲true,則當獲取task超時後,線程會銷燬
     */
    public void allowCoreThreadTimeOut(boolean flag){
        this.allowCoreThreadTimeOut = flag;
    }
    
    /**
     * 拒絕策略
     */
    private static class MyRejectedExecutionHandler{
        
        public MyRejectedExecutionHandler(){}

        public void rejectedExecution(Runnable r) {
            throw new RejectedExecutionException("Task " + r.toString() +
                    " rejected from myThreadPool");
        }
        
        public void otherRejectedExecution(Runnable r){
            throw new RejectedExecutionException("Task " + r.toString() +
                    " rejected from myThreadPool due to other reasons, maybe threadPool stopped.");
        }
    }
    
    /**
     * 線程工廠
     */
    private static class MyThreadFactory implements ThreadFactory {
        private final AtomicInteger threadNumber = new AtomicInteger(1);
        private final String namePrefix = "myThreadPool-";
        
        public MyThreadFactory() {}
        
        @Override
        public Thread newThread(Runnable r) {
            Thread t = new Thread(r, namePrefix + threadNumber.getAndIncrement());
            if (t.isDaemon()){
                t.setDaemon(false);
            }

            if (t.getPriority() != Thread.NORM_PRIORITY){
                t.setPriority(Thread.NORM_PRIORITY);
            }
            return t;
        }
    }
}

 

 


 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章