MyBatis (八)—— 自定义一个小MyBatis

最近研究了一下Mybatis的底层代码,准备写一个操作数据库的小工具,实现了Mybatis的部分功能:

1. SQL语句在mapper.xml中配置。
2. 支持int,String,自定义数据类型的入参。
3. 根据mapper.xml动态创建接口的代理实现对象。

功能有限,目的是搞清楚MyBatis框架的底层思想,多学习研究优秀框架的实现思路,对提升自己的编码能力大有裨益。

小工具使用到的核心技术点:xml解析+反射+jdk动态代理

接下来,一步一步来实现。

首先来说为什么要使用jdk动态代理。

传统的开发方式:

  1. 接口定义业务方法。
  2. 实现类实现业务方法。
  3. 实例化实现类对象来完成业务操作。

Mybatis的方式:

  • 开发者只需要创建接口,定义业务方法。
  • 不需要创建实现类。
  • 具体的业务操作通过配置xml来完成。

MyBatis的方式省去了实现类的创建,改为用xml来定义业务方法的具体实现。

那么问题来了。

我们知道Java是面向对象的编程语言,程序在运行时执行业务方法,必须要有实例化的对象。但是,接口是不能被实例化的,而且也没有接口的实现类,那么此时这个对象从哪来呢?

程序在运行时,动态创建代理对象。

所以我们要用JDK动态代理,运行时结合接口和mapper.xml来动态创建一个代理对象,程序调用该代理对象的方法来完成业务。

动态代理参考→动态代理

代码实现

这是一个根据 MyBatis 源码设计的一个简单自定义 MyBatis:

一、数据准备

首先说一下我用到的 Bean 类和 Dao 层接口、以及对应的数据库表,需要解析的配置文件:

POJO类 Student:

public class Student {
    private int id;  //编号
    private String name;  //名字
    private String sex;   //性别
    private int age;      //年龄
    private String grade;    //班级

     //get、set、constructor、toString略
}

Dao层接口:StudentMapper,我只写了两个查询方法:

public interface StudentMapper {
    //根据编号查询单个学生并返回
    Student getStudentInId(int id);
    //返回所有的学生列表
    List<Student> getStudentList();
}

对应的表 student:
在这里插入图片描述
需要解析的配置文件:

在这里插入图片描述

db.properties:

jdbc.driver=com.mysql.jdbc.Driver
jdbc.url=jdbc:mysql://localhost:3306/test1?useSSL=false
jdbc.username=root
jdbc.password=123456

mappers文件下我只放了一个 xml 文件:student-mapper.xml:

<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper
        PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
        "http://mybatis.org/dtd/mybatis-3-mapper.dtd">

<mapper namespace="com.lzq.dao.StudentMapper">
    <select id="getStudentInId" parameterType="java.lang.Integer" resultType="com.lzq.bean.Student">
        select * from student where id = ?
    </select>

    <select id="getStudentList" resultType="com.lzq.bean.Student">
        select * from student
    </select>
</mapper>

注意:为了方便,我后面用 PreparedStatement 执行 SQL 语句,所以我直接将 #{} 换成了 ?,不然的话,后面需要一个方法转换,麻烦:
在这里插入图片描述
好了,就是这些数据和配置,现在就开始写 MyBatis 的代码了!

二、自定义的 MyBatis 实现

我们将 实现 MyBatis 的大概分三个阶段,数据的初始化(就是解析那些配置文件),动态代理(根据传的 class 生成对应的代理对象给用户),使用阶段;

1、初始化阶段
在这里插入图片描述
以上就是一个 mapper.xml 中最重要的属性了,一个命名空间(需要用这个反射来创建对象),多个方法 id (后面要用这个方法 id 去反射方法,所以这个地方的 id 必须和接口中方法的名字一样,否则人家底层反射的时候就反射不到了!)、以及返回值、SQL语句,所以我们需要将这些信息转换成一个对象,Java中万物皆对象嘛,如下 MappedStatement :

/**
 * @类名 MappedStatement
 * @类说明 存储 SQL 语句信息
 * @作者 中都
 * @时间 2020/2/9 16:43
 */
public class MappedStatement {
    private String namespase;  //命名空间
    private String sourceId;    //命名空间+方法名
    private String resultType;  //返回值
    private String sql;         //SQL语句
   
     //get、set略
}

这个对象储存的是一个 mapper.xml 文件中的一条 SQL 语句信息,但是一个 mapper.xml 中是会有多条 SQL 语句的,并且很多时候会有多个 mapper.xml 文件,所以我们需要一个类把这些配置信息(包括所有的 mapper.xml 文件信息、每个 mapper.xml 文件的每条 SQL 信息以及数据库的配置信息等等)汇总,即对象 Configuration:

/**
 * @类名 Configuration
 * @类说明 把所有的配置信息糅合在一起,一般在一个项目中这个的对象是一个单例
 * @作者 中都
 * @时间 2020/2/9 17:12
 */
public class Configuration {
    private String jdbcDriver;   //数据库驱动
    private String jdbcUrl;      //数据库密码
    private String jdbcUserName;  //用户名
    private String jdbcPassword;  //密码

    //SQL信息
    private Map<String,MappedStatement> mappedStatementMaps = new HashMap<>();

    //get、set略    
}

因为这个配置信息对象需要储存多个 mapper.xml 文件中的多条 SQL 语句信息,所以拿一个 map 对象去储存,key 是 namespace + id,即接口全路径名+方法名,value 就是对应的要执行的 SQL 语句;

我们平时在使用 MyBatis 的时候会有一个 SqlSessionFactory 的工厂类创建我们需要的 SqlSession 对象供我们使用,所以我模仿他们的实现,也实现了一个 SqlSessionFactory 工厂类,这个类用来初始化配置信息和创建 SQLSession 对象:

/**
 * @类名 SqlSessionFactory
 * @类说明 完成 Configuration 实例化,生产 SqlSession
 * @作者 中都
 * @时间 2020/2/9 17:20
 */
public class SqlSessionFactory {
    //final修饰,单例,配置信息
    private final Configuration configuration = new Configuration();

    //记录mapper.xml文件存放的位置,这个 resources 文件下的 mappers 访问不到,我给了全路径
    public static final String MAPPER_CONFIG = "G:\\idea工程\\框架复习\\src\\main\\resources\\mappers";
    //记录数据库连接信息文件存放的位置
    public static final String DB_CONFIG_FILE = "db.properties";


    public SqlSessionFactory() {
        loadDbInfo();
        loadMappersInfo();
    }

    /**
     * 加载数据库信息
     */
    private void loadDbInfo() {
        //加载数据库配置文件
        InputStream dbin = SqlSessionFactory.class.getClassLoader().getResourceAsStream(DB_CONFIG_FILE);
        Properties p = new Properties();
        try {
            p.load(dbin); //将配置信息写入 Properties 对象
        } catch (IOException e) {
            e.printStackTrace();
        }
        //将数据库配置信息写入 config 对象
        configuration.setJdbcDriver(p.get("jdbc.driver").toString());
        configuration.setJdbcUrl(p.get("jdbc.url").toString());
        configuration.setJdbcUserName(p.get("jdbc.username").toString());
        configuration.setJdbcPassword(p.get("jdbc.password").toString());
    }

    /**
     * 加载所有文件夹下的 xml 信息
     */
    private void loadMappersInfo() {
        File mappers = new File(MAPPER_CONFIG);
        if(mappers.isDirectory()) {
            File[] files = mappers.listFiles();
            for (File f : files) {
                loadMapperInfo(f);
            }
        }
    }

    /**
     * 解析单个 xml 文件
     * @param f
     */
    private void loadMapperInfo(File f) {
        //创建 saxReader 对象
        SAXReader reader = new SAXReader();
        //通过 read 方法读取一个文件,转换成 Document 对象
        Document document = null;
        try {
            document = reader.read(f);
        } catch (DocumentException e) {
            e.printStackTrace();
        }
        //获取根节点元素
        Element root = document.getRootElement();
        //获取命名空间
        String namespace = root.attribute("namespace").getData().toString();
        //获取 select 子节点列表
        List<Element> selects = root.elements("select");
        //遍历 select 节点,将信息记录到 MappedStatement 对象,并登记到 configuration
        for (Element e : selects) {
            MappedStatement mappedStatement = new MappedStatement();  //实例化一条 sql 语句记录
            String id = e.attribute("id").getData().toString();
            String resultType = e.attribute("resultType").getData().toString();
            String sql = e.getData().toString();  //去取 sql 语句
            String sourceId = namespace +"."+id;
            //给 MappedStatement 对象 赋值
            mappedStatement.setNamespase(namespace);
            mappedStatement.setResultType(resultType);
            mappedStatement.setSourceId(sourceId);
            mappedStatement.setSql(sql);
            //注册到 configuration
            configuration.getMappedStatementMaps().put(sourceId,mappedStatement);
        }
    }


    /**
     * 创建 SqlSession 对象
     * @return
     */
    public SqlSession openSession() {
        return new DefaultSqlSession(configuration);
    }
}

至此,第一阶段配置信息的初始化完成;

2、动态代理
既然要创建 SqlSession对象,那就需要 SqlSession 了,里面也只是实现了两个简单的查询方法:

/**
 * @类名 SqlSession
 * @类说明
 * 1、对外提供方法的接口
 * 2、对内将请求转发给 executor 执行
 * @作者 中都
 * @时间 2020/2/9 18:10
 */
public interface SqlSession {
    /**
     * 根据传入的条件查询单一结果
     * @param statement sql语句
     * @param parameter 传入的参数
     * @param <T> 返回值
     * @return
     */
    <T> T selectOne(String statement,Object parameter);

    <E> List<E> selectList(String statement, Object parameter);

    <T> T getMapper(Class<T> type);
}

上面只是一个接口,这是他的实现类:

/**
 * @类名 DefaultSqlSession
 * @类说明
 *  1、对外提供方法的接口
 *  2、对内将请求转发给 executor 执行
 * @作者 中都
 * @时间 2020/2/9 18:17
 */
public class DefaultSqlSession implements SqlSession {
    //final修饰,单例,配置信息
    private final Configuration conf;
    //真正的执行者,委托者
    private Executor executor;

    public DefaultSqlSession(Configuration conf) {
        this.conf = conf;
        this.executor = new DefaultExecutor(conf);
    }

    @Override
    public <T> T selectOne(String statement, Object parameter) {
        List<Object> selectList = this.selectList(statement,parameter);
        if(selectList == null || selectList.size() == 0) {
            return null;
        }else if(selectList.size() == 1) {
            return (T)selectList.get(0);
        }else {
            throw new RuntimeException("这不是单条记录!");
        }
    }

    @Override
    public <E> List<E> selectList(String statement, Object parameter) {
        MappedStatement mappedStatement = conf.getMappedStatementMaps().get(statement);
        return executor.query(mappedStatement,parameter);
    }

    @Override
    public <T> T getMapper(Class<T> type) {
        MapperProxy mapperProxy = new MapperProxy(this);
        return (T)Proxy.newProxyInstance(type.getClassLoader(),new Class[]{type},mapperProxy);
    }
}

因为 getMapper 是通过动态代理实现的,根据单一职责原则,SqlSession 主要是完成数据查询的,那动态代理的实现就需要另外一个类来完成了,即 MapperProxy :

/**
 * @类名 MapperProxy
 * @类说明 完成动态代理
 * @作者 中都
 * @时间 2020/2/9 20:39
 */
public class MapperProxy implements InvocationHandler {
    private SqlSession sqlSession;

    public MapperProxy(SqlSession sqlSession) {
        this.sqlSession = sqlSession;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        //判断方法的返回值是否是集合类型
        if(Collection.class.isAssignableFrom(method.getReturnType())) {
            //method.getDeclaringClass().getName()+"."+method.getName() 即类名.方法名,即配置文件中的命名空间
            return sqlSession.selectList(method.getDeclaringClass().getName()+"."+method.getName(),args == null?null:args[0]);
        }else {
            return sqlSession.selectOne(method.getDeclaringClass().getName()+"."+method.getName(),args == null?null:args[0]);
        }
    }
}

在上面 SqlSession 的实现类中我们也看到,SqlSession 就是那个代理类,它每次操作都是调用 Executor 去实现的:

/**
 * @类名 Executor
 * @类说明 核心接口之一,定义了数据库操作的最基本方法,SqlSession的功能都基于它来实现
 * @作者 中都
 * @时间 2020/2/9 20:01
 */
public interface Executor {
    /**
     * 查询接口
     * @param ms 封装有 SQL 语句的 MappedStatement 对象
     * @param parameter 传入的 SQL 参数
     * @param <E> 将参数转化成指定的结果集返回
     * @return
     */
    <E> List<E> query(MappedStatement ms, Object parameter);
}

下面是 Executor 的具体实现类,主要完成数据库操作,并将结果映射成需要返回的对象类型:

public class DefaultExecutor implements Executor {
    private final Configuration configuration;

    public DefaultExecutor(Configuration configuration) {
        this.configuration = configuration;
    }

    /**
     * 用于查询
     * @param ms 封装有 SQL 语句的 MappedStatement 对象
     * @param parameter 传入的 SQL 参数
     * @param <E>
     * @return
     */
    @Override
    public <E> List<E> query(MappedStatement ms, Object parameter) {
        List<E> ret = new ArrayList<>();//定义结果集;
        try {
            Class.forName(configuration.getJdbcDriver());
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        Connection connection = null;
        PreparedStatement preparedStatement = null;
        ResultSet resultSet = null;
        try {
            //获取连接,从 MappedStatement 获取数据库信息
            connection = DriverManager.getConnection(configuration.getJdbcUrl(),configuration.getJdbcUserName(),configuration.getJdbcPassword());
            //创建 preparedStatement 对象,从MappedStatement获取SQL语句
            preparedStatement = connection.prepareStatement(ms.getSql());
            //处理SQL语句中的占位符
            parameterize(preparedStatement,parameter);
            //执行查询操作获取 resultSet
            resultSet = preparedStatement.executeQuery();
            //将结果集通过反射技术,填充到list集合中
            handlerResultSet(resultSet,ret,ms.getResultType());
        } catch (SQLException e) {
            e.printStackTrace();
        } finally {
            try {
                resultSet.close();
                preparedStatement.close();
                connection.close();
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
        return ret;
    }

    /**
     * 对占位符进行处理
     * @param preparedStatement
     * @param parameter
     */
    private void parameterize(PreparedStatement preparedStatement,Object parameter) throws SQLException {
        if(parameter instanceof Integer) {
            preparedStatement.setInt(1,(int)parameter);
        }else if(parameter instanceof Long) {
            preparedStatement.setLong(1,(long)parameter);
        }else if(parameter instanceof String) {
            preparedStatement.setString(1,(String)parameter);
        }
    }

    /**
     * 读取 resultset 中的数据,并转换成目标对象
     * @param resultSet
     * @param ret
     * @param className
     * @param <E>
     */
    private <E> void handlerResultSet(ResultSet resultSet,List<E> ret,String className) {
        Class<E> eClass = null;
        try {
            //通过反射获取类对象;
            eClass = (Class<E>)Class.forName(className);
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        }
        try {
            while (resultSet.next()) {
                //通过反射实例化对象
                Object o = eClass.newInstance();
                //使用反射将 resultset 中的数据填充到 o 对象
                ReflectionUtil.setPropToBeanFromResultSet(o,resultSet);
                //将对象加入到返回集中
                ret.add((E)o);
            }
        } catch (SQLException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InstantiationException e) {
            e.printStackTrace();
        }
    }

它用到的工具类:

/**
 * @类名 ReflectionUtil
 * @类说明 利用反射帮助创建对象
 * @作者 中都
 * @时间 2020/2/9 21:37
 */
public class ReflectionUtil {

    /**
     * 为指定的 bean 的 proName 属性赋值 value
     * @param bean 目标对象
     * @param propName  对象的属性名
     * @param value 值
     */
    public static void setPropToBean(Object bean,String propName,Object value) {
        Field f;
        try {
            f = bean.getClass().getDeclaredField(propName);//获得对象指定的属性
            f.setAccessible(true); //可以访问
            f.set(bean,value);//为属性赋值
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (NoSuchFieldException e) {
            e.printStackTrace();
        }
    }

    /**
     * 将结果集映射到对象
     * @param entity
     * @param resultSet
     * @throws SQLException
     */
    public static void setPropToBeanFromResultSet(Object entity, ResultSet resultSet) throws SQLException {
        Field[] declaredFields = entity.getClass().getDeclaredFields(); //得到对象的所有属性
        for (int i = 0; i < declaredFields.length; i++) {
            if(declaredFields[i].getType().getSimpleName().equals("String")) {  //字符串类型
                setPropToBean(entity,declaredFields[i].getName(),resultSet.getString(declaredFields[i].getName()));
            }else if(declaredFields[i].getType().getSimpleName().equals("int")) {
                setPropToBean(entity,declaredFields[i].getName(),resultSet.getInt(declaredFields[i].getName()));
            }else if(declaredFields[i].getType().getSimpleName().equals("long")) {
                setPropToBean(entity,declaredFields[i].getName(),resultSet.getLong(declaredFields[i].getName()));
            }
        }

    }
}

至此,整个自定义 MyBatis 就写完了,现在只需要测试一下就好:

3、测试

public class Test2 {
    @Test
    public void Test2() {
        SqlSessionFactory factory = new SqlSessionFactory();
        SqlSession sqlSession =  factory.openSession();
        StudentMapper mapper = sqlSession.getMapper(StudentMapper.class);
        Student student = mapper.getStudentInId(3);
        System.out.println(student);

        System.out.println(" ======================== ");
        List<Student> studentList = mapper.getStudentList();
        for (Student stu : studentList) {
            System.out.println(stu);
        }
    }
}

结果如下:
在这里插入图片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章