手寫MyBatis

編寫類似Mybatis的持久層框架。

Mapper代理類:

public class MapperInvocationHandler implements InvocationHandler {
    private static String driverClassName = "com.mysql.jdbc.Driver";
    private static String URL = "jdbc:mysql://localhost:3306/taotao?useUnicode=true&characterEncoding=UTF-8&useSSL=false&rewriteBatchedStatements=true";
    private static String username = "root";
    private static String password = "123456";
    private ResultSet rs = null;
    private Connection conn=null;
    private PreparedStatement ps = null;
    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        //方法上標註了@Select註解
        Select select = method.getAnnotation(Select.class);
        Object o=null;
        if (select!=null) {
             o = selectMybatis(select, method, args);
        }
        return o;
    }

    public <E> Object selectMybatis( Select select,Method method,Object[] args){
        try {
            // 加載數據庫驅動( 成功加載後,會將Driver類的實例註冊到DriverManager類中)
            Class.forName(driverClassName);
            // 獲取數據庫連接
            conn= DriverManager.getConnection(URL, username, password);

            // 定義操作的SQL語句
            String sql = select.value()[0];

            // 獲取返回類型
            Class<?> returnType = method.getReturnType();

            //如果返回值類型爲集合
            //isAssignableFrom()方法是判斷是否爲某個類的父類,instanceof關鍵字是判斷是否某個類的子類。
            if (Collection.class.isAssignableFrom(returnType)){
                // 預編譯
                ps = conn.prepareStatement(sql);
                //執行
                rs = ps.executeQuery();
                //無數據
                if (!rs.next()){
                    return null;
                }
                List<E> result=new ArrayList<>();
                //獲取返回值的泛型
                Type type = method.getGenericReturnType();
                Class<?> actualType=null;
                if (type instanceof ParameterizedType){
                    Type[] actualTypeArguments = ((ParameterizedType) type).getActualTypeArguments();
                    //因爲list泛型只有一個值 所以直接取0下標
                    String typeName = actualTypeArguments[0].getTypeName();
                    //真實返回值類型 Class對象
                    actualType = Class.forName(typeName);
                }
                // 獲取並操作結果集
                //重置遊標的位置到第一行之前
                rs.beforeFirst();
                while (rs.next()) {
                    // 實例化對象
                    Object newInstance =  actualType.newInstance();
                    //設置結果集元數據
                    ResultSetMetaData rsmd = rs.getMetaData();
                    //列數
                    int columnCount = rsmd.getColumnCount();
                    for (int i=0;i<columnCount;i++){
                        //列名
                        String columnName = StringToolkit.underlineToCamel( rsmd.getColumnName(i+1));
                        //值
                        Object value = rs.getObject(i+1);
                        // 查找對應屬性
                        Field field = actualType.getDeclaredField(columnName);
                        // 設置允許私有訪問
                        field.setAccessible(true);
                        //設置值
                        field.set(newInstance, value);
                    }
                    result.add((E) newInstance);
                }
                return result;
            }//返回值類型爲普通對象,如:TblEmployeePO
            else {
                // 將方法上的參數存放在Map集合中
                Parameter[] parameters = method.getParameters();
                // 獲取方法上參數集合
                ConcurrentHashMap<Object, Object> parameterMap = getParams(parameters, args);
                // 獲取SQL要傳遞哪些參數
                List<String> sqlSelectParameter = SQLUtils.sqlSelectParameter(sql);
                //根據sqlSelectParameter中的參數,獲取parameterMap中的值,並把值放到集合parameValues中
                List<Object> parameValues = new ArrayList<>();
                for (int i = 0; i < sqlSelectParameter.size(); i++) {
                    String parameterName = sqlSelectParameter.get(i);
                    Object object = parameterMap.get(parameterName);
                    parameValues.add(object.toString());
                }
                // 變爲?號
                String newSql = SQLUtils.parameQuestion(sql, sqlSelectParameter);
                System.out.println("執行SQL:" + newSql + "參數信息:" + parameValues.toString());

                // 預編譯
                ps = conn.prepareStatement(newSql);
                for (int i = 0; i < parameValues.size(); i++) {
                    ps.setObject(i + 1, parameValues.get(i));
                }
                //執行查詢
                rs = ps.executeQuery();
                if (!rs.next()) {
                    // 沒有查找數據
                    return null;
                }
                // 向上移動
                rs.previous();
                // 實例化對象
                Object newInstance = returnType.newInstance();
                while (rs.next()) {
                    //設置結果集元數據
                    ResultSetMetaData rsmd = rs.getMetaData();
                    //列數
                    int columnCount = rsmd.getColumnCount();
                    for (int i=0;i<columnCount;i++){
                        //列名
                        String columnName = StringToolkit.underlineToCamel( rsmd.getColumnName(i+1));
                        //值
                        Object value = rs.getObject(i+1);
                        // 查找對應屬性
                        Field field = returnType.getDeclaredField(columnName);
                        // 設置允許私有訪問
                        field.setAccessible(true);
                        //設置值
                        field.set(newInstance, value);
                    }
                }
                return newInstance;
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        finally {
            try {
                rs.close();
                ps.close();
                conn.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
        return null;
    }

    private ConcurrentHashMap<Object, Object> getParams(Parameter[] parameters, Object[] args) {
        // 獲取方法上參數集合
        ConcurrentHashMap<Object, Object> parameterMap = new ConcurrentHashMap<>();
        for (int i = 0; i < parameters.length; i++) {
            // 參數信息
            Parameter parameter = parameters[i];
            Param extParam = parameter.getDeclaredAnnotation(Param.class);
            // 參數名稱
            String paramValue = extParam.value();
            // 參數值
            Object oj = args[i];
            parameterMap.put(paramValue, oj);
        }
        return parameterMap;
    }
}

工具類:

/**
 * SQL拼接<br>
 */
public class SQLUtils {
	/**
	 * 
	 * 獲取select 後面where語句<br>
	 * @param sql
	 * @return
	 */
	public static List<String> sqlSelectParameter(String sql) {
		int startIndex = sql.indexOf("where");
		int endIndex = sql.length();
		String substring = sql.substring(startIndex + 5, endIndex);
		String[] split = substring.split("and");
		List<String> listArr = new ArrayList<>();
		for (String string : split) {
			String[] sp2 = string.split("=");
			listArr.add(sp2[0].trim());
		}
		return listArr;
	}
	/**
	 * 將SQL語句的參數替換變爲?<br>
	 * @param sql
	 * @param parameterName
	 * @return
	 */
	public static String parameQuestion(String sql, String[] parameterName) {
		for (int i = 0; i < parameterName.length; i++) {
			String string = parameterName[i];
			sql = sql.replace("#{" + string + "}", "?");
		}
		return sql;
	}
	public static String parameQuestion(String sql, List<String> parameterName) {
		for (int i = 0; i < parameterName.size(); i++) {
			String string = parameterName.get(i);
			sql = sql.replace("#{" + string + "}", "?");
		}
		return sql;
	}
}
public class StringToolkit {
    //字符串下劃線格式轉駝峯(例:page_views -> pageViews)
    public static String underlineToCamel(String str){
        if (str == null || "".equals(str.trim())) {
            return str;
        }
        char[] chars = str.toCharArray();
        StringBuilder builder = new StringBuilder();
        for (int i = 0; i < chars.length; i++) {
            if (chars[i] == '_') {
                if (chars[i+1] >= 'a' && chars[i+1] <= 'z'){
                    chars[i+1] -= 32;
                    continue;
                }
            }
            builder.append(chars[i]);
        }
        return builder.toString();
    }
}

自定義註解:

@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.PARAMETER)
public @interface Param {
  String value();
}
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Select {
    String[] value();
}

Mapper:

public interface TblEmployeeMapper {
    @Select("select id,last_name,gender,email from tbl_employee")
    List<TblEmployeePO> getEmps();
    @Select("select id,last_name,gender,email from tbl_employee where id = #{id} ")
    TblEmployeePO getEmpsById(@Param("id") Long id);
}

po:省略getter,setter,toString

public class TblEmployeePO implements Serializable {
    private Integer id;

    private String lastName;

    private String gender;

    private String email;

    public TblEmployeePO(){}

    public TblEmployeePO(String lastName, String gender, String email) {
        this.lastName = lastName;
        this.gender = gender;
        this.email = email;
    }
 。。。。。
}

測試:

 @Test
 void test01() {
     Class[] classes=new Class[]{TblEmployeeMapper.class};
     TblEmployeeMapper  tblEmployeeMapper= (TblEmployeeMapper) Proxy.newProxyInstance(DemoApplicationTests.class.getClassLoader(),classes ,new MapperInvocationHandler());
//        List<TblEmployeePO> emps = tblEmployeeMapper.getEmps();
//        for (TblEmployeePO emp : emps) {
//            System.out.println(emp);
//        }
     TblEmployeePO empsById = tblEmployeeMapper.getEmpsById(345190l);
     System.out.println(empsById);
 }

參考:
https://blog.csdn.net/qq_35393693/article/details/80556007
https://www.jianshu.com/p/43e7c828082d


改進

添加自定義註解: 模擬mybatis的@MapperScan註解

@Import(MybatisImportBeanDefinitionRegistrar.class)
//讓當前註解生效
@Retention(RetentionPolicy.RUNTIME)
public @interface MapperSan {
}

配置類:

@ComponentScan("com.example.demo.*")
@Configuration
@MapperSan
public class MyConfig {
}

ImportBeanDefinitionRegistrar:模擬mybatisMapperScannerRegistrar,用來動態的注入一個bean->MapperFactoryBean

public class MybatisImportBeanDefinitionRegistrar implements ImportBeanDefinitionRegistrar {
    @Override
    public void registerBeanDefinitions(AnnotationMetadata importingClassMetadata, BeanDefinitionRegistry registry) {
        BeanDefinitionBuilder definitionBuilder=BeanDefinitionBuilder.genericBeanDefinition(MapperFactoryBean.class);
        GenericBeanDefinition beanDefinition =(GenericBeanDefinition)definitionBuilder.getBeanDefinition();
        //將com.example.demo.dao.TblEmployeeMapper傳入到MapperFactoryBean的有參構造函數,自動將String轉成Class類型
        beanDefinition.getConstructorArgumentValues().addGenericArgumentValue("com.example.demo.dao.TblEmployeeMapper");
        registry.registerBeanDefinition("tblEmployeeMapper", beanDefinition);
    }
}

MapperFactoryBean:模擬mybatisMapperFactoryBean,返回一個mapper的代理對象,mapper定義爲一個Class,由構造函數傳入

public class MapperFactoryBean implements FactoryBean {
    Class mapperInterface;
    public MapperFactoryBean( ) {
    }
    public MapperFactoryBean(Class mapperInterface) {
        this.mapperInterface = mapperInterface;
    }
    @Override
    public Object getObject() throws Exception {
        Class[] classes=new Class[]{mapperInterface};
        Object o  =  Proxy.newProxyInstance(MapperFactoryBean.class.getClassLoader(),classes ,new MapperInvocationHandler());
        return o;
    }
    @Override
    public Class<?> getObjectType() {
        return TblEmployeeMapper.class;
    }
    @Override
    public boolean isSingleton() {
        return false;
    }
}

測試:

@Component
public class TblEmployeeService {
    @Autowired
    TblEmployeeMapper tblEmployeeMapper;
    public void queryAll(){
        System.out.println(tblEmployeeMapper.getEmpsById(345190l));
    }
}
@Test
 public  void test(){
     AnnotationConfigApplicationContext applicationContext=
             new AnnotationConfigApplicationContext(MyConfig.class);
     //獲取所有beanName
//        String[] beanNamesForType = applicationContext.getBeanDefinitionNames();
//        for (String s: beanNamesForType){
//            System.out.println(s);
//        }
     TblEmployeeService bean = applicationContext.getBean(TblEmployeeService.class);
     bean.queryAll();
     applicationContext.close();
 }
 結果:
 TblEmployeePO{id=345190, lastName='張三', gender='1', email='asd'}
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章