redis限流器的設計

1.定義註解

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Repeatable;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * redis緩存的註解
 *
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Repeatable(RateLimits.class)
public @interface RedisRateLimitAttribute {
    /**
     * {@link #key()}的別名
     *
     * @return key()的別名
     */
    String value() default "";

    /**
     * key, 支持SpEL表達式解析
     *
     * @return 限流的key值
     */
    String key() default "";

    /**
     * 限流的優先級
     *
     * @return 限流器的優先級
     */
    int order() default 0;

    /**
     * 執行計數的條件表達式,支持SpEL表達式,如果結果爲真,則執行計數
     *
     * @return 執行計數的條件表達式
     */
    String incrCondition() default "true";

    /**
     * 限流的最大值,支持配置引用
     *
     * @return 限流的最大值
     */
    String limit() default "1";

    /**
     * 限流的時間範圍值,支持配置引用
     *
     * @return 限流的時間範圍值
     */
    String intervalInMilliseconds() default "1000";

    /**
     * 降級的方法名,降級方法的參數與原方法一致或多了一個原方法的ReturnValue的類型
     *
     * @return 降級的方法名
     */
    String fallbackMethod() default "";
}

 

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 多重限流注解的存儲器
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimits {

    /**
     *
     * @return 註解列表
     */
    RedisRateLimitAttribute[] value() default {};
}

 

2. 切面方法

 

import com.google.common.base.Strings;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.EnableAspectJAutoProxy;
import org.springframework.core.DefaultParameterNameDiscoverer;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.Order;
import org.springframework.core.env.Environment;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import redis.clients.jedis.JedisCluster;

//開啓AspectJ 自動代理模式,如果不填proxyTargetClass=true,默認爲false,
@EnableAspectJAutoProxy(proxyTargetClass = true)
@Component
@Order(-1)
@Aspect
public class RedisRateLimitAspect {
    /**
     * 日誌
     */
    private static Logger logger = LoggerFactory.getLogger(RedisRateLimitAspect.class);

    /**
     * SPEL表達式解析器
     */
    private static final ExpressionParser EXPRESSION_PARSER = new SpelExpressionParser();

    /**
     * 獲取方法參數名稱發現器
     */
    private static final ParameterNameDiscoverer PARAMETER_NAME_DISCOVERER = new DefaultParameterNameDiscoverer();

    /**
     * Redis集羣
     */
    @Autowired
    private JedisCluster jedisCluster;

    /**
     * springboot自動加載配置信息
     */
    @Autowired
    private Environment environment;

    /**
     * 切面切入點
     */
    @Pointcut("@annotation(com.g2.order.server.annotation.RedisRateLimitAttribute)")
    public void rateLimit() {

    }

    /**
     * 環繞切面
     */
    @Around("rateLimit()")
    public Object handleControllerMethod(ProceedingJoinPoint proceedingJoinPoint) throws Throwable {
        //獲取切入點對應的方法.
        MethodSignature methodSignature = (MethodSignature) proceedingJoinPoint.getSignature();
        Method method = methodSignature.getMethod();

        //獲取註解列表
        List<RedisRateLimitAttribute> redisRateLimitAttributes =
                AnnotatedElementUtils.findMergedRepeatableAnnotations(method, RedisRateLimitAttribute.class)
                        .stream()
                        .sorted(Comparator.comparing(RedisRateLimitAttribute::order))
                        .collect(Collectors.collectingAndThen(Collectors.toList(), Collections::unmodifiableList));

        if (CollectionUtils.isEmpty(redisRateLimitAttributes)) {
            return proceedingJoinPoint.proceed();
        }

        // 切入點所在的實例,調用fallback方法時需要
        Object target = proceedingJoinPoint.getTarget();
        // 方法入參集合,調用fallback方法時需要
        Object[] args = proceedingJoinPoint.getArgs();
        if (args == null) {
            args = new Object[0];
        }

        // 前置檢查
        for (RedisRateLimitAttribute rateLimit : redisRateLimitAttributes) {
            // 獲取限流設置的key(可能有配置佔位符和spel表達式)
            String key = computeExpress(formatKey(rateLimit.key()), proceedingJoinPoint, String.class);
            // 獲取限流配置的閥值
            long limitV = Long.parseLong(formatKey(rateLimit.limit()));
            // 獲取當前key已記錄的值
            String currentValue = jedisCluster.get(key);
            long currentV = Strings.isNullOrEmpty(currentValue) ? 0 : Long.parseLong(jedisCluster.get(key));
            // 當前值如果小於等於閥值,則合法;否則不合法
            boolean validated = currentV <= limitV;
            // 如果不合法則進入fallback流程
            if (!validated) {
                // 獲取當前限流配置的fallback
                Method fallbackMethod = getFallbackMethod(proceedingJoinPoint, rateLimit.fallbackMethod());
                // 如果fallback參數數量與切入點參數數量不一樣,則壓入空的返回值
                if (fallbackMethod.getParameterCount() != method.getParameterCount()) {
                    Object[] args2 = Arrays.copyOf(args, args.length + 1);
                    args2[args2.length - 1] = null;
                    return invokeFallbackMethod(fallbackMethod, target, args2);
                }

                return invokeFallbackMethod(fallbackMethod, target, args);
            }
        }

        // 前置檢查通過後,執行方法體
        Object result = proceedingJoinPoint.proceed();

        // 後置檢查
        for (RedisRateLimitAttribute rateLimit : redisRateLimitAttributes) {
            // 獲取限流設置的key(可能有配置佔位符和spel表達式)
            String key = computeExpress(formatKey(rateLimit.key()), proceedingJoinPoint, String.class, result);
            // 獲取限流配置的閥值
            long limitV = Long.parseLong(formatKey(rateLimit.limit()));
            // 獲取限流配置的限流區間
            long interval = Long.parseLong(formatKey(rateLimit.intervalInMilliseconds()));
            boolean validated = true;
            // 計算當前一次執行後是否滿足限流條件
            boolean incrMatch = match(proceedingJoinPoint, rateLimit, result);
            if (incrMatch) {
                // 如果不存在key,則設置該key,並且超時時間爲限流區間值
                // 獲取當前key已記錄的值
                String currentValue = jedisCluster.get(key);
                // TODO 這裏最好修改成 lua腳本來實現原子性
                long currentV = Strings.isNullOrEmpty(currentValue) ? 0 : Long.parseLong(jedisCluster.get(key));
                if (currentV == 0) {
                    jedisCluster.set(key, "1", "nx", "ex", interval);
                } else {
                    jedisCluster.incrBy(key, 1);
                }
                validated = currentV +1 <= limitV;
            }

            if (!validated) {
                // 獲取fallback方法
                // TODO 這裏可以修改爲已獲取的話Map裏,下次不需要再調用getFallbackMethod方法了
                Method fallbackMethod = getFallbackMethod(proceedingJoinPoint, rateLimit.fallbackMethod());
                Object[] args2 = Arrays.copyOf(args, args.length + 1);
                args2[args2.length - 1] = result;
                return invokeFallbackMethod(fallbackMethod, target, args2);
            }
        }

        return result;
    }

    /**
     * 計算spel表達式
     *
     * @param expression 表達式
     * @param context    上下文
     * @return String的緩存key
     */
    private <T> T computeExpress(String expression, JoinPoint context, Class<T> tClass) {
        // 計算表達式(根據參數上下文)
        return computeExpress(expression, context, tClass, null);
    }

    /**
     * 計算spel表達式
     *
     * @param expression 表達式
     * @param context    上下文
     * @return String的緩存key
     */
    private <T> T computeExpress(String expression, JoinPoint context, Class<T> tClass, Object returnValue) {
        // 將參數名與參數值放入參數上下文
        EvaluationContext evaluationContext = buildEvaluationContext(returnValue, context);

        // 計算表達式(根據參數上下文)
        return EXPRESSION_PARSER.parseExpression(expression).getValue(evaluationContext, tClass);
    }

    /**
     * 計算是否匹配限流策略
     * @param context
     * @param rateLimit
     * @param returnValue
     * @return
     */
    private boolean match(JoinPoint context, RedisRateLimitAttribute rateLimit, Object returnValue) {
        return computeExpress(rateLimit.incrCondition(), context, Boolean.class, returnValue);
    }

    /**
     * 格式化key
     * @param v
     * @return
     */
    private String formatKey(String v) {
        String result = v;
        if (Strings.isNullOrEmpty(result)) {
            throw new IllegalStateException("key配置不能爲空");
        }
        return environment.resolvePlaceholders(result);
    }

    /**
     * 放入參數值到StandardEvaluationContext
     */
    private static void addParameterVariable(StandardEvaluationContext evaluationContext, JoinPoint context) {
        MethodSignature methodSignature = (MethodSignature) context.getSignature();
        Method method = methodSignature.getMethod();
        String[] parameterNames = PARAMETER_NAME_DISCOVERER.getParameterNames(method);
        if (parameterNames != null && parameterNames.length > 0) {
            Object[] args = context.getArgs();
            for (int i = 0; i < parameterNames.length; i++) {
                evaluationContext.setVariable(parameterNames[i], args[i]);
            }
        }
    }

    /**
     * 放入返回值到StandardEvaluationContext
     */
    private static void addReturnValue(StandardEvaluationContext evaluationContext, Object returnValue) {
        evaluationContext.setVariable("returnValue", returnValue);
        evaluationContext.setVariable("response", returnValue);
    }

    /**
     * 構建StandardEvaluationContext
     */
    private static EvaluationContext buildEvaluationContext(Object returnValue, JoinPoint context) {
        StandardEvaluationContext evaluationContext = new StandardEvaluationContext();
        addParameterVariable(evaluationContext, context);
        addReturnValue(evaluationContext, returnValue);

        return evaluationContext;
    }

    /**
     * 獲取降級方法
     *
     * @param context        過濾器上下文
     * @param fallbackMethod 失敗要執行的函數
     * @return 降級方法
     */
    private static Method getFallbackMethod(JoinPoint context, String fallbackMethod) {
        MethodSignature methodSignature = (MethodSignature) context.getSignature();
        Class[] parameterTypes = Optional.ofNullable(methodSignature.getParameterTypes()).orElse(new Class[0]);
        try {
            Method method = context.getTarget().getClass().getDeclaredMethod(fallbackMethod, parameterTypes);
            method.setAccessible(true);
            return method;
        } catch (NoSuchMethodException e) {

        }

        try {
            Class[] parameterTypes2 = Arrays.copyOf(parameterTypes, parameterTypes.length + 1);
            parameterTypes2[parameterTypes2.length - 1] = methodSignature.getReturnType();

            Method method = context.getTarget().getClass().getDeclaredMethod(fallbackMethod, parameterTypes2);
            method.setAccessible(true);
            return method;
        } catch (NoSuchMethodException e) {

        }

        String message = String.format("獲取fallbackMethod失敗, context: %s, fallbackMethod: %s",
                context, fallbackMethod);
        throw new RuntimeException(message);
    }

    /**
     * 執行降級fallback方法
     * @param fallbackMethod
     * @param fallbackTarget
     * @param fallbackArgs
     * @return
     * @throws Throwable
     */
    private static Object invokeFallbackMethod(Method fallbackMethod, Object fallbackTarget, Object[] fallbackArgs)
            throws Throwable {
        try {
            return fallbackMethod.invoke(fallbackTarget, fallbackArgs);
        } catch (InvocationTargetException e) {
            if (e.getCause() != null) {
                throw e.getCause();
            }
            throw e;
        }
    }
}

 

3.調用事例

@Slf4j
@Api(value = "HomeController", description = "用戶登錄登出接口")
@RestController
@RequestMapping("/home")
public class HomeController {
    private static Logger logger = LoggerFactory.getLogger(HomeController.class);


    @ApiOperation(value = "用戶登錄", notes = "用戶登錄接口")
    @RequestMapping(value = "/login",
            method = RequestMethod.POST,
            consumes = MediaType.APPLICATION_JSON_VALUE,
            produces = MediaType.APPLICATION_JSON_VALUE)
    @ResponseBody
  
    @RedisRateLimitAttribute(key = "'login'+#req.userId"
            , limit = "${login.maxFailedTimes:3}"
            , incrCondition = "#response.success == true"
            , intervalInMilliseconds = "${login.limit.millseconds:3600}"
            , fallbackMethod = "loginFallback"
    )
    public UserLoginResp login(@RequestBody UserLoginReq req) {
        logger.info("進入登陸業務");
        
        UserModel userModel = new UserModel();
        userModel.setRoleId(123);
        userModel.setUserId(req.getUserId());
        userModel.setMustValidateCode(false);

        return new UserLoginResp(userModel);
    }

    private UserLoginResp loginFallback(UserLoginReq req, UserLoginResp resp) {
        if (resp == null) {
            return new UserLoginResp();
           
        }
        resp.getPayload().setMustValidateCode(true);
        return resp;
    }
}

 

@Data
public class UserModel {
    /***
     * 用戶id
     */
    private String userId;

    /**
     * 角色
     */
    private String roleName;

    /**
     * 角色編號
     */
    private Integer roleId;

    /**
     * 登陸是否需要驗證碼
     * 當錯誤次數達到閥值時,需要驗證碼來增加提交難度
     */
    private Boolean mustValidateCode;
}

 

import lombok.Data;

@Data
public class Response<T> {
    private Boolean success;
    private String errorMessage;
    private T payload;

    public Response() {
        this(true);
    }

    public Response(boolean succ) {
        this(succ, "");
    }

    public Response(boolean succ, String msg) {
        this(succ, msg, null);
    }

    public Response(T data) {
        this(true, "", data);
    }

    public Response(boolean succ, String msg, T data) {
        success = succ;
        errorMessage = msg;
        this.payload = data;
    }
}

 

public class UserLoginResp extends Response<UserModel> {
    public UserLoginResp(){
    }
    public UserLoginResp(UserModel userModel){
        super(userModel);
    }

    @Override
    public String toString() {
        return super.toString();
    }
}

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章