mybatis的sql攔截器應用

mybatis的sql攔截器應用

前言

在項目上線前的開發及調試中,都需要清楚的知道每條sql的用時來避免在生產環境出現慢查詢,甚至對於業務邏輯較複雜的應用,還需防止大對象的產生(比如千萬級別的集合),此時我們完全可以使用mybatis的sql攔截器來實現以上效果。下面是一個sql攔截器的簡單記錄,以加深印象。

作用

攔截器實現的功能有:

  1. 攔截全表查詢sql(可配置白名單)
  2. 防止大對象的產生(集合大於4萬的sql進行告警,並不攔截)
  3. 輸出完整的sql
  4. 輸出sql的大約用時

示例代碼

一、引入相關mybatis的maven

<dependency>
    <groupId>org.mybatis</groupId>
    <artifactId>mybatis</artifactId>
    <version>3.4.5</version>
</dependency>

//用於解析sql
<dependency>
    <groupId>com.github.jsqlparser</groupId>
    <artifactId>jsqlparser</artifactId>
    <version>0.9.5</version>
</dependency>

(具體攔截器配置省略,如不清楚該學習了)

二、sql攔截器具體實現

@Slf4j
@Intercepts({
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})}
)
@Component
public class SqlInterceptor implements Interceptor {
    /**
     * 全表掃的白名單,接入了disconf,也可以hardcode字符串
     */
    @Autowired(required = false)
    private SqlIntercepterWhiteListConf sqlIntercepterWhiteListConf;

    /**
     * DQL最大返回值
     */
    private Integer allowMaxReturn;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        EmailPool emailPool = EmailPool.getInstance();
        String whiteList = sqlIntercepterWhiteListConf.getSqlWhiteList();
        Object result;
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        Object parameter = invocation.getArgs()[1];
        BoundSql boundSql = mappedStatement.getBoundSql(parameter);
        String sql = boundSql.getSql();
        Map<String, String> whereResult = getWhere(sql);
        String where = whereResult.get("success");
        String join = whereResult.get("join");
        if (Objects.nonNull(where) && StringUtils.isBlank(join)) {
            where = where.replace("1 = 1", "");
            String tableName = getMainTableName(sql);
            if (StringUtils.isNotEmpty(whiteList) && !StringUtils.containsIgnoreCase(whiteList, tableName) && StringUtils.isBlank(where)) {
                String emailMes = "vin-sql攔截器【已攔截】:出現不在配置內的全表掃描:" + tableName + "--->" + beautifySql(sql);
                log.warn(emailMes);
                emailPool.putEmailToQueue(emailMes);
    //                try {
    //                    EmailUtil.sendMessage(emailMes, "[email protected]");
    //                } catch (Exception e) {
    //                    log.warn("vin-sql攔截器發送郵件失敗!");
    //                }
                return getDefaultReturnValue(invocation);
            }
        }
        result = invocation.proceed();
        if (result instanceof Collection) {
            Collection c = (Collection) result;
            int size = c.size();
            if (size > allowMaxReturn) {
    //                result = getDefaultReturnValue(invocation);
                try {
                Configuration configuration = mappedStatement.getConfiguration();
                sql = getSql(configuration, boundSql);
                String emailMes = "vin-sql攔截器【未攔截】:發現DQL返回結果:" + size + "大於最大返回值數量" + allowMaxReturn + "---> " + sql;
                log.warn(emailMes);
                emailPool.putEmailToQueue(emailMes);

    //              EmailUtil.sendMessage(emailMes, "[email protected]");
                } catch (Exception e) {
                    log.warn("vin-sql攔截器發送郵件失敗!");
                }
            }
        }
        return result;
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof Executor) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }

    @Override
    public void setProperties(Properties properties) {
        if (Objects.isNull(properties)) {
            return;
        }
        if (properties.containsKey("allowMaxReturn")) {
            allowMaxReturn = Integer.valueOf(properties.getProperty("allowMaxReturn"));
        }
    }

    private String getSql(Configuration configuration, BoundSql boundSql) {
        // 輸入sql字符串空判斷
        String sql = boundSql.getSql();
        if (StringUtils.isBlank(sql)) {
            return "";
        }
        //美化sql
        sql = beautifySql(sql);
        //填充佔位符
        Object parameterObject = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        if (!parameterMappings.isEmpty() && parameterObject != null) {
            TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
            if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                sql = this.replacePlaceholder(sql, parameterObject);
            } else {
                MetaObject metaObject = configuration.newMetaObject(parameterObject);
                for (ParameterMapping parameterMapping : parameterMappings) {
                    String propertyName = parameterMapping.getProperty();
                    if (metaObject.hasGetter(propertyName)) {
                        Object obj = metaObject.getValue(propertyName);
                        sql = replacePlaceholder(sql, obj);
                    } else if (boundSql.hasAdditionalParameter(propertyName)) {
                        Object obj = boundSql.getAdditionalParameter(propertyName);
                        sql = replacePlaceholder(sql, obj);
                    }
                }
            }
        }
        return sql;
    }

    private String replacePlaceholder(String sql, Object parameterObject) {
        String result;
        if (parameterObject instanceof String) {
            result = "'" + parameterObject.toString() + "'";
        } else if (parameterObject instanceof Date) {
            result = "'" + getDate2String((Date) parameterObject) + "'";
        } else {
            result = parameterObject.toString();
        }
        return sql.replaceFirst("\\?", result);
    }

    private String getDate2String(Date parameterObject) {
        return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(parameterObject);
    }

    /**
     * 獲取where後的條件語句
     *
     * @param sql
     * @return
     * @throws JSQLParserException
     */
    public Map<String, String> getWhere(String sql) throws JSQLParserException {
        CCJSqlParserManager parserManager = new CCJSqlParserManager();
        Map<String, String> result = Maps.newHashMap();
        try {
            Select select = (Select) parserManager.parse(new StringReader(sql));
            PlainSelect plain = (PlainSelect) select.getSelectBody();
            Expression whereExpression = plain.getWhere();
            //獲取join
            List<Join> joins = plain.getJoins();
            result.put("success", Objects.isNull(whereExpression) ? "" : whereExpression.toString());
            result.put("join", CollectionUtils.isEmpty(joins) ? "" : "1");
            return result;
        } catch (Exception e) {
    //            log.warn("vin-sql攔截器無法解析sql:{}", sql);
                return result;
            }
        }

    /**
     * 獲取主表名
     *
     * @param sql
     * @return
     * @throws JSQLParserException
     */
    public String getMainTableName(String sql) throws JSQLParserException {
        try {
            Statement statement = CCJSqlParserUtil.parse(sql);
            Select selectStatement = (Select) statement;
            List<String> tableList = new TablesNamesFinder().getTableList(selectStatement);
            return tableList.get(0);
        } catch (Exception e) {
            log.warn("vin-sql攔截器無法解析sql:{}", sql);
            return "";
        }
    }

    /**
     * 返回默認的值,list類型的返回空list,數值類型的返回0
     *
     * @param invocation
     * @return
     */
    private Object getDefaultReturnValue(Invocation invocation) {
        Class returnType = invocation.getMethod().getReturnType();
        if (returnType.equals(List.class)) {
            return Lists.newArrayList();
        } else if (returnType.equals(Set.class)) {
            return Sets.newHashSet();
        } else if (returnType.equals(Integer.TYPE) || returnType.equals(Long.TYPE)
                || returnType.equals(Integer.class) || returnType.equals(Long.class)) {
            return 0;
        }
        return null;
    }

    /**
     * 美化sql
     *
     * @param sql
     * @return
     */
    private String beautifySql(String sql) {
        return sql.replaceAll("[\\s\n ]+", " ");
    }
}

效果:

三、估算sql耗時攔截器具體實現

@Slf4j
@Intercepts({
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})}
)
public class SqlInterceptor implements Interceptor {

    private int MIN_SIZE = -1;
    private int OPMITIZE_SIZE = 0;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];
        Object parameter = null;
        if (invocation.getArgs().length > 1) {
            parameter = invocation.getArgs()[1];
        }
        String sqlId = mappedStatement.getId();
        BoundSql boundSql = mappedStatement.getBoundSql(parameter);
        Configuration configuration = mappedStatement.getConfiguration();

        long startTime = System.currentTimeMillis();
        Object result = null;
        try {
            result = invocation.proceed();
        } finally {
            try {
                long sqlCostTime = System.currentTimeMillis() - startTime;
                String sql = getSql(configuration, boundSql);
                formatSqlLog(mappedStatement.getSqlCommandType(), sqlId, sql, sqlCostTime, result);
            } catch (Exception ignored) {

            }
        }
        return result;
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        if (properties == null) {
            return;
        }
        if (properties.containsKey("minLogSize")) {
            MIN_SIZE = Integer.valueOf(properties.getProperty("minLogSize"));
        }
    }

    private String getSql(Configuration configuration, BoundSql boundSql) {
        // 輸入sql字符串空判斷
        String sql = boundSql.getSql();
        if (StringUtils.isBlank(sql)) {
            return "";
        }

        //美化sql
        sql = beautifySql(sql);

        //填充佔位符, 目前基本不用mybatis存儲過程調用,故此處不做考慮
        Object parameterObject = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        if (!parameterMappings.isEmpty() && parameterObject != null) {
            TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
            if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                sql = this.replacePlaceholder(sql, parameterObject);
            } else {
                MetaObject metaObject = configuration.newMetaObject(parameterObject);
                for (ParameterMapping parameterMapping : parameterMappings) {
                    String propertyName = parameterMapping.getProperty();
                    if (metaObject.hasGetter(propertyName)) {
                        Object obj = metaObject.getValue(propertyName);
                        sql = replacePlaceholder(sql, obj);
                    } else if (boundSql.hasAdditionalParameter(propertyName)) {
                        Object obj = boundSql.getAdditionalParameter(propertyName);
                        sql = replacePlaceholder(sql, obj);
                    }
                }
            }
        }
        return sql;
    }

    private String beautifySql(String sql) {
        return sql.replaceAll("[\\s\n ]+", " ");
    }

    private String replacePlaceholder(String sql, Object parameterObject) {
        String result;
        if (parameterObject instanceof String) {
            result = "'" + parameterObject.toString() + "'";
        } else if (parameterObject instanceof Date) {
            result = "'" + getDate2String((Date) parameterObject) + "'";
        } else {
            result = parameterObject.toString();
        }
        return sql.replaceFirst("\\?", result);
    }

    private String getDate2String(Date parameterObject) {
        return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(parameterObject);
    }

    private void formatSqlLog(SqlCommandType sqlCommandType, String sqlId, String sql, long costTime, Object obj) {
        String log = String.format("DAO [%s]\n[%dms] ===> %s\n", sqlId, costTime, sql);

        if (sqlCommandType == SqlCommandType.UPDATE || sqlCommandType == SqlCommandType.INSERT || sqlCommandType == SqlCommandType.DELETE) {
            log += "Count ===> " + obj;
        }
        if (costTime > MIN_SIZE) {
            SqlInterceptor.log.warn(log);
        }
    }
}

效果:

總結:

以上是sql攔截器的基本作用,不管在測試環境還是生產環境,效果還是挺明顯的,媽媽再也不用擔心會發生FGC了。此外,sql攔截器還可以實現更多的東西,多人關注的話會繼續更。

ps:如有錯誤,歡迎指正。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章