AtomicInteger源码级模仿拓展实践

前言

目前就职于国内比较大的一家保险公司, 昨天所有技术人员减半, 原因是技术部大领导换人… 作为组内唯一的幸存者, 心里还是有一丝丝的窃喜, 下面我教大家怎么写出别人难以维护的代码.

骚操作开始了

今早接到了一个需求, 大概就是写个拦截器做限流, 项目背景java常规springboot, 双节点服务器, 要求每个节点单独限流, 对其中三个接口分别限流200,200,100, 总限流500, 在接口不繁忙的情况下其他接口可以占用流量, 这里可以用三个AtomicInteger清晰的完成要求, 但是总限流这里就需要加锁了, 完整的需求比较复杂, 不一一细说, 这里分析下用单个AtomicInteger实现限流的功能, 其中有参考线程池的源码.

重点分析, 下面的讲解需要一定的java基础, 个别点不懂的可以一一百度

    private final AtomicInteger ai = new AtomicInteger(stateOf(RUNNING, 0));
    private static final int STATE_BITS = Integer.SIZE - 3;
    private static final int CONTLIST_BITS = Integer.SIZE - 8 * 3;
    private static final int QUESTION_BITS = Integer.SIZE - 8 * 2;
    private static final int RVRESULT_BITS = Integer.SIZE - 8 * 1;

AtomicInteger是并发包下提供的线程安全的计数工具, 原理大概用到了cas, 比锁更轻量化. 我们先new 一个AtomicInteger, 一个int值是32位的01组合, 在内存中大概就是下面这个样子

00000000 00000000 00000000 00000000

首位是符号位, 如果对位运算和二进制转换不太熟悉的同学可以去百度了解下, 我们取前三位来表示状态STATE_BITS = Integer.SIZE - 3 = 29, 这个值后面位运算会用到, 采取位运算一是因为cpu处理比较快, 第二点主要是很少有人用, 不好理解, 你看我的代码里一个注释都没有.

111 00000 11001000 11001000 01100100

接下来我们分析需求, 两个200一个100, 200用二进制表示就是128+64+8(11001000), 100=64+32+4(01100100), 我们把32位每8位分开, 每一段用来记录一组数据, 前三个数字用来表示状态, 有如下几种状态可用:

    private static final int RUNNING = -1 << STATE_BITS;
    private static final int SHUTDOWN = 0 << STATE_BITS;
    private static final int STOP = 1 << STATE_BITS;
    private static final int STATE = SHUTDOWN;

接下来是一些之后运算要用的变量, 用来做与或非位运算的

    private static final int STATE_CAPACITY = (1 << STATE_BITS) - 1;
    private static final int CONTLIST_CAPACITY = (1 << CONTLIST_BITS) - 1;
    private static final int QUESTION_CAPACITY = (((1 << QUESTION_BITS) - 1) >> CONTLIST_BITS) << CONTLIST_BITS;
    private static final int RVRESULT_CAPACITY = (((1 << RVRESULT_BITS) - 1) >> QUESTION_BITS) << QUESTION_BITS;

首先写一个拦截器, 继承HandlerInterceptorAdapter , 实现里面的方法, preHandle在请求来之前会执行, afterCompletion是在请求离开时执行的, 当请求来我们判断是哪个url然后对应的计数+1,离开的时候-1, 当数量达到最大就熔断限流

@Component
public class AccessInterceptor extends HandlerInterceptorAdapter {
 @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        if(handler instanceof HandlerMethod){
            AccessLimit accessLimit = ((HandlerMethod) handler).getMethodAnnotation(AccessLimit.class);
            if(accessLimit == null){
                return true;
            }
            if (accessLimit.processFlag()) {
                int cap = getCap(accessLimit.value());
                if (countOf(incrementAndGet(cap), cap)>accessLimit.processQuantity()) {
                    decrementAndGet(cap);
                    sendDefaultMsg(response);
                    return false;
                }
            }
            return true;
        }
        return true;
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex)
            throws Exception {
        AccessLimit accessLimit = ((HandlerMethod) handler).getMethodAnnotation(AccessLimit.class);
        if(accessLimit != null){
            decrementAndGet(getCap(accessLimit.value()));
        }
    }
}

加在url上的自定义注解, 用来配置限流大小的, value是用来区分url的, processQuantity是限流大小,processFlag是开启限流的标记

import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.RetentionPolicy.RUNTIME;

@Retention(RUNTIME)
@Target(METHOD)
public @interface AccessLimit {
    String value();
    int processQuantity() default 200;
    boolean processFlag() default true;
}

大概用法像下面这样, 把注解加到controller的url上, default 默认的值可以省略不写, 这里对于不熟悉注解的同学也是个迷惑, 又可以劝退一批大学刚毕业的小年轻

    @AccessLimit(value = "/getRvOnlineContList")
	@PostMapping(value = "/getRvOnlineContList")
	public String getRvOnlineContList(String inputJson) throws Exception {
		String outputJson = "";
		...
		return result;
	}

下面是拦截器配置类, 对相应的url规则启用我们的拦截器


@Configuration
public class WebConfig extends WebMvcConfigurerAdapter {

    @Autowired
    private AccessInterceptor interceptor;

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        registry.addInterceptor(interceptor)
                .addPathPatterns("/returnVisit/getRvOnlineContList")
                .addPathPatterns("/returnVisit/getRvOnlineQuestion")
                .addPathPatterns("/returnVisit/getRVStatusResult");
    }
}

为了友好, 要有熔断返回的异常信息

private void sendDefaultMsg(HttpServletResponse response) throws IOException {
        PrintWriter pw = response.getWriter();
        pw.write("{\"success\":\"false\",\"resultCode\":\"-1\",\"resultMsg\":\"服务器繁忙,请稍后访问!\"}");
        pw.flush();
        pw.close();
    }

这是两个对于分段+1-1用的位运算, 可以说是整个代码最精华的部分, 注意括号的使用

	private int incrementAndGet(int cap) {
        int prev, next;
        do {
            prev = ai.get();
            next = applyAsIntGo(prev, cap);
        } while (!ai.compareAndSet(prev, next));
        return prev;
    }

    private int decrementAndGet(int cap) {
        int prev, next;
        do {
            prev = ai.get();
            next = minusAsIntGo(prev, cap);
        } while (!ai.compareAndSet(prev, next));
        return prev;
    }
	private static int applyAsIntGo(int prev, int cap){
        int w = getCapW(cap);
        return (prev & ~cap) | ((((prev & cap) >> w) + 1) << w);
    }

    private static int minusAsIntGo(int prev, int cap){
        int w = getCapW(cap);
        return (prev & ~cap) | ((((prev & cap) >> w) - 1) << w);
    }

完整代码

import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.concurrent.atomic.AtomicInteger;

@Component
public class AccessInterceptor extends HandlerInterceptorAdapter {
    private final AtomicInteger ai = new AtomicInteger(stateOf(RUNNING, 0));
    private static final int STATE_BITS = Integer.SIZE - 3;
    private static final int CONTLIST_BITS = Integer.SIZE - 8 * 3;
    private static final int QUESTION_BITS = Integer.SIZE - 8 * 2;
    private static final int RVRESULT_BITS = Integer.SIZE - 8 * 1;
    private static final int STATE_CAPACITY = (1 << STATE_BITS) - 1;
    private static final int CONTLIST_CAPACITY = (1 << CONTLIST_BITS) - 1;
    private static final int QUESTION_CAPACITY = (((1 << QUESTION_BITS) - 1) >> CONTLIST_BITS) << CONTLIST_BITS;
    private static final int RVRESULT_CAPACITY = (((1 << RVRESULT_BITS) - 1) >> QUESTION_BITS) << QUESTION_BITS;
    private static final int RUNNING = 0 << STATE_BITS;
    private static final int STOP = 1 << STATE_BITS;
    private static final int STATE = STOP;

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        if(handler instanceof HandlerMethod){
            AccessLimit accessLimit = ((HandlerMethod) handler).getMethodAnnotation(AccessLimit.class);
            if (accessLimit.processFlag()) {
                int cap = getCap(accessLimit.value());
                if (countOf(incrementAndGet(cap), cap) > accessLimit.processQuantity()) {
                    decrementAndGet(cap);
                    sendDefaultMsg(response);
                    return false;
                }
                request.setAttribute("aicap", cap);
            }
            return true;
        }
        return true;
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex)
            throws Exception {
        Object cap;
        if((cap = request.getAttribute("aicap")) != null){
            decrementAndGet((int) cap);
        }
    }

    private static int goState(int c) { return c & ~STATE_CAPACITY; }

    private static int goCount(int c) { return c & STATE_CAPACITY; }

    private static int countOf(int rs, int wc) { return (rs & wc) >> getCapW(wc); }

    private static int stateOf(int rs, int wc) { return rs | wc; }

    private static boolean runStateLessThan(int c, int s) {
        return c < s;
    }

    private static boolean runStateAtLeast(int c, int s) {
        return c >= s;
    }

    private static boolean isRunning(int c) {
        return c < STOP;
    }

    private boolean casIncrement(int n) {
        return ai.compareAndSet(n, n + 1);
    }

    private boolean casDecrement(int n) {
        return ai.compareAndSet(n, n - 1);
    }

    private int incrementAndGet(int cap) {
        int prev, next;
        do {
            prev = ai.get();
            next = applyAsIntGo(prev, cap);
        } while (!ai.compareAndSet(prev, next));
        return next;
    }

    private int decrementAndGet(int cap) {
        int prev, next;
        do {
            prev = ai.get();
            next = minusAsIntGo(prev, cap);
        } while (!ai.compareAndSet(prev, next));
        return next;
    }

    private static int getCap(String pathInfo) {
        switch (pathInfo) {
            case "/go":
                return QUESTION_CAPACITY;
            case "/getRvOnlineContList":
                return CONTLIST_CAPACITY;
            case "/getRvOnlineQuestion":
                return QUESTION_CAPACITY;
            case "/getRVStatusResult":
                return RVRESULT_CAPACITY;
            default:
                return STATE_CAPACITY;
        }
    }

    private static int getCapW(int cap) {
        switch (cap) {
            case CONTLIST_CAPACITY:
                return 0;
            case QUESTION_CAPACITY:
                return CONTLIST_BITS;
            case RVRESULT_CAPACITY:
                return QUESTION_BITS;
            default:
                return STATE_BITS;
        }
    }

    private static void sendDefaultMsg(HttpServletResponse response) throws IOException {
        PrintWriter pw = response.getWriter();
        pw.write("{\"success\":\"false\",\"resultCode\":\"-1\",\"resultMsg\":\"Please Try Again!\"}");
        pw.flush();
        pw.close();
    }

    private static int applyAsIntGo(int prev, int cap){
        int w = getCapW(cap);
        return (prev & ~cap) | ((((prev & cap) >> w) + 1) << w);
    }

    private static int minusAsIntGo(int prev, int cap){
        int w = getCapW(cap);
        return (prev & ~cap) | ((((prev & cap) >> w) - 1) << w);
    }
}


总结

嘿嘿

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章