一、 架構分析
我們來實現一個簡單的線程池,先看一下都需要實現哪些部分。
我們對需求進行一下簡單的抽象分析,就可以發現這個跟生產者/消費者模型有點像,我們可以畫個圖,邊分析邊往上加
1. 以生產者/消費者模型爲基礎創建架構。需要一個隊列taskQueue 用以緩存要執行的任務,通過execute()方法從外部向線程池提交任務,封裝一個方法用來決策,是立即創建線程、放入隊列還是拒絕。
2.創建線程可以用工廠設計模式封裝一個ThreadFactory內部類。
3. 拒絕策略可以簡單封裝一個RejectedHandler類,當線程數已達maxPoolSize且taskQueue滿時拒絕添加新的任務。
4. 封裝一個runTask()方法。在方法中需要不斷地利用getTask()方法來從taskQueue中獲取任務並執行。一個線程執行完任務就會銷燬,爲了防止頻繁地創建和銷燬線程,這裏需要用while()循環不斷getTask()並執行。
5. 關閉並銷燬線程池。
二、具體代碼實現
package threadpooldemo;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
public class MyThreadPool {
final static int MAX_QUEUE_SIZE = 100; //默認的最大任務隊列長度
private volatile long corePoolSize; //加volatile是爲了後面支持set擴展
private volatile long maximumPoolSize;
private BlockingQueue<Runnable> taskQueue = new LinkedBlockingQueue<>(MAX_QUEUE_SIZE);
private volatile long keepAliveTime;
private static final ThreadFactory threadFactory = new MyThreadFactory();
private static final MyRejectedExecutionHandler defaultHandler = new MyRejectedExecutionHandler();
private volatile boolean running = true;
private volatile boolean allowCoreThreadTimeOut = false;
private final AtomicLong threadNum = new AtomicLong(0);
/**
*線程池構造方法
*/
public MyThreadPool(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit) {
//core 線程數不允許爲0
if (corePoolSize <= 0 ||
maximumPoolSize <= 0 ||
maximumPoolSize < corePoolSize ||
keepAliveTime < 0)
throw new IllegalArgumentException();
this.corePoolSize = corePoolSize;
this.maximumPoolSize = maximumPoolSize;
this.keepAliveTime = unit.toNanos(keepAliveTime);
}
/**
* 對Runnable task進行包裝,在其run()方法中調用runTask方法,以便在runTask中可以反覆從taskQueue中取任務執行
*/
private class WrappedTask implements Runnable{
private Thread thread;
private Runnable task;
public WrappedTask(Runnable task) {
this.thread = threadFactory.newThread(this); //newThread(this),start時會自動調用run()
this.task = task;
}
public void startTask(){
this.thread.start();
}
public Runnable getTask(){
return task;
}
@Override
public void run() {
runTask(this);
}
}
/**
* 外部提交任務到線程池,選擇處理策略
*/
public void execute(Runnable task){
if (null == task) {
throw new NullPointerException();
}
//線程數未到corePoolSize,直接創建新線程去幹活
if (running && (threadNum.longValue() < corePoolSize)){
addTask(task);
return;
}
boolean added = taskQueue.offer(task);
if (added) {
return;
}
//添加到隊列不成功,threadNum與maxPoolNum比較,不超出則創建非core線程,超出則reject
if(running && (false == added)){
if(threadNum.longValue() < maximumPoolSize){
addTask(task);
return;
}
defaultHandler.rejectedExecution(task);
}
defaultHandler.otherRejectedExecution(task);
}
/**
* 線程執行過程
* @param wrappedTask
*/
public void runTask(WrappedTask wrappedTask){
Runnable task = wrappedTask.getTask();
while(running && ((null != task) || (null != (task = getTask())))){
task.run();
task = null; //將task置null,以免初始不爲null導致後面 null != task始終成立 將 ||後面邏輯短路
}
threadNum.decrementAndGet();
}
public void addTask(Runnable firstTask){
//前面已經檢查過參數,不可能爲null
threadNum.incrementAndGet();
WrappedTask wTask = new WrappedTask(firstTask);
wTask.startTask();
}
/**
* 從阻塞的任務隊列taskQueue中獲取待執行的任務
*/
public Runnable getTask(){
Runnable task = null;
boolean timedOut = false;
while(true){
//非核心線程是否已經超時
boolean NonCoretimed = allowCoreThreadTimeOut || threadNum.longValue() > corePoolSize;
if(threadNum.longValue() > maximumPoolSize || (NonCoretimed && timedOut)){
return null;
}
try {
task = allowCoreThreadTimeOut ? taskQueue.poll(keepAliveTime, TimeUnit.NANOSECONDS): taskQueue.take();
if(null != task){
return task;
}
timedOut = true;
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
public void terminate(){
running = false;
}
/**
* 是否允許core線程超時,若爲true,則當獲取task超時後,線程會銷燬
*/
public void allowCoreThreadTimeOut(boolean flag){
this.allowCoreThreadTimeOut = flag;
}
/**
* 拒絕策略
*/
private static class MyRejectedExecutionHandler{
public MyRejectedExecutionHandler(){}
public void rejectedExecution(Runnable r) {
throw new RejectedExecutionException("Task " + r.toString() +
" rejected from myThreadPool");
}
public void otherRejectedExecution(Runnable r){
throw new RejectedExecutionException("Task " + r.toString() +
" rejected from myThreadPool due to other reasons, maybe threadPool stopped.");
}
}
/**
* 線程工廠
*/
private static class MyThreadFactory implements ThreadFactory {
private final AtomicInteger threadNumber = new AtomicInteger(1);
private final String namePrefix = "myThreadPool-";
public MyThreadFactory() {}
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(r, namePrefix + threadNumber.getAndIncrement());
if (t.isDaemon()){
t.setDaemon(false);
}
if (t.getPriority() != Thread.NORM_PRIORITY){
t.setPriority(Thread.NORM_PRIORITY);
}
return t;
}
}
}