手写实现线程池,思路与ExecutorService一致

主要定义一个线程池应该具备的基本操作和方法
public interface ThreadPool {

    /**
     * 提交任务到线程池
     *
     * @param runnable
     */
    void execute(Runnable runnable);

    /**
     * 关闭线程池
     */
    void shutdown();

    /**
     * 初始化大小
     */
    int getInitSize();

    /**
     * 获取线程池最大线程数
     *
     * @return
     */
    int getMaxsize();

    /**
     * 获取线程池最大的线程数
     *
     * @return
     */
    int getCoreSize();

    /**
     * 获取线程池中用于缓存任务队列的大小
     *
     * @return
     */
    int getQuenSize();

    /**
     * 获取线程池中活跃线程的数量
     *
     * @return
     */
    int getActiveCount();

    /**
     * 查询线程池是否已经关闭
     *
     * @return
     */
    boolean isShutdown();


}
主要存放用于提交的Runnable,并且有limit限制
public interface RunnableQueue {

    /**
     * 当有新的任务进来时,放到队列的末尾
     *
     * @param offer
     */
    void offer(Runnable offer);

    /**
     * 工作线程通过task获取任务
     *
     * @return
     */
    Runnable task();

    /**
     * 获取线程池缓存任务的数量
     *
     * @return
     */
    int size();
}
 当RunnableQueue超出范围时,使用拒绝策略
@FunctionalInterface
public interface DenyPolicy {


    void reject(Runnable runnable, ThreadPool threadPool);

    /**
     * 直接丢弃任务
     */
    class DiscardDenyPolicy implements DenyPolicy {

        @Override
        public void reject(Runnable runnable, ThreadPool threadPool) {
            //TODO
        }
    }

    /**
     * 该拒绝策略会向任务提交者抛出异常
     */
    class AbortDenyPolicy implements DenyPolicy {

        @Override
        public void reject(Runnable runnable, ThreadPool threadPool) {
            throw new RunnableDenyPolicyException("");
        }
    }

    //该拒绝策略会使任务在提交者所在的线程执行
    class RunnerDenyPolicy implements DenyPolicy {

        @Override
        public void reject(Runnable runnable, ThreadPool threadPool) {
            if (!threadPool.isShutdown()) {
                runnable.run();
            }
        }
    }


}
public class RunnableDenyPolicyException extends RuntimeException {
    public RunnableDenyPolicyException(String message) {
        super(message);
    }
}
 主要用于线程池内部,不断从queue中取出某个runnable,并运行run方法
public class InternalTask implements Runnable {

    /**
     * 任务队列
     */
    private final RunnableQueue runnableQueue;

    /**
     * 标记
     */
    private volatile boolean running = true;

    public InternalTask(RunnableQueue runnableQueue) {
        this.runnableQueue = runnableQueue;
    }


    @Override
    public void run() {

        //当前任务为running并且没有被中断,则不从队列获取下一个任务,
        while (running && !Thread.currentThread().isInterrupted()) {
            try {
                //获取任务
                Runnable task = runnableQueue.task();
                task.run();
            } catch (Exception e) {
                running = false;
                break;
            }
        }
    }

    //停止当前任务
    public void stop() {
        this.running = false;
    }
}

public class LinkendRunnableQueue implements RunnableQueue {

    /**
     * 队列最大长度
     */
    private final int limit;
    /**
     * 拒绝策略
     */
    private final DenyPolicy denyPolicy;
    /**
     * 存放任务
     */
    private final LinkedList<Runnable> runnableList = new LinkedList<>();

    /**
     * 线程池
     */
    private final ThreadPool threadPool;

    public LinkendRunnableQueue(int limit, DenyPolicy denyPolicy, ThreadPool threadPool) {
        this.limit = limit;
        this.denyPolicy = denyPolicy;
        this.threadPool = threadPool;
    }

    @Override
    public void offer(Runnable runnable) {
        synchronized (runnableList) {
            if (runnableList.size() >= limit) {
                //超过限制,启动拒绝策略
                this.denyPolicy.reject(runnable, threadPool);
            } else {
                //将入到队列
                runnableList.addLast(runnable);
                runnableList.notifyAll();
            }
        }
    }

    @Override
    public Runnable task() {
        synchronized (runnableList) {
            while (runnableList.isEmpty()) {
                try {
                    //如果没有任务进来,阻塞
                    runnableList.wait();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
            //从任务头移除一个任务
           return runnableList.removeFirst();
        }
    }

    @Override
    public int size() {
        //返回线程池数量
        synchronized (runnableList) {
           return runnableList.size();
        }
    }
}
@FunctionalInterface
public interface ThreadFactory {


    /**
     * 用于创建线程
     *
     * @param runnable
     * @return
     */
    Thread createThread(Runnable runnable);

}
 初始化线程池,需要有数量控制、创建线程工厂、队列策略等

import java.util.ArrayDeque;
import java.util.Queue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
 
public class BasicThreadPool extends Thread implements ThreadPool {

    /**
     * 初始化线程数量
     */
    private final int initSize;

    /**
     * 线程池最大线程数量
     */
    private final int maxSize;

    /**
     * 线程池核心线程数量
     */
    private final int coreSize;

    /**
     * 当前活跃的线程数量
     */
    private int activeCount;

    /**
     * 线程工厂
     */
    private final ThreadFactory threadFactory;

    /**
     * 任务队列
     */
    private final RunnableQueue runnableQueue;
    /**
     * 是否停止线程池
     */
    private volatile boolean isShutdown = false;

    /**
     * 时间
     */
    private final long keepAliveTime;
    /**
     * 睡眠工具类
     */
    private final TimeUnit timeUnit;

    /**
     * 任务队列
     */
    private final Queue<ThreadTask> threadQueue = new ArrayDeque();

    private final static DenyPolicy DEFAULT_DENY_POLICY = new DenyPolicy.DiscardDenyPolicy();

    private final static ThreadFactory DEFAULT_THREAD_FACTORY = new DefaultThreadFactory();

    public BasicThreadPool(int initSize, int maxSize, int coreSize, int queueSize) {
        this(initSize, maxSize, coreSize, DEFAULT_THREAD_FACTORY, queueSize, DEFAULT_DENY_POLICY, 10, TimeUnit.SECONDS);
    }

    public BasicThreadPool(int initSize, int maxSize, int coreSize, ThreadFactory defaultThreadFactory,
                           int queueSize, DenyPolicy defaultDenyPolicy, int keepAliveTime, TimeUnit timeUnit) {
        this.initSize = initSize;
        this.maxSize = maxSize;
        this.coreSize = coreSize;
        this.threadFactory = defaultThreadFactory;
        this.runnableQueue = new LinkendRunnableQueue(queueSize, defaultDenyPolicy, this);
        this.keepAliveTime = keepAliveTime;
        this.timeUnit = timeUnit;
        this.init();
    }

    /**
     * 初始化
     */
    private void init() {
        start();
        for (int i = 0; i < initSize; i++) {
            newThread();
        }
    }

    /**
     * 扩容
     */
    private void newThread() {
        InternalTask internalTask = new InternalTask(runnableQueue);
        Thread thread = this.threadFactory.createThread(internalTask);
        ThreadTask threadTask = new ThreadTask(thread, internalTask);
        threadQueue.offer(threadTask);
        this.activeCount++;
        thread.start();
    }

    private void removeThread() {
        //从线程中移除某个线程
        ThreadTask threadTask = threadQueue.remove();
        threadTask.internalTask.stop();
        this.activeCount--;
    }

    @Override
    public void run() {
        //用于维护线程数量,比如扩容、回收等工作
        while (!isShutdown && !isInterrupted()) {
            try {
                timeUnit.sleep(keepAliveTime);
            } catch (InterruptedException e) {
                isShutdown = true;
                break;
            }

            synchronized (this) {
                if (isShutdown)
                    break;
                if (runnableQueue.size() > 0 && activeCount < coreSize) {
                    for (int i = initSize; i < coreSize; i++) {
                        newThread();
                    }
                    continue;
                }


                if (runnableQueue.size() > 0 && activeCount < coreSize) {
                    for (int i = initSize; i < maxSize; i++) {
                        newThread();
                    }
                }

                if (runnableQueue.size() == 0 && activeCount > coreSize) {
                    for (int i = coreSize; i < activeCount; i++) {
                        removeThread();
                    }
                }


            }
        }
    }

    @Override
    public void execute(Runnable runnable) {
        if (this.isShutdown)
            throw new IllegalStateException("The thread pool is destroy");
        //提交到队列
        this.runnableQueue.offer(runnable);
    }

    @Override
    public void shutdown() {
        synchronized (this) {
            if (isShutdown) {
                return;
            }
            isShutdown = true;

            threadQueue.forEach(threadTask -> {
                threadTask.internalTask.stop();
                threadTask.thread.interrupt();
            });
            this.interrupt();
        }
    }

    @Override
    public int getInitSize() {
        if (isShutdown)
            throw new IllegalStateException("The thread pool is destroy");
        return this.initSize;
    }

    @Override
    public int getMaxsize() {
        if (isShutdown)
            throw new IllegalStateException("The thread pool is destroy");
        return this.maxSize;
    }

    @Override
    public int getCoreSize() {
        if (isShutdown)
            throw new IllegalStateException("The thread pool is destroy");
        return this.coreSize;
    }

    @Override
    public int getQuenSize() {
        if (isShutdown)
            throw new IllegalStateException("The thread pool is destroy");
        return this.runnableQueue.size();
    }

    @Override
    public int getActiveCount() {
        if (isShutdown)
            throw new IllegalStateException("The thread pool is destroy");
        return this.activeCount;
    }

    @Override
    public boolean isShutdown() {
        if (isShutdown)
            throw new IllegalStateException("The thread pool is destroy");
        return this.isShutdown;
    }

    /**
     * ThreadTask只是InternalTask和Thread结合
     */
    private static class ThreadTask {
        Thread thread;
        InternalTask internalTask;

        public ThreadTask(Thread thread, InternalTask internalTask) {
            this.thread = thread;
            this.internalTask = internalTask;
        }
    }


    /**
     * 默认线程池
     */
    private static class DefaultThreadFactory implements ThreadFactory {

        private static final AtomicInteger GROUP_COUNTER = new AtomicInteger(1);

        private static final ThreadGroup group = new ThreadGroup("MyThreadPool-" + GROUP_COUNTER.getAndDecrement());

        private static final AtomicInteger COUNTER = new AtomicInteger(0);

        @Override
        public Thread createThread(Runnable runnable) {
            return new Thread(group, runnable, "thread-pool-" + COUNTER.getAndDecrement());
        }
    }
}

测试类

public class ThreadPoolTest {
    public static void main(String[] args) throws InterruptedException {
        final ThreadPool threadPool = new BasicThreadPool(10, 100, 10, 1000);


        for (int i = 0; i < 20; i++) {
            threadPool.execute(() -> {
                try {
                    TimeUnit.SECONDS.sleep(10);
                    System.out.println(Thread.currentThread().getName() + " is running and done.");
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            });

            for (; ; ) {
                System.out.println("getActiveCount:" + threadPool.getActiveCount());
                System.out.println("getCoreSize:" + threadPool.getCoreSize());
                System.out.println("getMaxSize:" + threadPool.getMaxsize());
                System.out.println("getRunnableQueue:" + threadPool.getQuenSize());
                System.out.println("=================================");
                TimeUnit.SECONDS.sleep(5);
            }
        }
    }
}

这里只是思路,还存在很多缺点不在优化。

1.BasicThreadPool和Thread不应该是继承关系,采用组合最好。

2.线程池销毁对未返回未完成的任务,导致任务丢失。

3.BasicThreadPool构造太多,建议建造者模式

4.线程池对数量控制并没有校验,initSize不应该大于MaxSize

5.其他自行考虑

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章