mapper層sql校驗,在項目啓動前進行sql語法校驗,通常要到執行這個mapper纔會報錯。
package ix.account.util;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.ibatis.annotations.*;
import org.springframework.beans.factory.BeanDefinitionStoreException;
import org.springframework.context.ResourceLoaderAware;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternUtils;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import org.springframework.core.type.filter.TypeFilter;
import org.springframework.util.StringUtils;
import org.springframework.util.SystemPropertyUtils;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.*;
import java.util.stream.Collectors;
/**
* spring scaner
*/
@Slf4j
public class ClassScaner implements ResourceLoaderAware {
private final List<TypeFilter> includeFilters = new LinkedList<>();
private final List<TypeFilter> excludeFilters = new LinkedList<>();
private ResourcePatternResolver resourcePatternResolver = new PathMatchingResourcePatternResolver();
private MetadataReaderFactory metadataReaderFactory = new CachingMetadataReaderFactory(this.resourcePatternResolver);
public static Set<Class> scan(String[] basePackages,
Class<? extends Annotation>... annotations) {
ClassScaner classScaner = new ClassScaner();
if (ArrayUtils.isNotEmpty(annotations)) {
for (Class annotation : annotations) {
classScaner.addIncludeFilter(new AnnotationTypeFilter(annotation));
}
}
Set<Class> classes = new HashSet<>();
for (String s : basePackages) {
classes.addAll(classScaner.doScan(s));
}
return classes;
}
/**
* spring 指定包掃描
*
* @param basePackages 掃描包基本路徑
* @param annotations 具體掃描什麼註解 例如{@link Mapper}
* @return
*/
public static Set<Class> scan(String basePackages, Class<? extends Annotation>... annotations) {
return ClassScaner.scan(StringUtils.tokenizeToStringArray(basePackages, ",; \t\n"), annotations);
}
public final ResourceLoader getResourceLoader() {
return this.resourcePatternResolver;
}
@Override
public void setResourceLoader(ResourceLoader resourceLoader) {
this.resourcePatternResolver = ResourcePatternUtils
.getResourcePatternResolver(resourceLoader);
this.metadataReaderFactory = new CachingMetadataReaderFactory(
resourceLoader);
}
public void addIncludeFilter(TypeFilter includeFilter) {
this.includeFilters.add(includeFilter);
}
public void addExcludeFilter(TypeFilter excludeFilter) {
this.excludeFilters.add(0, excludeFilter);
}
public void resetFilters(boolean defaultFilters) {
this.includeFilters.clear();
this.excludeFilters.clear();
}
public Set<Class> doScan(String basePackage) {
Set<Class> classes = new HashSet<>();
try {
String packageSearchPath = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX
+ org.springframework.util.ClassUtils
.convertClassNameToResourcePath(SystemPropertyUtils
.resolvePlaceholders(basePackage))
+ "/**/*.class";
Resource[] resources = this.resourcePatternResolver
.getResources(packageSearchPath);
for (int i = 0; i < resources.length; i++) {
Resource resource = resources[i];
if (resource.isReadable()) {
MetadataReader metadataReader = this.metadataReaderFactory.getMetadataReader(resource);
boolean b = (includeFilters.size() == 0 && excludeFilters.size() == 0)
|| matches(metadataReader);
if (b) {
try {
classes.add(Class.forName(metadataReader
.getClassMetadata().getClassName()));
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
}
}
} catch (IOException ex) {
throw new BeanDefinitionStoreException(
"I/O failure during classpath scanning", ex);
}
return classes;
}
protected boolean matches(MetadataReader metadataReader) throws IOException {
for (TypeFilter tf : this.excludeFilters) {
if (tf.match(metadataReader, this.metadataReaderFactory)) {
return false;
}
}
for (TypeFilter tf : this.includeFilters) {
if (tf.match(metadataReader, this.metadataReaderFactory)) {
return true;
}
}
return false;
}
public static boolean getMethodAnnotation(String basePackages,
Class<? extends Annotation>... annotations) {
Set<Class> scan = scan(basePackages, annotations);
List<SqlErrorInfo> sqlErrorInfos = new ArrayList<>();
for (Class mapperClass : scan) {
Method[] methods = mapperClass.getMethods();
for (Method method : methods) {
Annotation[] annotations1 = method.getAnnotations();
for (Annotation annotation : annotations1) {
if (annotation instanceof Insert) {
List<String> collect = Arrays.stream(((Insert) annotation).value()).collect(Collectors.toList());
String sql = sqlAnnotValue(collect);
boolean b = crudCheck(sql);
if (b == false) {
sqlErrorInfos.add(SqlErrorInfo.builder().clazz(mapperClass.toString()).method(method.toString()).sql(sql).build());
}
} else if (annotation instanceof Select) {
List<String> collect = Arrays.stream(((Select) annotation).value()).collect(Collectors.toList());
String sql = sqlAnnotValue(collect);
boolean b = crudCheck(sql);
if (b == false) {
sqlErrorInfos.add(SqlErrorInfo.builder().clazz(mapperClass.toString()).method(method.toString()).sql(sql).build());
}
} else if (annotation instanceof Update) {
List<String> collect = Arrays.stream(((Update) annotation).value()).collect(Collectors.toList());
String sql = sqlAnnotValue(collect);
boolean b = crudCheck(sql);
if (b == false) {
sqlErrorInfos.add(SqlErrorInfo.builder().clazz(mapperClass.toString()).method(method.toString()).sql(sql).build());
}
} else if (annotation instanceof Delete) {
List<String> collect = Arrays.stream(((Delete) annotation).value()).collect(Collectors.toList());
String sql = sqlAnnotValue(collect);
boolean b = crudCheck(sql);
if (b == false) {
sqlErrorInfos.add(SqlErrorInfo.builder().clazz(mapperClass.toString()).method(method.toString()).sql(sql).build());
}
}
}
}
}
// System.out.println(sqlErrorInfos.size());
sqlErrorInfos.forEach(
info -> {
log.error("不正確的sql,不校驗<script>包裝,錯誤sql : " + info);
}
);
if (sqlErrorInfos.size() == 0) {
return true;
} else {
return false;
}
}
/**
* 將Mapper層註解中的sql獲取
*
* @param collect sqlCollect
* @return sql
*/
private static String sqlAnnotValue(List<String> collect) {
String sql;
if (collect.size() != 1) {
StringBuilder sbd = new StringBuilder();
collect.forEach(s -> {
sbd.append(s);
sbd.append(" ");
});
sql = sbd.toString();
} else {
sql = collect.get(0);
}
return sql;
}
/**
* crud sql校驗
*
* @param sql
*/
private static boolean crudCheck(String sql) {
// System.out.println("準備校驗的sql = " + sql);
if (sql.startsWith("<script>")) {
return true;
} else {
try {
MySqlStatementParser parser = new MySqlStatementParser(sql);
List<SQLStatement> stmtList = parser.parseStatementList();
int size = stmtList.size();
if (size != 0) {
return true;
} else {
return false;
}
} catch (Exception e) {
return false;
}
}
}
public static void main(String[] args) {
String basePackages = "ix.account.mapper";
Set<Class> scan = ClassScaner.scan(basePackages, Mapper.class);
getMethodAnnotation(basePackages, Mapper.class);
}
@Data
@Builder
private static class SqlErrorInfo {
private String method;
private String sql;
private String clazz;
}
}