一個註解解決ShardingJdbc不支持複雜SQL

背景介紹

公司最近做分庫分表業務,接入了 Sharding JDBC,接入完成後,迴歸測試時發現好幾個 SQL 執行報錯,關鍵這幾個表都還不是分片表。報錯如下:

這下糟了嘛。熟悉 Sharding JDBC 的同學應該知道,有很多 SQL 它是不支持的。官方截圖如下:

如果要去修改這些複雜 SQL 的話,可能要花費很多時間。那怎麼辦呢?只能從 Sharding JDBC 這裏找突破口了,兩天的研究,出來了下面這個只需要加一個註解輕鬆解決 Sharding Jdbc 不支持複雜 SQL 的方案。

問題復現

我本地寫了一個複雜 SQL 進行測試:

public List<Map<String, Object>> queryOrder(){
        List<Map<String, Object>> orders = borderRepository.findOrders();
        return orders;
    }
public interface BOrderRepository extends JpaRepository<BOrder,Long> {

    @Query(value = "SELECT * FROM (SELECT id,CASE WHEN company_id =1 THEN '小' WHEN company_id=4 THEN '中' ELSE '大' END AS com,user_id as userId FROM b_order0) t WHERE t.com ='中'",nativeQuery =true)
    List<Map<String, Object>> findOrders();
}

寫了個測試 controller 來調用,調用後果然報錯了。

解決思路

因爲查詢的複雜 SQL 的表不是分片表,那能不能指定這幾個複雜查詢的時候不用 Sharding JDBC 的數據源呢?

  1. 在注入 Sharding JDBC 數據源的地方做處理,注入一個我們自定義的數據源
  2. 這樣我們獲取連接的時候就能返回原生數據源了
  3. 另外我們聲明一個註解,對標識了註解的就返回原生數據源,否則還是返回 Sharding 數據源

具體實現

  1. 編寫一個 autoConfig 類,來替換 ShardingSphereAutoConfiguration 類
/**
 * 動態數據源核心自動配置類
 *
 *
 */
@Configuration
@ComponentScan("org.apache.shardingsphere.spring.boot.converter")
@EnableConfigurationProperties(SpringBootPropertiesConfiguration.class)
@ConditionalOnProperty(prefix = "spring.shardingsphere", name = "enabled", havingValue = "true", matchIfMissing = true)
@AutoConfigureBefore(DataSourceAutoConfiguration.class)
public class DynamicDataSourceAutoConfiguration implements EnvironmentAware {

    private String databaseName;

    private final SpringBootPropertiesConfiguration props;

    private final Map<String, DataSource> dataSourceMap = new LinkedHashMap<>();

    public DynamicDataSourceAutoConfiguration(SpringBootPropertiesConfiguration props) {
        this.props = props;
    }

    /**
     * Get mode configuration.
     *
     * @return mode configuration
     */
    @Bean
    public ModeConfiguration modeConfiguration() {
        return null == props.getMode() ? null : new ModeConfigurationYamlSwapper().swapToObject(props.getMode());
    }

    /**
     * Get ShardingSphere data source bean.
     *
     * @param rules rules configuration
     * @param modeConfig mode configuration
     * @return data source bean
     * @throws SQLException SQL exception
     */
    @Bean
    @Conditional(LocalRulesCondition.class)
    @Autowired(required = false)
    public DataSource shardingSphereDataSource(final ObjectProvider<List<RuleConfiguration>> rules, final ObjectProvider<ModeConfiguration> modeConfig) throws SQLException {
        Collection<RuleConfiguration> ruleConfigs = Optional.ofNullable(rules.getIfAvailable()).orElseGet(Collections::emptyList);
        DataSource dataSource = ShardingSphereDataSourceFactory.createDataSource(databaseName, modeConfig.getIfAvailable(), dataSourceMap, ruleConfigs, props.getProps());
        return new WrapShardingDataSource((ShardingSphereDataSource) dataSource,dataSourceMap);
    }

    /**
     * Get data source bean from registry center.
     *
     * @param modeConfig mode configuration
     * @return data source bean
     * @throws SQLException SQL exception
     */
    @Bean
    @ConditionalOnMissingBean(DataSource.class)
    public DataSource dataSource(final ModeConfiguration modeConfig) throws SQLException {
        DataSource dataSource = !dataSourceMap.isEmpty() ? ShardingSphereDataSourceFactory.createDataSource(databaseName, modeConfig, dataSourceMap, Collections.emptyList(), props.getProps())
                : ShardingSphereDataSourceFactory.createDataSource(databaseName, modeConfig);
        return new WrapShardingDataSource((ShardingSphereDataSource) dataSource,dataSourceMap);
    }

    /**
     * Create transaction type scanner.
     *
     * @return transaction type scanner
     */
    @Bean
    public TransactionTypeScanner transactionTypeScanner() {
        return new TransactionTypeScanner();
    }

    @Override
    public final void setEnvironment(final Environment environment) {
        dataSourceMap.putAll(DataSourceMapSetter.getDataSourceMap(environment));
        databaseName = DatabaseNameSetter.getDatabaseName(environment);
    }

    @Role(BeanDefinition.ROLE_INFRASTRUCTURE)
    @Bean
    @ConditionalOnProperty(prefix = "spring.datasource.dynamic.aop", name = "enabled", havingValue = "true", matchIfMissing = true)
    public Advisor dynamicDatasourceAnnotationAdvisor() {
        DynamicDataSourceAnnotationInterceptor interceptor = new DynamicDataSourceAnnotationInterceptor(true);
        DynamicDataSourceAnnotationAdvisor advisor = new DynamicDataSourceAnnotationAdvisor(interceptor, DS.class);
        return advisor;
    }


}
  1. 自定義數據源
public class WrapShardingDataSource extends AbstractDataSourceAdapter implements AutoCloseable{

    private ShardingSphereDataSource dataSource;

    private Map<String, DataSource> dataSourceMap;

    public WrapShardingDataSource(ShardingSphereDataSource dataSource, Map<String, DataSource> dataSourceMap) {
        this.dataSource = dataSource;
        this.dataSourceMap = dataSourceMap;
    }

    public DataSource getTargetDataSource(){
        String peek = DynamicDataSourceContextHolder.peek();
        if(StringUtils.isEmpty(peek)){
            return dataSource;
        }
        return dataSourceMap.get(peek);
    }


    @Override
    public Connection getConnection() throws SQLException {
        return getTargetDataSource().getConnection();
    }

    @Override
    public Connection getConnection(final String username, final String password) throws SQLException {
        return getConnection();
    }



    @Override
    public void close() throws Exception {
        DataSource targetDataSource = getTargetDataSource();
        if (targetDataSource instanceof AutoCloseable) {
            ((AutoCloseable) targetDataSource).close();
        }
    }

    @Override
    public int getLoginTimeout() throws SQLException {
        DataSource targetDataSource = getTargetDataSource();
        return targetDataSource ==null ? 0 : targetDataSource.getLoginTimeout();
    }

    @Override
    public void setLoginTimeout(final int seconds) throws SQLException {
        DataSource targetDataSource = getTargetDataSource();
        targetDataSource.setLoginTimeout(seconds);
    }
}
  1. 聲明指定數據源註解
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DS {

    /**
     * 數據源名
     */
    String value();
}
  1. 另外使用 AOP 的方式攔截使用了註解的類或方法,並且要將這些用了註解的方法存起來,在獲取數據源連接的時候取出來進行判斷。這就還要用到 ThreadLocal。

aop 攔截器:

public class DynamicDataSourceAnnotationInterceptor implements MethodInterceptor {

    private final DataSourceClassResolver dataSourceClassResolver;

    public DynamicDataSourceAnnotationInterceptor(Boolean allowedPublicOnly) {
        dataSourceClassResolver = new DataSourceClassResolver(allowedPublicOnly);
    }

    @Override
    public Object invoke(MethodInvocation invocation) throws Throwable {
        String dsKey = determineDatasourceKey(invocation);
        DynamicDataSourceContextHolder.push(dsKey);
        try {
            return invocation.proceed();
        } finally {
            DynamicDataSourceContextHolder.poll();
        }
    }

    private String determineDatasourceKey(MethodInvocation invocation) {
        String key = dataSourceClassResolver.findKey(invocation.getMethod(), invocation.getThis());
        return key;
    }
}

aop 切面定義:

/**
 * aop Advisor
 */
public class DynamicDataSourceAnnotationAdvisor extends AbstractPointcutAdvisor implements BeanFactoryAware {

    private final Advice advice;

    private final Pointcut pointcut;

    private final Class<? extends Annotation> annotation;

    public DynamicDataSourceAnnotationAdvisor(MethodInterceptor advice,
                                               Class<? extends Annotation> annotation) {
        this.advice = advice;
        this.annotation = annotation;
        this.pointcut = buildPointcut();
    }

    @Override
    public Pointcut getPointcut() {
        return this.pointcut;
    }

    @Override
    public Advice getAdvice() {
        return this.advice;
    }

    @Override
    public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
        if (this.advice instanceof BeanFactoryAware) {
            ((BeanFactoryAware) this.advice).setBeanFactory(beanFactory);
        }
    }

    private Pointcut buildPointcut() {
        Pointcut cpc = new AnnotationMatchingPointcut(annotation, true);
        Pointcut mpc = new AnnotationMethodPoint(annotation);
        return new ComposablePointcut(cpc).union(mpc);
    }

    /**
     * In order to be compatible with the spring lower than 5.0
     */
    private static class AnnotationMethodPoint implements Pointcut {

        private final Class<? extends Annotation> annotationType;

        public AnnotationMethodPoint(Class<? extends Annotation> annotationType) {
            Assert.notNull(annotationType, "Annotation type must not be null");
            this.annotationType = annotationType;
        }

        @Override
        public ClassFilter getClassFilter() {
            return ClassFilter.TRUE;
        }

        @Override
        public MethodMatcher getMethodMatcher() {
            return new AnnotationMethodMatcher(annotationType);
        }

        private static class AnnotationMethodMatcher extends StaticMethodMatcher {
            private final Class<? extends Annotation> annotationType;

            public AnnotationMethodMatcher(Class<? extends Annotation> annotationType) {
                this.annotationType = annotationType;
            }

            @Override
            public boolean matches(Method method, Class<?> targetClass) {
                if (matchesMethod(method)) {
                    return true;
                }
                // Proxy classes never have annotations on their redeclared methods.
                if (Proxy.isProxyClass(targetClass)) {
                    return false;
                }
                // The method may be on an interface, so let's check on the target class as well.
                Method specificMethod = AopUtils.getMostSpecificMethod(method, targetClass);
                return (specificMethod != method && matchesMethod(specificMethod));
            }

            private boolean matchesMethod(Method method) {
                return AnnotatedElementUtils.hasAnnotation(method, this.annotationType);
            }
        }
    }
}

/**
 * 數據源解析器
 *
 */
public class DataSourceClassResolver {

    private static boolean mpEnabled = false;

    private static Field mapperInterfaceField;

    static {
        Class<?> proxyClass = null;
        try {
            proxyClass = Class.forName("com.baomidou.mybatisplus.core.override.MybatisMapperProxy");
        } catch (ClassNotFoundException e1) {
            try {
                proxyClass = Class.forName("com.baomidou.mybatisplus.core.override.PageMapperProxy");
            } catch (ClassNotFoundException e2) {
                try {
                    proxyClass = Class.forName("org.apache.ibatis.binding.MapperProxy");
                } catch (ClassNotFoundException ignored) {
                }
            }
        }
        if (proxyClass != null) {
            try {
                mapperInterfaceField = proxyClass.getDeclaredField("mapperInterface");
                mapperInterfaceField.setAccessible(true);
                mpEnabled = true;
            } catch (NoSuchFieldException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 緩存方法對應的數據源
     */
    private final Map<Object, String> dsCache = new ConcurrentHashMap<>();
    private final boolean allowedPublicOnly;

    /**
     * 加入擴展, 給外部一個修改aop條件的機會
     *
     * @param allowedPublicOnly 只允許公共的方法, 默認爲true
     */
    public DataSourceClassResolver(boolean allowedPublicOnly) {
        this.allowedPublicOnly = allowedPublicOnly;
    }

    /**
     * 從緩存獲取數據
     *
     * @param method       方法
     * @param targetObject 目標對象
     * @return ds
     */
    public String findKey(Method method, Object targetObject) {
        if (method.getDeclaringClass() == Object.class) {
            return "";
        }
        Object cacheKey = new MethodClassKey(method, targetObject.getClass());
        String ds = this.dsCache.get(cacheKey);
        if (ds == null) {
            ds = computeDatasource(method, targetObject);
            if (ds == null) {
                ds = "";
            }
            this.dsCache.put(cacheKey, ds);
        }
        return ds;
    }

    /**
     * 查找註解的順序
     * 1. 當前方法
     * 2. 橋接方法
     * 3. 當前類開始一直找到Object
     * 4. 支持mybatis-plus, mybatis-spring
     *
     * @param method       方法
     * @param targetObject 目標對象
     * @return ds
     */
    private String computeDatasource(Method method, Object targetObject) {
        if (allowedPublicOnly && !Modifier.isPublic(method.getModifiers())) {
            return null;
        }
        //1. 從當前方法接口中獲取
        String dsAttr = findDataSourceAttribute(method);
        if (dsAttr != null) {
            return dsAttr;
        }
        Class<?> targetClass = targetObject.getClass();
        Class<?> userClass = ClassUtils.getUserClass(targetClass);
        // JDK代理時,  獲取實現類的方法聲明.  method: 接口的方法, specificMethod: 實現類方法
        Method specificMethod = ClassUtils.getMostSpecificMethod(method, userClass);

        specificMethod = BridgeMethodResolver.findBridgedMethod(specificMethod);
        //2. 從橋接方法查找
        dsAttr = findDataSourceAttribute(specificMethod);
        if (dsAttr != null) {
            return dsAttr;
        }
        // 從當前方法聲明的類查找
        dsAttr = findDataSourceAttribute(userClass);
        if (dsAttr != null && ClassUtils.isUserLevelMethod(method)) {
            return dsAttr;
        }
        //since 3.4.1 從接口查找,只取第一個找到的
        for (Class<?> interfaceClazz : ClassUtils.getAllInterfacesForClassAsSet(userClass)) {
            dsAttr = findDataSourceAttribute(interfaceClazz);
            if (dsAttr != null) {
                return dsAttr;
            }
        }
        // 如果存在橋接方法
        if (specificMethod != method) {
            // 從橋接方法查找
            dsAttr = findDataSourceAttribute(method);
            if (dsAttr != null) {
                return dsAttr;
            }
            // 從橋接方法聲明的類查找
            dsAttr = findDataSourceAttribute(method.getDeclaringClass());
            if (dsAttr != null && ClassUtils.isUserLevelMethod(method)) {
                return dsAttr;
            }
        }
        return getDefaultDataSourceAttr(targetObject);
    }

    /**
     * 默認的獲取數據源名稱方式
     *
     * @param targetObject 目標對象
     * @return ds
     */
    private String getDefaultDataSourceAttr(Object targetObject) {
        Class<?> targetClass = targetObject.getClass();
        // 如果不是代理類, 從當前類開始, 不斷的找父類的聲明
        if (!Proxy.isProxyClass(targetClass)) {
            Class<?> currentClass = targetClass;
            while (currentClass != Object.class) {
                String datasourceAttr = findDataSourceAttribute(currentClass);
                if (datasourceAttr != null) {
                    return datasourceAttr;
                }
                currentClass = currentClass.getSuperclass();
            }
        }
        // mybatis-plus, mybatis-spring 的獲取方式
        if (mpEnabled) {
            final Class<?> clazz = getMapperInterfaceClass(targetObject);
            if (clazz != null) {
                String datasourceAttr = findDataSourceAttribute(clazz);
                if (datasourceAttr != null) {
                    return datasourceAttr;
                }
                // 嘗試從其父接口獲取
                return findDataSourceAttribute(clazz.getSuperclass());
            }
        }
        return null;
    }

    /**
     * 用於處理嵌套代理
     *
     * @param target JDK 代理類對象
     * @return InvocationHandler 的 Class
     */
    private Class<?> getMapperInterfaceClass(Object target) {
        Object current = target;
        while (Proxy.isProxyClass(current.getClass())) {
            Object currentRefObject = AopProxyUtils.getSingletonTarget(current);
            if (currentRefObject == null) {
                break;
            }
            current = currentRefObject;
        }
        try {
            if (Proxy.isProxyClass(current.getClass())) {
                return (Class<?>) mapperInterfaceField.get(Proxy.getInvocationHandler(current));
            }
        } catch (IllegalAccessException ignore) {
        }
        return null;
    }

    /**
     * 通過 AnnotatedElement 查找標記的註解, 映射爲  DatasourceHolder
     *
     * @param ae AnnotatedElement
     * @return 數據源映射持有者
     */
    private String findDataSourceAttribute(AnnotatedElement ae) {
        AnnotationAttributes attributes = AnnotatedElementUtils.getMergedAnnotationAttributes(ae, DS.class);
        if (attributes != null) {
            return attributes.getString("value");
        }
        return null;
    }
}

ThreadLocal:

public final class DynamicDataSourceContextHolder {

    /**
     * 爲什麼要用鏈表存儲(準確的是棧)
     * <pre>
     * 爲了支持嵌套切換,如ABC三個service都是不同的數據源
     * 其中A的某個業務要調B的方法,B的方法需要調用C的方法。一級一級調用切換,形成了鏈。
     * 傳統的只設置當前線程的方式不能滿足此業務需求,必須使用棧,後進先出。
     * </pre>
     */
    private static final ThreadLocal<Deque<String>> LOOKUP_KEY_HOLDER = new NamedThreadLocal<Deque<String>>("dynamic-datasource") {
        @Override
        protected Deque<String> initialValue() {
            return new ArrayDeque<>();
        }
    };

    private DynamicDataSourceContextHolder() {
    }

    /**
     * 獲得當前線程數據源
     *
     * @return 數據源名稱
     */
    public static String peek() {
        return LOOKUP_KEY_HOLDER.get().peek();
    }

    /**
     * 設置當前線程數據源
     * <p>
     * 如非必要不要手動調用,調用後確保最終清除
     * </p>
     *
     * @param ds 數據源名稱
     */
    public static String push(String ds) {
        String dataSourceStr = StringUtils.isEmpty(ds) ? "" : ds;
        LOOKUP_KEY_HOLDER.get().push(dataSourceStr);
        return dataSourceStr;
    }

    /**
     * 清空當前線程數據源
     * <p>
     * 如果當前線程是連續切換數據源 只會移除掉當前線程的數據源名稱
     * </p>
     */
    public static void poll() {
        Deque<String> deque = LOOKUP_KEY_HOLDER.get();
        deque.poll();
        if (deque.isEmpty()) {
            LOOKUP_KEY_HOLDER.remove();
        }
    }

    /**
     * 強制清空本地線程
     * <p>
     * 防止內存泄漏,如手動調用了push可調用此方法確保清除
     * </p>
     */
    public static void clear() {
        LOOKUP_KEY_HOLDER.remove();
    }
}
  1. 啓動類上做如下配置:

引入我們寫的自動配置類,排除 ShardingJdbc 的自動配置類。

@SpringBootApplication(exclude = ShardingSphereAutoConfiguration.class)
@Import({DynamicDataSourceAutoConfiguration.class})
public class ShardingRunApplication {

    public static void main(String[] args) {
        SpringApplication.run(ShardingRunApplication.class);
    }
}

最後,我們給之前寫的 Repository 加上註解:

public interface BOrderRepository extends JpaRepository<BOrder,Long> {

    @DS("slave0")
    @Query(value = "SELECT * FROM (SELECT id,CASE WHEN company_id =1 THEN '小' WHEN company_id=4 THEN '中' ELSE '大' END AS com,user_id as userId FROM b_order0) t WHERE t.com ='中'",nativeQuery =true)
    List<Map<String, Object>> findOrders();
}

再次調用,查詢成功!!!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章