Guava-RateLimiter實現令牌桶控制接口限流方案(註解方式)

常見限流方案

1.計數器法

  • 原理:在單位時間段內,對請求數進行計數,如果數量超過了單位時間的限制,則執行限流策略,當單位時間結束後,計數器清零,這個過程週而復始,就是計數器法。
  • 缺點:不能均衡限流,在一個單位時間的末尾和下一個單位時間的開始,很可能會有兩個訪問的峯值,導致系統崩潰。
  • 改進方式:可以通過減小單位時間來提高精度。

2.漏桶算法

  • 原理:假設有一個水桶,水桶有一定的容量,所有請求不論速度都會注入到水桶中,然後水桶以一個恆定的速度向外將請求放出,當水桶滿了的時候,新的請求被丟棄。
  • 優點:可以平滑請求,削減峯值。
  • 缺點:瓶頸會在漏出的速度,可能會拖慢整個系統,且不能有效地利用系統的資源。  

3.令牌桶算法(推薦)

  • 原理:有一個令牌桶,單位時間內令牌會以恆定的數量(即令牌的加入速度)加入到令牌桶中,所有請求都需要獲取令牌纔可正常訪問。當令牌桶中沒有令牌可取的時候,則拒絕請求。
  • 優點:相比漏桶算法,令牌桶算法允許一定的突發流量,但是又不會讓突發流量超過我們給定的限制(單位時間窗口內的令牌數)。即限制了我們所說的 QPS(每秒查詢率)。

實現

1.依賴

<dependency>
    <groupId>com.google.guava</groupId>
    <artifactId>guava</artifactId>
    <version>27.1-jre</version>
</dependency>

2.註解類


import java.lang.annotation.*;
import java.util.concurrent.TimeUnit;

/**
 * @author EDZ
 * @Classname RequestLimiter
 * @Description TODO
 * @Date 2020/6/29 17:46
 * @Created zzf
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RequestLimiter {

    /**
     * 每秒創建令牌個數,默認:10
     */
    double QPS() default 10D;

    /**
     * 獲取令牌等待超時時間 默認:500
     */
    long timeout() default 500;

    /**
     * 超時時間單位 默認:毫秒
     */
    TimeUnit timeunit() default TimeUnit.MILLISECONDS;

    /**
     * 無法獲取令牌返回提示信息
     */
    String msg() default "服務器繁忙,請稍後重試!";
}

3.定義一個枚舉類用於狀態返回

/**
 * @Classname ResponseEnum
 * @Description TODO
 * @Date 2020/6/29 17:52
 * @Created zzf
 */
public enum ResponseEnum {

    SUCCESS("200", "000000", "請求成功"),
    FAIL("200", "100000", "請求失敗"),
    FAIL_BY_PARAMS("200", "200000", "請求參數異常"),
    FAIL_IN_SERVER("200", "300000", "服務器內部異常"),
    RATE_LIMIT("200", "400000", "限流中");

    public String status;
    public String code;
    public String message;

    ResponseEnum(String s, String s1, String s2) {
        this.status = s;
        this.code = s1;
        this.message = s2;
    }
}

4.自定義攔截器,並在攔截器中實現限流

import com.alibaba.fastjson.JSON;
import com.ils.intelab.common.result.CommonResult;
import com.ils.intelab.openapi.client.enums.ResponseEnum;
import com.ils.intelab.openapi.client.result.OpenApiResultCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.MediaType;
import org.springframework.web.servlet.handler.HandlerInterceptorAdapter;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.PrintWriter;

/**
 * 抽象攔截器
 */
public abstract class AbstractInterceptor extends HandlerInterceptorAdapter {
    private Logger logger = LoggerFactory.getLogger(AbstractInterceptor.class);

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
        ResponseEnum result;
        try {
            result = preFilter(request, response, handler);
        } catch (Exception e) {
            logger.error("preHandle catch a exception:" + e.getMessage());
            result = ResponseEnum.FAIL;
        }
        if (ResponseEnum.SUCCESS.code.equals(result.code)) {
            return true;
        }
        handlerResponse(result, response);
        return false;
    }


    /**
     * 自定義pre處理
     *
     * @param request
     * @return
     */
    protected abstract ResponseEnum preFilter(HttpServletRequest request, HttpServletResponse response, Object handler);

    /**
     * 錯誤處理事件
     *
     * @param result
     * @param response
     */
    private void handlerResponse(ResponseEnum result, HttpServletResponse response) {
        CommonResult commonResult = new CommonResult();
        commonResult.setData(null);
        CommonResult<Object> errorInfo = new CommonResult<>();
        errorInfo.setCodes(OpenApiResultCode.ERROR_DEFAULT_NOT_CATCH.getCode());
        errorInfo.setSuccess(false);
        errorInfo.setMessage(OpenApiResultCode.errorMessageByLanguage.get("zh").get(OpenApiResultCode.ERROR_DEFAULT_NOT_CATCH.getMessage()));
        errorInfo.setData(null);
        response.setStatus(HttpServletResponse.SC_OK);
        response.setContentType(MediaType.APPLICATION_JSON_UTF8_VALUE);
        PrintWriter printWriter = null;
        try {
            printWriter = response.getWriter();
            printWriter.write(JSON.toJSONString(errorInfo));
        } catch (Exception e) {
            logger.error("handlerResponse catch a exception:" + e.getMessage());
        } finally {
            if (printWriter != null) {
                printWriter.close();
            }
        }
    }
}

實現


import com.google.common.util.concurrent.RateLimiter;
import com.ils.intelab.openapi.annotation.RequestLimiter;
import com.ils.intelab.openapi.client.enums.ResponseEnum;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.method.HandlerMethod;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 請求限流器
 * @Classname RequestLimiterInterceptor
 * @Description TODO
 * @Date 2020/6/29 17:47
 * @Created zzf
 */
@Slf4j
@Component
public class RequestLimiterInterceptor extends AbstractInterceptor {

    /**
     * 不同的方法存放不同的令牌桶
     */
    private final Map<String, RateLimiter> rateLimiterMap = new ConcurrentHashMap<>();


    @Override
    protected ResponseEnum preFilter(HttpServletRequest request, HttpServletResponse response, Object handler) {
        try {
            if (handler instanceof HandlerMethod) {
                HandlerMethod handlerMethod = (HandlerMethod) handler;
                RequestLimiter rateLimit = handlerMethod.getMethodAnnotation(RequestLimiter.class);
                //判斷是否有註解
                if (rateLimit != null) {
                    // 獲取請求url
                    String key = request.getMethod() + request.getRequestURI();
                    RateLimiter rateLimiter;
                    // 判斷map集合中是否有創建好的令牌桶
                    if (!rateLimiterMap.containsKey(key)) {
                        // 創建令牌桶,以n r/s往桶中放入令牌
                        rateLimiter = RateLimiter.create(rateLimit.QPS());
                        rateLimiterMap.put(key, rateLimiter);
                    }
                    rateLimiter = rateLimiterMap.get(key);
                    // 獲取令牌
                    boolean acquire = rateLimiter.tryAcquire(rateLimit.timeout(), rateLimit.timeunit());
                    if (acquire) {
                        //獲取令牌成功
                        return ResponseEnum.SUCCESS;
                    } else {
                        log.warn("請求被限流,url:{}", request.getServletPath());
                        return ResponseEnum.RATE_LIMIT;
                    }
                }
            }
            return ResponseEnum.SUCCESS;
        } catch (Exception var6) {
            var6.printStackTrace();
            return ResponseEnum.RATE_LIMIT;
        }
    }

}

5.繼承WebMvcConfigurerAdapter來添加自定義攔截器


import com.ils.intelab.openapi.aspect.RequestLimiterInterceptor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.servlet.config.annotation.InterceptorRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

/**
 * 註冊攔截器
 */
@Configuration
public class WebMvcConfig implements WebMvcConfigurer {

    /**
     * 請求限流攔截器
     */
    @Autowired
    protected RequestLimiterInterceptor requestLimiterInterceptor;

    public WebMvcConfig() {
    }

    @Override
    public void addInterceptors(InterceptorRegistry registry) {
        // 請求限流
        registry.addInterceptor(requestLimiterInterceptor).addPathPatterns("/**");
    }

}

6.測試

import com.ils.intelab.common.result.CommonResult;
import com.ils.intelab.openapi.annotation.RequestLimiter;
import io.swagger.annotations.*;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RestController;

import java.util.concurrent.TimeUnit;

/**
 * @Classname TestController
 * @Description TODO
 * @Date 2020/5/27 9:50
 * @Created zzf
 */
@RestController
@RequestMapping("/api/v3/")
@Api(value = "open_api", description = "open_api基礎接口", tags = {"open_api"})
public class TestController {

    @ApiOperation(value = "測試", notes = "測試")
    @ApiResponses({
            @ApiResponse(code = 400, message = "參數非法"),
            @ApiResponse(code = 500, message = "服務器錯誤"),
            @ApiResponse(code = 200, message = "成功")
    })
    @ApiImplicitParams({
    })
    @RequestLimiter(QPS = 1, timeout = 200, timeunit = TimeUnit.MILLISECONDS, msg = "服務器繁忙,請稍後再試")
    @RequestMapping(value = "/test", method = RequestMethod.GET)
    public CommonResult<String> getUserInfo() {
        return new CommonResult<>("999");
    }
}

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章