目錄
一、線程安全的數據源切換類(DataSourceSwitch.java)
在實際場景中,會遇到不同用戶擁有不同的數據源,這些數據源信息配置在數據庫表裏面,需要我們根據用戶切換成相應的數據源。在本文中,會介紹如何在SpringBoot + Mybatis中根據用戶切換數據源的配置。
一、線程安全的數據源切換類(DataSourceSwitch.java)
將ThreadLocal封裝成設置、保存、獲取和清空當前線程所屬用戶的數據源的工具類,具體代碼如下:
import javax.sql.DataSource;
/**
* 當前線程數據源工具類
*
* @author hrc
* @date 2019年1月29日
*/
public class DataSourceSwitch {
/**
* 保存數據源線程安全容器
*/
private static final ThreadLocal<DataSource> dataSourceThreadLocal = new ThreadLocal<DataSource>();
/**
* 設置數據源
* @param dataSource 數據源
*/
public static void setDataSource (DataSource dataSource) {
dataSourceThreadLocal.set(dataSource);
}
/**
* 獲取數據源
* @return
*/
public static DataSource getDataSource(){
return (DataSource) dataSourceThreadLocal.get();
}
/**
* 清空數據源
*/
public static void clearDataSource(){
dataSourceThreadLocal.remove();
}
}
二、多數據源類(MultiDataSource.java)
多數據源類實現了DataSource接口,並且該類一個單例模式。在mybatis的sqlSessionFactory在創建一個sqlSession的時候會調用DataSource裏的DataSource裏的方法,並從中獲取數據庫連接。所以我們要實現多數據源,就只需要把多數據源的getDataSource()方法寫成獲取當前線程的數據源,並且把DataSource接口的方法改成getDataSource().xxx實現就行了。多數據源的具體實現代碼如下:
import java.io.PrintWriter;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.util.logging.Logger;
import javax.sql.DataSource;
/**
* 多數據源類
*
* @author hrc
* @date 2019年1月29日
*/
public class MultiDataSource implements DataSource {
private static MultiDataSource multiDataSource = null;
private DataSource dataSource = null;
private MultiDataSource() {};
public static MultiDataSource getInstance() {
if(multiDataSource == null) {
synchronized (MultiDataSource.class) {
if (multiDataSource == null) {
multiDataSource = new MultiDataSource();
}
}
}
return multiDataSource;
}
public DataSource getDataSource() {
DataSource dataSource = DataSourceSwitch.getDataSource();
if (dataSource == null) {
dataSource = this.dataSource;
}
return dataSource;
}
public void setDataSource(DataSource dataSource) {
this.dataSource = dataSource;
}
@Override
public PrintWriter getLogWriter() throws SQLException {
return getDataSource().getLogWriter();
}
@Override
public void setLogWriter(PrintWriter out) throws SQLException {
getDataSource().setLogWriter(out);
}
@Override
public void setLoginTimeout(int seconds) throws SQLException {
getDataSource().setLoginTimeout(seconds);
}
@Override
public int getLoginTimeout() throws SQLException {
return getDataSource().getLoginTimeout();
}
@Override
public Logger getParentLogger() throws SQLFeatureNotSupportedException {
return getDataSource().getParentLogger();
}
@Override
public <T> T unwrap(Class<T> iface) throws SQLException {
return getDataSource().unwrap(iface);
}
@Override
public boolean isWrapperFor(Class<?> iface) throws SQLException {
return getDataSource().isWrapperFor(iface);
}
@Override
public Connection getConnection() throws SQLException {
return getDataSource().getConnection();
}
@Override
public Connection getConnection(String username, String password) throws SQLException {
return getDataSource().getConnection(username, password);
}
}
三、SpringBoot的數據源配置
在SpringBoot中手動配置數據源的具體代碼如下:
import javax.sql.DataSource;
import org.apache.ibatis.session.SqlSessionFactory;
import org.mybatis.spring.SqlSessionFactoryBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.core.env.Environment;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;
import com.alibaba.druid.pool.DruidDataSource;
import com.cheng.common.util.JdbcConfigUtil;
/**
* 數據源配置
*
* @author hrc
* @date 2018年10月9日
*/
@Configuration
public class DataSourceConfig {
@Autowired
private Environment env;
/**
* 爲多數據源創建默認的數據源
* @return
*/
public DataSource dataSource() {
MultiDataSource multiDataSource = MultiDataSource.getInstance();
String driverClassName = JdbcConfigUtil.getValue("jdbc.default.driverClassName");
String url = JdbcConfigUtil.getValue("jdbc.default.url");
String username = JdbcConfigUtil.getValue("jdbc.default.username");
String password = JdbcConfigUtil.getValue("jdbc.default.password");
DruidDataSource dataSource = new DruidDataSource();
dataSource.setDriverClassName(driverClassName);
dataSource.setUrl(url);
dataSource.setUsername(username);
dataSource.setPassword(password);
multiDataSource.setDataSource(dataSource);
return multiDataSource;
}
@Primary
@Bean("sqlSessionFactory")
public SqlSessionFactory sqlSessionFactory() throws Exception {
SqlSessionFactoryBean sessionFactory = new SqlSessionFactoryBean();
sessionFactory.setDataSource(this.dataSource());
sessionFactory.setTypeAliasesPackage(env.getProperty("mybatis.typeAliasesPackage"));
sessionFactory.setMapperLocations(new PathMatchingResourcePatternResolver().getResources(env.getProperty("mybatis.mapper-locations")));
return sessionFactory.getObject();
}
@Bean("transactionManager")
public DataSourceTransactionManager transactionManager() throws Exception {
return new DataSourceTransactionManager(this.dataSource());
}
}
四、使用過濾器在線程訪問數據前設置線程數據源
該多數源的設計時需要在線程訪問數據之前將當前線程的數據源設置成當前用戶相應的數據源,所以這就需要用到過濾器(Filter)了。使用過濾器對相應的請求進行攔截,在其到達controller層之前,調用DataSourceSwitch.setDataSource(dataSource)來設置當前線程的數據源。我寫的一個應用多數據源的過濾器例子,代碼如下:
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.sql.DataSource;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.WebApplicationContextUtils;
import com.alibaba.druid.pool.DruidDataSource;
import com.alibaba.fastjson.JSON;
import com.cheng.common.datasourse.DataSourceSwitch;
import com.cheng.common.util.JdbcConfigUtil;
import com.cheng.common.util.ResMapBuilder;
import com.cheng.common.util.Util;
import com.cheng.system.dto.User;
import com.cheng.system.service.UserService;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
* 基礎過濾器
*
* @author hrc
* @date 2018年12月7日
*/
public class BaseFilter implements Filter {
private UserService userService;
private final static Set<String> EXCLUDE_PATTERN = new HashSet<String>();
private final static HashMap<String, DataSource> DATA_SOURCE_MAP = new HashMap<String, DataSource>(50);
@Override
public void init(FilterConfig filterConfig) throws ServletException {
EXCLUDE_PATTERN.add("hanaAuth.action");
/*
*初始化時,將數據源加載到緩存中
*/
ServletContext sc = filterConfig.getServletContext();
WebApplicationContext cxt = WebApplicationContextUtils.getWebApplicationContext(sc);
if (cxt != null && cxt.getBean(UserService.class) != null && userService == null) {
userService = (UserService) cxt.getBean(UserService.class);
}
List<User> userList = userService.getAcctInfos("HANA");
if (!Util.isEmpty(userList)) {
String username = "";
String password = "";
for (User user : userList) {
username = user.getAcctCode();
password = user.getAcctPwd();
DataSource dataSource = configDataSource(username, password);
DATA_SOURCE_MAP.put(username, dataSource);
}
}
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain)
throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
HttpServletResponse httpServletResponse = (HttpServletResponse) response;
// 允許ajax跨域的參數設置
String origin = httpServletRequest.getHeader("Origin");
httpServletResponse.setHeader("Access-Control-Allow-Origin", origin);
httpServletResponse.setHeader("Access-Control-Allow-Credentials", "true");
httpServletResponse.setHeader("Access-Control-Allow-Headers", "X-Requested-With, accept, content-type");
httpServletResponse.setHeader("Access-Control-Allow-Methods", "GET, HEAD, POST, PUT, DELETE, TRACE, OPTIONS, PATCH");
String path = httpServletRequest.getRequestURI();
if (isExcludePattern(path)) {
filterChain.doFilter(request, response);
return;
}
// 進行登錄攔截
User user = null;
Object obj = httpServletRequest.getSession().getAttribute("user");
if (obj instanceof User) {
user = (User) obj;
}
if (user == null) {
httpServletResponse.setContentType("application/json; charset=utf-8");
httpServletResponse.setHeader("cache-control", "no-cache");
String msg = JSON.toJSONString(new ResMapBuilder().code(403).status(0).msg("未檢測到登錄狀態,請登錄").build());
PrintWriter out = httpServletResponse.getWriter();
out.println(msg);
out.flush();
out.close();
return;
}
String uamCode = user.getUamCode();
String username = user.getAcctCode();
if (DATA_SOURCE_MAP.keySet().contains(username)) {
DataSourceSwitch.setDataSource(DATA_SOURCE_MAP.get(username));
} else {
User acctInfo = userService.getAcctInfo(uamCode, "HANA");
String password = acctInfo.getAcctPwd();
DataSource dataSource = configDataSource(username, password);
DataSourceSwitch.setDataSource(dataSource);
DATA_SOURCE_MAP.put(username, dataSource);
}
filterChain.doFilter(request, response);
// 清楚當前線程的數據源
DataSourceSwitch.clearDataSource();
}
@Override
public void destroy() {
}
/**
* 是否是過濾的URL
*
* @param url
* @return
*/
private boolean isExcludePattern(String url) {
boolean isExcludePattern = false;
if (Util.isEmpty(url)) {
return isExcludePattern;
}
for (String str : EXCLUDE_PATTERN) {
if (url.endsWith(str)) {
isExcludePattern = true;
break;
}
}
return isExcludePattern;
}
/**
* 配置數據源
* @param username 用戶名
* @param password 密碼
* @return
*/
private DataSource configDataSource(String username, String password) {
/*
* 數據源配置參數
*/
String driverClassName = JdbcConfigUtil.getValue("jdbc.driverClassName");
String url = JdbcConfigUtil.getValue("jdbc.url");
String initialSize = JdbcConfigUtil.getValue("jdbc.initialSize");
String minIdle = JdbcConfigUtil.getValue("jdbc.minIdle");
String maxActive = JdbcConfigUtil.getValue("jdbc.maxActive");
String maxWait = JdbcConfigUtil.getValue("jdbc.maxWait");
String validationQuery = JdbcConfigUtil.getValue("jdbc.validationQuery");
String testOnBorrow = JdbcConfigUtil.getValue("jdbc.testOnBorrow");
/*
* 創建數據源對象
*/
DruidDataSource dataSource = new DruidDataSource();
dataSource.setDriverClassName(driverClassName);
dataSource.setUrl(url);
dataSource.setUsername(username);
dataSource.setPassword(password);
dataSource.setInitialSize(Util.isEmpty(initialSize) ? 5 : Integer.valueOf(initialSize));
dataSource.setMinIdle(Util.isEmpty(minIdle) ? 5 : Integer.valueOf(minIdle));
dataSource.setMaxActive(Util.isEmpty(maxActive) ? 20 : Integer.valueOf(maxActive));
dataSource.setMaxWait(Util.isEmpty(maxWait) ? 60000 : Long.valueOf(maxWait));
if (!Util.isEmpty(validationQuery)) {
dataSource.setValidationQuery(validationQuery);
dataSource.setTestOnBorrow(Util.isEmpty(testOnBorrow) ? false : Boolean.valueOf(testOnBorrow));
}
return dataSource;
}
}