AtomicInteger源碼級模仿拓展實踐

前言

目前就職於國內比較大的一家保險公司, 昨天所有技術人員減半, 原因是技術部大領導換人… 作爲組內唯一的倖存者, 心裏還是有一絲絲的竊喜, 下面我教大家怎麼寫出別人難以維護的代碼.

騷操作開始了

今早接到了一個需求, 大概就是寫個攔截器做限流, 項目背景java常規springboot, 雙節點服務器, 要求每個節點單獨限流, 對其中三個接口分別限流200,200,100, 總限流500, 在接口不繁忙的情況下其他接口可以佔用流量, 這裏可以用三個AtomicInteger清晰的完成要求, 但是總限流這裏就需要加鎖了, 完整的需求比較複雜, 不一一細說, 這裏分析下用單個AtomicInteger實現限流的功能, 其中有參考線程池的源碼.

重點分析, 下面的講解需要一定的java基礎, 個別點不懂的可以一一百度

    private final AtomicInteger ai = new AtomicInteger(stateOf(RUNNING, 0));
    private static final int STATE_BITS = Integer.SIZE - 3;
    private static final int CONTLIST_BITS = Integer.SIZE - 8 * 3;
    private static final int QUESTION_BITS = Integer.SIZE - 8 * 2;
    private static final int RVRESULT_BITS = Integer.SIZE - 8 * 1;

AtomicInteger是併發包下提供的線程安全的計數工具, 原理大概用到了cas, 比鎖更輕量化. 我們先new 一個AtomicInteger, 一個int值是32位的01組合, 在內存中大概就是下面這個樣子

00000000 00000000 00000000 00000000

首位是符號位, 如果對位運算和二進制轉換不太熟悉的同學可以去百度瞭解下, 我們取前三位來表示狀態STATE_BITS = Integer.SIZE - 3 = 29, 這個值後面位運算會用到, 採取位運算一是因爲cpu處理比較快, 第二點主要是很少有人用, 不好理解, 你看我的代碼裏一個註釋都沒有.

111 00000 11001000 11001000 01100100

接下來我們分析需求, 兩個200一個100, 200用二進制表示就是128+64+8(11001000), 100=64+32+4(01100100), 我們把32位每8位分開, 每一段用來記錄一組數據, 前三個數字用來表示狀態, 有如下幾種狀態可用:

    private static final int RUNNING = -1 << STATE_BITS;
    private static final int SHUTDOWN = 0 << STATE_BITS;
    private static final int STOP = 1 << STATE_BITS;
    private static final int STATE = SHUTDOWN;

接下來是一些之後運算要用的變量, 用來做與或非位運算的

    private static final int STATE_CAPACITY = (1 << STATE_BITS) - 1;
    private static final int CONTLIST_CAPACITY = (1 << CONTLIST_BITS) - 1;
    private static final int QUESTION_CAPACITY = (((1 << QUESTION_BITS) - 1) >> CONTLIST_BITS) << CONTLIST_BITS;
    private static final int RVRESULT_CAPACITY = (((1 << RVRESULT_BITS) - 1) >> QUESTION_BITS) << QUESTION_BITS;

首先寫一個攔截器, 繼承HandlerInterceptorAdapter , 實現裏面的方法, preHandle在請求來之前會執行, afterCompletion是在請求離開時執行的, 當請求來我們判斷是哪個url然後對應的計數+1,離開的時候-1, 當數量達到最大就熔斷限流

@Component
public class AccessInterceptor extends HandlerInterceptorAdapter {
 @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        if(handler instanceof HandlerMethod){
            AccessLimit accessLimit = ((HandlerMethod) handler).getMethodAnnotation(AccessLimit.class);
            if(accessLimit == null){
                return true;
            }
            if (accessLimit.processFlag()) {
                int cap = getCap(accessLimit.value());
                if (countOf(incrementAndGet(cap), cap)>accessLimit.processQuantity()) {
                    decrementAndGet(cap);
                    sendDefaultMsg(response);
                    return false;
                }
            }
            return true;
        }
        return true;
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex)
            throws Exception {
        AccessLimit accessLimit = ((HandlerMethod) handler).getMethodAnnotation(AccessLimit.class);
        if(accessLimit != null){
            decrementAndGet(getCap(accessLimit.value()));
        }
    }
}

加在url上的自定義註解, 用來配置限流大小的, value是用來區分url的, processQuantity是限流大小,processFlag是開啓限流的標記

import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

@Retention(RUNTIME)
@Target(METHOD)
public @interface AccessLimit {
    String value();
    int processQuantity() default 200;
    boolean processFlag() default true;
}

大概用法像下面這樣, 把註解加到controller的url上, default 默認的值可以省略不寫, 這裏對於不熟悉註解的同學也是個迷惑, 又可以勸退一批大學剛畢業的小年輕

    @AccessLimit(value = "/getRvOnlineContList")
	@PostMapping(value = "/getRvOnlineContList")
	public String getRvOnlineContList(String inputJson) throws Exception {
		String outputJson = "";
		...
		return result;
	}

下面是攔截器配置類, 對相應的url規則啓用我們的攔截器


@Configuration
public class WebConfig extends WebMvcConfigurerAdapter {

    @Autowired
    private AccessInterceptor interceptor;

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(interceptor)
                .addPathPatterns("/returnVisit/getRvOnlineContList")
                .addPathPatterns("/returnVisit/getRvOnlineQuestion")
                .addPathPatterns("/returnVisit/getRVStatusResult");
    }
}

爲了友好, 要有熔斷返回的異常信息

private void sendDefaultMsg(HttpServletResponse response) throws IOException {
        PrintWriter pw = response.getWriter();
        pw.write("{\"success\":\"false\",\"resultCode\":\"-1\",\"resultMsg\":\"服務器繁忙,請稍後訪問!\"}");
        pw.flush();
        pw.close();
    }

這是兩個對於分段+1-1用的位運算, 可以說是整個代碼最精華的部分, 注意括號的使用

	private int incrementAndGet(int cap) {
        int prev, next;
        do {
            prev = ai.get();
            next = applyAsIntGo(prev, cap);
        } while (!ai.compareAndSet(prev, next));
        return prev;
    }

    private int decrementAndGet(int cap) {
        int prev, next;
        do {
            prev = ai.get();
            next = minusAsIntGo(prev, cap);
        } while (!ai.compareAndSet(prev, next));
        return prev;
    }
	private static int applyAsIntGo(int prev, int cap){
        int w = getCapW(cap);
        return (prev & ~cap) | ((((prev & cap) >> w) + 1) << w);
    }

    private static int minusAsIntGo(int prev, int cap){
        int w = getCapW(cap);
        return (prev & ~cap) | ((((prev & cap) >> w) - 1) << w);
    }

完整代碼

import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.concurrent.atomic.AtomicInteger;

@Component
public class AccessInterceptor extends HandlerInterceptorAdapter {
    private final AtomicInteger ai = new AtomicInteger(stateOf(RUNNING, 0));
    private static final int STATE_BITS = Integer.SIZE - 3;
    private static final int CONTLIST_BITS = Integer.SIZE - 8 * 3;
    private static final int QUESTION_BITS = Integer.SIZE - 8 * 2;
    private static final int RVRESULT_BITS = Integer.SIZE - 8 * 1;
    private static final int STATE_CAPACITY = (1 << STATE_BITS) - 1;
    private static final int CONTLIST_CAPACITY = (1 << CONTLIST_BITS) - 1;
    private static final int QUESTION_CAPACITY = (((1 << QUESTION_BITS) - 1) >> CONTLIST_BITS) << CONTLIST_BITS;
    private static final int RVRESULT_CAPACITY = (((1 << RVRESULT_BITS) - 1) >> QUESTION_BITS) << QUESTION_BITS;
    private static final int RUNNING = 0 << STATE_BITS;
    private static final int STOP = 1 << STATE_BITS;
    private static final int STATE = STOP;

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        if(handler instanceof HandlerMethod){
            AccessLimit accessLimit = ((HandlerMethod) handler).getMethodAnnotation(AccessLimit.class);
            if (accessLimit.processFlag()) {
                int cap = getCap(accessLimit.value());
                if (countOf(incrementAndGet(cap), cap) > accessLimit.processQuantity()) {
                    decrementAndGet(cap);
                    sendDefaultMsg(response);
                    return false;
                }
                request.setAttribute("aicap", cap);
            }
            return true;
        }
        return true;
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex)
            throws Exception {
        Object cap;
        if((cap = request.getAttribute("aicap")) != null){
            decrementAndGet((int) cap);
        }
    }

    private static int goState(int c) { return c & ~STATE_CAPACITY; }

    private static int goCount(int c) { return c & STATE_CAPACITY; }

    private static int countOf(int rs, int wc) { return (rs & wc) >> getCapW(wc); }

    private static int stateOf(int rs, int wc) { return rs | wc; }

    private static boolean runStateLessThan(int c, int s) {
        return c < s;
    }

    private static boolean runStateAtLeast(int c, int s) {
        return c >= s;
    }

    private static boolean isRunning(int c) {
        return c < STOP;
    }

    private boolean casIncrement(int n) {
        return ai.compareAndSet(n, n + 1);
    }

    private boolean casDecrement(int n) {
        return ai.compareAndSet(n, n - 1);
    }

    private int incrementAndGet(int cap) {
        int prev, next;
        do {
            prev = ai.get();
            next = applyAsIntGo(prev, cap);
        } while (!ai.compareAndSet(prev, next));
        return next;
    }

    private int decrementAndGet(int cap) {
        int prev, next;
        do {
            prev = ai.get();
            next = minusAsIntGo(prev, cap);
        } while (!ai.compareAndSet(prev, next));
        return next;
    }

    private static int getCap(String pathInfo) {
        switch (pathInfo) {
            case "/go":
                return QUESTION_CAPACITY;
            case "/getRvOnlineContList":
                return CONTLIST_CAPACITY;
            case "/getRvOnlineQuestion":
                return QUESTION_CAPACITY;
            case "/getRVStatusResult":
                return RVRESULT_CAPACITY;
            default:
                return STATE_CAPACITY;
        }
    }

    private static int getCapW(int cap) {
        switch (cap) {
            case CONTLIST_CAPACITY:
                return 0;
            case QUESTION_CAPACITY:
                return CONTLIST_BITS;
            case RVRESULT_CAPACITY:
                return QUESTION_BITS;
            default:
                return STATE_BITS;
        }
    }

    private static void sendDefaultMsg(HttpServletResponse response) throws IOException {
        PrintWriter pw = response.getWriter();
        pw.write("{\"success\":\"false\",\"resultCode\":\"-1\",\"resultMsg\":\"Please Try Again!\"}");
        pw.flush();
        pw.close();
    }

    private static int applyAsIntGo(int prev, int cap){
        int w = getCapW(cap);
        return (prev & ~cap) | ((((prev & cap) >> w) + 1) << w);
    }

    private static int minusAsIntGo(int prev, int cap){
        int w = getCapW(cap);
        return (prev & ~cap) | ((((prev & cap) >> w) - 1) << w);
    }
}


總結

嘿嘿

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章