Mini MyBatis-Plus(下)

发布时间:2023年12月26日
作者简介:大家好,我是smart哥,前中兴通讯、美团架构师,现某互联网公司CTO

联系qq:184480602,加我进群,大家一起学习,一起进步,一起对抗互联网寒冬
?

最核心的内容前两篇已经讲完了,这一篇只有代码:

先看demo目录下的三个文件:

DemoApplication.java

package com.example.demo;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;

@SpringBootApplication
public class DemoApplication {

    public static void main(String[] args) {
        SpringApplication.run(DemoApplication.class, args);
    }

}

User.java

package com.example.demo;

import com.example.demo.mybatisplus.annotations.TableName;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.Date;

/**
 * @author mx
 */
@Data
@TableName("t_user")
@AllArgsConstructor
@NoArgsConstructor
public class User {
    private Long id;
    private String name;
    private Integer age;
    private Date birthday;
}

UserMapper.java

package com.example.demo;

import com.example.demo.mybatisplus.AbstractBaseMapper;

/**
 * @author mx
 */
public class UserMapper extends AbstractBaseMapper<User> {
}

mybatisplus下AbstractBaseMapper.java

package com.example.demo.mybatisplus;

import com.example.demo.mybatisplus.annotations.TableName;
import com.example.demo.mybatisplus.core.JdbcTemplate;
import com.example.demo.mybatisplus.query.QueryWrapper;
import com.example.demo.mybatisplus.query.SqlParam;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

/**
 * Mapper基类
 *
 * @author mx
 */
public abstract class AbstractBaseMapper<T> {

    private static Logger logger = LoggerFactory.getLogger(AbstractBaseMapper.class);

    private JdbcTemplate<T> jdbcTemplate = new JdbcTemplate<T>();

    private Class<T> beanClass;

    private final String TABLE_NAME;

    private static final String DEFAULT_LOGICAL_TYPE = " and ";

    public AbstractBaseMapper() {
        // DO对象的Class
        beanClass = (Class<T>) ((ParameterizedType) this.getClass()
                .getGenericSuperclass())
                .getActualTypeArguments()[0];
        // DO对应的表名 TODO 非空判断及默认处理
        TABLE_NAME = beanClass.getAnnotation(TableName.class).value();
    }

    public T select(QueryWrapper<T> queryWrapper) {
        List<T> list = this.list(queryWrapper);
        if (!list.isEmpty()) {
            return list.get(0);
        }

        return null;
    }

    public List<T> list(QueryWrapper<T> queryWrapper) {
        StringBuilder sqlBuilder = new StringBuilder("SELECT * FROM ").append(TABLE_NAME).append(" WHERE ");

        List<Object> paramList = new ArrayList<>();
        Map<String, SqlParam> conditionMap = queryWrapper.build();
        conditionMap.forEach((operator, param) -> {
            sqlBuilder.append(param.getColumnName()).append(operator).append("?").append(DEFAULT_LOGICAL_TYPE);
            paramList.add(param.getValue());
        });

        // 删除最后一个 and
        String sql = sqlBuilder.replace(sqlBuilder.length() - DEFAULT_LOGICAL_TYPE.length(), sqlBuilder.length(), ";").toString();

        try {
            logger.info("sql: {}", sql);
            logger.info("params: {}", paramList);
            return jdbcTemplate.queryForList(sql, paramList, beanClass);
        } catch (Exception e) {
            e.printStackTrace();
            logger.error("query failed", e);
        }

        return Collections.emptyList();
    }

    public int insert(T bean) {
        // 得到DO对象的所有字段
        Field[] declaredFields = beanClass.getDeclaredFields();

        // 拼接sql语句,表名来自DO的TableName注解value
        StringBuilder sqlBuilder = new StringBuilder()
                .append("INSERT INTO ")
                .append(TABLE_NAME)
                .append(" VALUES(");
        for (int i = 0; i < declaredFields.length; i++) {
            sqlBuilder.append("?");
            if (i < declaredFields.length - 1) {
                sqlBuilder.append(",");
            }
        }
        sqlBuilder.append(")");

        // 收集sql参数
        ArrayList<Object> paramList = new ArrayList<>();
        try {
            for (Field declaredField : declaredFields) {
                declaredField.setAccessible(true);
                Object o = declaredField.get(bean);
                paramList.add(o);
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }

        int affectedRows = 0;
        try {
            logger.info("sql: {}", sqlBuilder.toString());
            logger.info("params: {}", paramList);
            affectedRows = jdbcTemplate.update(sqlBuilder.toString(), paramList);
            logger.info("insert success, affectedRows: {}", affectedRows);
            return affectedRows;
        } catch (SQLException e) {
            e.printStackTrace();
            logger.error("insert failed", e);
        }

        return 0;
    }

    public int updateSelective(T bean, QueryWrapper<T> queryWrapper) {
        // 得到DO对象的所有字段
        Field[] declaredFields = beanClass.getDeclaredFields();

        // 拼接sql语句,表名来自DO的TableName注解value
        StringBuilder sqlSetBuilder = new StringBuilder()
                .append("UPDATE ")
                .append(TABLE_NAME)
                .append(" SET ");

        List<Object> paramList = new ArrayList<>();

        // 先拼接要SET的字段占位符 SET name=?, age=?
        try {
            for (int i = 0; i < declaredFields.length; i++) {
                Field declaredField = declaredFields[i];
                declaredField.setAccessible(true);
                Object fieldValue = declaredField.get(bean);
                if (fieldValue != null) {
                    sqlSetBuilder.append(declaredField.getName()).append(" = ").append("?").append(", ");
                    paramList.add(fieldValue);
                }
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
        // 删除最后一个 ,
        sqlSetBuilder = sqlSetBuilder.delete(sqlSetBuilder.length() - 2, sqlSetBuilder.length());

        // 再拼接WHERE条件占位符
        StringBuilder sqlWhereBuilder = new StringBuilder(" WHERE ");
        Map<String, SqlParam> conditionMap = queryWrapper.build();
        for (Map.Entry<String, SqlParam> stringSqlParamEntry : conditionMap.entrySet()) {
            String operator = stringSqlParamEntry.getKey();
            SqlParam param = stringSqlParamEntry.getValue();
            sqlWhereBuilder.append(param.getColumnName()).append(operator).append("?").append(DEFAULT_LOGICAL_TYPE);
            paramList.add(param.getValue());
        }
        // 删除最后一个 and
        sqlWhereBuilder = sqlWhereBuilder.replace(sqlWhereBuilder.length() - DEFAULT_LOGICAL_TYPE.length(), sqlWhereBuilder.length(), ";");

        String sql = sqlSetBuilder.append(sqlWhereBuilder).toString();

        int affectedRows = 0;
        try {
            logger.info("sql: {}", sqlSetBuilder.toString());
            logger.info("params: {}", paramList);
            affectedRows = jdbcTemplate.update(sql, paramList);
            logger.info("update success, affectedRows: {}", affectedRows);
            return affectedRows;
        } catch (SQLException e) {
            e.printStackTrace();
            logger.error("update failed", e);
        }

        return 0;
    }

}

annotations下的TableName.java

package com.example.demo.mybatisplus.annotations;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * @author mx
 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface TableName {

    String value();

}

core下的

JdbcTemplate.java

package com.example.demo.mybatisplus.core;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

/**
 * JdbcTemplate,简化jdbc操作
 *
 * @author mx
 */
public class JdbcTemplate<T> {

    public List<T> queryForList(String sql, List<Object> params, RowMapper<T> rowMapper) throws SQLException {
        return query(sql, params, rowMapper);
    }

    public T queryForObject(String sql, List<Object> params, RowMapper<T> rowMapper) throws SQLException {
        List<T> result = query(sql, params, rowMapper);
        return result.isEmpty() ? null : result.get(0);
    }

    public List<T> queryForList(String sql, List<Object> params, Class<T> clazz) throws Exception {
        return query(sql, params, clazz);
    }

    public T queryForObject(String sql, List<Object> params, Class<T> clazz) throws Exception {
        List<T> result = query(sql, params, clazz);
        return result.isEmpty() ? null : result.get(0);
    }

    public int update(String sql, List<Object> params) throws SQLException {
        // 1.获取Connection
        Connection conn = getConnection();

        // 2.传入sql模板、sql参数,得到PreparedStatement
        PreparedStatement ps = getPreparedStatement(sql, params, conn);

        // 3.执行更新(增删改)
        int affectedRows = ps.executeUpdate();

        // 4.释放资源
        closeConnection(conn, ps, null);

        return affectedRows;
    }

    // ************************* private methods **************************

    private List<T> query(String sql, List<Object> params, RowMapper<T> rowMapper) throws SQLException {
        // 外部传入rowMapper(手写规则)
        return baseQuery(sql, params, rowMapper);
    }

    private List<T> query(String sql, List<Object> params, Class<T> clazz) throws Exception {
        // 自己创建rowMapper(反射)后传入
        BeanHandler<T> beanHandler = new BeanHandler<>(clazz);
        return baseQuery(sql, params, beanHandler);
    }

    /**
     * 基础查询方法,必须传入Bean的映射规则
     *
     * @param sql
     * @param params
     * @param rowMapper
     * @return
     * @throws SQLException
     */
    private List<T> baseQuery(String sql, List<Object> params, RowMapper<T> rowMapper) throws SQLException {
        // TODO 参数非空校验

        // 1.获取Connection
        Connection conn = getConnection();

        // 2.传入sql模板、sql参数,得到PreparedStatement
        PreparedStatement ps = getPreparedStatement(sql, params, conn);

        // 3.执行查询
        ResultSet rs = ps.executeQuery();

        // 4.处理结果
        List<T> result = new ArrayList<>();
        while (rs.next()) {
            T obj = rowMapper.mapRow(rs);
            result.add(obj);
        }

        // 5.释放资源
        closeConnection(conn, ps, rs);
        return result;
    }

    /**
     * 内部类,实现了RowMapper接口,底层使用反射
     *
     * @param <R>
     */
    private static class BeanHandler<R> implements RowMapper<R> {
        // clazz表示最终封装的bean类型
        private Class<R> clazz;

        public BeanHandler(Class<R> clazz) {
            this.clazz = clazz;
        }

        @Override
        public R mapRow(ResultSet rs) {
            try {
                if (rs.next()) {
                    // 1.获取表数据
                    ResultSetMetaData metaData = rs.getMetaData();

                    // 2.反射创建bean
                    R bean = clazz.newInstance();

                    // 3.利用反射,把表数据设置到bean中
                    for (int i = 0; i < metaData.getColumnCount(); i++) {
                        String name = metaData.getColumnName(i + 1);
                        Object value = rs.getObject(name);
                        Field field = clazz.getDeclaredField(name);
                        field.setAccessible(true);
                        field.set(bean, value);
                    }

                    // 4.返回bean
                    return bean;
                } else {
                    return null;
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    private PreparedStatement getPreparedStatement(String sql, List<Object> params, Connection conn) throws SQLException {
        // 1.传入sql模板,得到PreparedStatement
        PreparedStatement ps = conn.prepareStatement(sql);

        // 2.为sql模板设置参数
        for (int i = 0; i < params.size(); i++) {
            ps.setObject(i + 1, params.get(i));
        }

        return ps;
    }

    private Connection getConnection() throws SQLException {
        // TODO 可以抽取配置到properties文件
        String url = "jdbc:mysql://localhost:3306/demo";
        String user = "root";
        String password = "123456";
        return DriverManager.getConnection(url, user, password);
    }

    private void closeConnection(Connection conn, PreparedStatement preparedStatement, ResultSet rs) throws SQLException {
        if (rs != null) {
            rs.close();
        }

        if (preparedStatement != null) {
            preparedStatement.close();
        }

        if (conn != null) {
            conn.close();
        }
    }

}

?RowMapper.java

package com.example.demo.mybatisplus.core;

import java.sql.ResultSet;

/**
 * 结果集映射器
 *
 * @author mx
 */
@FunctionalInterface
public interface RowMapper<T> {
    /**
     * 将结果集转为指定的Bean
     *
     * @param resultSet
     * @return
     */
    T mapRow(ResultSet resultSet);
}

query下的

QueryWrapper.java

package com.example.demo.mybatisplus.query;

import com.example.demo.mybatisplus.utils.ConditionFunction;
import com.example.demo.mybatisplus.utils.Reflections;

import java.util.HashMap;
import java.util.Map;

/**
 * 模拟MyBatis-Plus的LambdaQueryWrapper(思路完全不同,仅仅是形似)
 *
 * @author mx
 */
public class QueryWrapper<T> {
    // conditionMap,收集查询条件
    // {
    //    " LIKE ": {
    //        "name": "bravo1988"
    //    },
    //    " = ": {
    //        "age": 18
    //    }
    // }
    private final Map<String, SqlParam> conditionMap = new HashMap<>();

    // 操作符类型,比如 name like 'bravo' 中的 LIKE
    private static final String OPERATOR_EQ = " = ";
    private static final String OPERATOR_GT = " > ";
    private static final String OPERATOR_LT = " < ";
    private static final String OPERATOR_LIKE = " LIKE ";

    public QueryWrapper<T> eq(ConditionFunction<T, ?> fn, Object value) {
        String columnName = Reflections.fnToColumnName(fn);
        conditionMap.put(OPERATOR_EQ, new SqlParam(columnName, value));
        return this;
    }

    public QueryWrapper<T> gt(ConditionFunction<T, ?> fn, Object value) {
        String columnName = Reflections.fnToColumnName(fn);
        conditionMap.put(OPERATOR_GT, new SqlParam(columnName, value));
        return this;
    }

    public QueryWrapper<T> lt(ConditionFunction<T, ?> fn, Object value) {
        String columnName = Reflections.fnToColumnName(fn);
        conditionMap.put(OPERATOR_LT, new SqlParam(columnName, value));
        return this;
    }

    public QueryWrapper<T> like(ConditionFunction<T, ?> fn, Object value) {
        String columnName = Reflections.fnToColumnName(fn);
        conditionMap.put(OPERATOR_LIKE, new SqlParam(columnName, "%" + value + "%"));
        return this;
    }

    public Map<String, SqlParam> build() {
        return conditionMap;
    }
}

?SqlParam.java

package com.example.demo.mybatisplus.query;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
 * @author mx
 */
@Data
@NoArgsConstructor
@AllArgsConstructor
public class SqlParam {
    private String columnName;
    private Object value;
}

utils下的

?ConditionFunction.java

package com.example.demo.mybatisplus.utils;

import java.io.Serializable;
import java.util.function.Function;

/**
 * 扩展java.util.function包下的Function接口:支持Serializable
 * 搭配Reflections工具类一起使用,用于获取Lambda表达式的方法名
 *
 * @author mx
 */
@FunctionalInterface
public interface ConditionFunction<T, R> extends Function<T, R>, Serializable {
}

Reflections.java

package com.example.demo.mybatisplus.utils;

import java.beans.Introspector;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
import java.util.regex.Pattern;

/**
 * 获取Lambda入参的方法名
 *
 * @author mx
 */
public class Reflections {
    private static final Pattern GET_PATTERN = Pattern.compile("^get[A-Z].*");
    private static final Pattern IS_PATTERN = Pattern.compile("^is[A-Z].*");

    /**
     * 注意: 非标准变量(非小驼峰)调用这个方法可能会有问题
     *
     * @param fn
     * @param <T>
     * @return
     */
    public static <T> String fnToColumnName(ConditionFunction<T, ?> fn) {
        try {
            Method method = fn.getClass().getDeclaredMethod("writeReplace");
            method.setAccessible(Boolean.TRUE);
            SerializedLambda serializedLambda = (SerializedLambda) method.invoke(fn);
            String getter = serializedLambda.getImplMethodName();
            // 对于非标准变量生成的Get方法这里可以直接抛出异常,或者打印异常日志
            if (GET_PATTERN.matcher(getter).matches()) {
                getter = getter.substring(3);
            } else if (IS_PATTERN.matcher(getter).matches()) {
                getter = getter.substring(2);
            }
            return Introspector.decapitalize(getter);
        } catch (ReflectiveOperationException e) {
            throw new RuntimeException(e);
        }
    }
}

其实第一篇的内容是最难的,不只是从0到1,而是从0到90,后面两篇其实只是90到100,在这基础稍微扩展了一下而已。

AbstractBaseMapper代码还有冗余,有兴趣的同学可以自行完善。但还是那句话,如果你的目的是为了锻炼封装能力,可以精益求精,但我们的AbstractBaseMapper注定不能用于生产,即使要优化,点到为止即可。

学习必须往深处挖,挖的越深,基础越扎实!

阶段1、深入多线程

阶段2、深入多线程设计模式

阶段3、深入juc源码解析

阶段4、深入jdk其余源码解析

阶段5、深入jvm源码解析

文章来源:https://blog.csdn.net/smart_an/article/details/135214271
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。