1.Mybatis執行增刪改的底層核心接口是Executor.update,如果我們想對它進行攔截處理,則需要攔截此方法
2.Mybatis執行查詢的底層核心接口是ResultSetHandler.handleResultSets,如果我們想對它進行攔截處理,則需要攔截此方法
定義增刪改的攔截器:
@Intercepts({
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
})
@Slf4j
public class IbatisUpdateInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
try {
if (invocation.getTarget() instanceof Executor && invocation.getArgs().length == 2) {
return invokeUpdate(invocation);
}
} finally {
// 回滾到原生對象
MybatisContextHolder.rollback();
}
return invocation.proceed();
}
private Object invokeUpdate(Invocation invocation) throws Exception {
Executor executor = (Executor) invocation.getTarget();
// 獲取第一個參數
MappedStatement ms = (MappedStatement) invocation.getArgs()[0];
// 非insert/update,不處理
if (ms.getSqlCommandType() != SqlCommandType.INSERT && ms.getSqlCommandType() != SqlCommandType.UPDATE) {
return invocation.proceed();
}
Object paramObj = invocation.getArgs()[1];
// 無參數不處理
if (paramObj == null) {
return invocation.proceed();
}
if (!needEncryptDecrypt(paramObj) && !isBatchUpdate(paramObj)) {
return invocation.proceed();
}
// 加密
if (paramObj instanceof Map) {
Map map = (Map) paramObj;
// 批量更新
if (map.containsKey("list")) {
encryptList((Collection) map.get("list"));
} else {
for (Object updateObj : map.values()) {
if (needEncryptDecrypt(updateObj)) {
encryptSensitiveFields(updateObj);
}
}
}
} else if(paramObj instanceof Collection) {
encryptList((Collection) paramObj);
} else {
encryptSensitiveFields(paramObj);
}
return executor.update(ms, paramObj);
}
private static boolean needEncryptDecrypt(Object obj) {
Class<?> clazz = obj.getClass();
EncryptDecrypt annotation = AnnotationUtils.findAnnotation(clazz, EncryptDecrypt.class);
return annotation != null && annotation.needEncryptDecrypt();
}
private void encryptList(Collection params) throws Exception {
for (Object obj : params) {
encryptSensitiveFields(obj);
}
}
private boolean isBatchUpdate(Object paramObj) {
if (!(paramObj instanceof Map)) {
return false;
}
Map map = (Map) paramObj;
if (map.containsKey("list")) {
// 針對批量更新
return needEncryptDecrypt(map.get("list"));
} else {
for (Object updateObj : map.values()) {
// 針對updateByExample和updateByExampleSelective
if (needEncryptDecrypt(updateObj)) {
return true;
}
}
}
return false;
}
public static List<Field> getAllFields(Class clazz) {
List<Field> allFields = Lists.newArrayList();
while (clazz != null && clazz != Object.class) {
Field[] tmpFields = clazz.getDeclaredFields();
if (tmpFields != null) {
allFields.addAll(Arrays.asList(tmpFields));
}
clazz = clazz.getSuperclass();
}
return allFields;
}
public static boolean isStatic(Field f){
return (f.getModifiers() & Modifier.STATIC) == Modifier.STATIC;
}
private Object encryptSensitiveFields(Object paramObj) throws InvocationTargetException, IllegalAccessException {
List<Field> allFields = getAllFields(paramObj.getClass());
for (Field field : allFields) {
Sensitive sensitive = field.getAnnotation(Sensitive.class);
if (sensitive == null) {
continue;
}
PropertyDescriptor ps = BeanUtils.getPropertyDescriptor(paramObj.getClass(), field.getName());
if (ps.getReadMethod() == null || ps.getWriteMethod() == null) {
continue;
}
Object value = ps.getReadMethod().invoke(paramObj);
if (value != null) {
ps.getWriteMethod().invoke(paramObj, encryptDbField(value));
// register rollback
MybatisContextHolder.registerRollbackEvent(() -> ps.getWriteMethod().invoke(paramObj, value));
}
}
return paramObj;
}
private Object encryptDbField(Object value) {
//加密
return null;
}
@Override
public Object plugin(Object target) {
return null;
}
@Override
public void setProperties(Properties properties) {
}
}
定義查詢結果攔截器:
@Intercepts({
@Signature(type = ResultSetHandler.class, method = "handleResultSets", args = Statement.class)
})
@Slf4j
public class IbatisResultInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
Object result = invocation.proceed();
// 解密
if (result instanceof Collection) {
Collection<Object> objList = (Collection) result;
List<Object> decryptList = Lists.newArrayList();
for (Object obj : objList) {
decryptList.add(decrypt(obj));
}
return decryptList;
} else {
return decrypt(result);
}
}
private Object decrypt(Object obj) throws InvocationTargetException, IllegalAccessException {
List<Field> allFields = getAllFields(obj.getClass());
for (Field field : allFields) {
Sensitive sensitive = field.getAnnotation(Sensitive.class);
// 非敏感信息
if (sensitive == null) {
continue;
}
PropertyDescriptor ps = BeanUtils.getPropertyDescriptor(obj.getClass(), field.getName());
if (ps.getReadMethod() == null || ps.getWriteMethod() == null) {
continue;
}
Object value = ps.getReadMethod().invoke(obj);
if (value != null) {
try {
ps.getWriteMethod().invoke(obj, decryptDbField((String) value));
} catch (Exception e) {
log.error("decrypt error of field:{}", field.getName());
throw e;
}
}
}
return obj;
}
public List<Field> getAllFields(Class clazz) {
List<Field> allFields = Lists.newArrayList();
while (clazz != null && clazz != Object.class) {
Field[] tmpFields = clazz.getDeclaredFields();
if (tmpFields != null) {
allFields.addAll(Arrays.asList(tmpFields));
}
clazz = clazz.getSuperclass();
}
return allFields;
}
/**
* 對字段解密
* @param value
* @return
*/
private Object decryptDbField(String value) {
//解密,rsa,aes,des
return null;
}
@Override
public Object plugin(Object target) {
return null;
}
@Override
public void setProperties(Properties properties) {
}
}
註冊攔截器:
@Bean
@ConditionalOnProperty(value = "mybatis.encrypt.enabled", havingValue = "true")
public IbatisResultInterceptor myIbatisResultInterceptor (
SqlSessionFactory sqlSessionFactory) {
IbatisResultInterceptor interceptor = new IbatisResultInterceptor();
Properties properties = new Properties();
interceptor.setProperties(properties);
sqlSessionFactory.getConfiguration().addInterceptor(interceptor);
return interceptor;
}