wisdomisite-java/src/main/java/com/zhgd/mybatis/DataScopeInterceptor.java

374 lines
17 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package com.zhgd.mybatis;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ReflectUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import com.zhgd.annotation.DataScope;
import com.zhgd.jeecg.common.util.SpringContextUtils;
import com.zhgd.xmgl.constant.Cts;
import com.zhgd.xmgl.entity.dto.OperLogDataChange;
import com.zhgd.xmgl.entity.dto.OperLogInsertChange;
import com.zhgd.xmgl.security.util.SecurityUtils;
import com.zhgd.xmgl.util.EnvironmentUtil;
import com.zhgd.xmgl.util.LogMdcUtil;
import com.zhgd.xmgl.util.PrintColorUtil;
import com.zhgd.xmgl.util.ThreadLocalUtil;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.type.TypeHandlerRegistry;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Array;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.SQLException;
import java.text.DateFormat;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.regex.Matcher;
@Slf4j
public class DataScopeInterceptor extends JsqlParserSupport implements InnerInterceptor {
@Setter
private DataScopeHandler dataScopeHandler;
public static boolean findIgnoreDataScope(Object parameter, DataScope annotation) {
if (annotation == null || !annotation.enable()) {
return true;
}
if (parameter instanceof Map) {
Map<?, ?> map = (Map<?, ?>) parameter;
for (Object k : map.keySet()) {
Object val = map.get(k);
if (k instanceof String) {
if (k.equals(Cts.IGNORE_DATA_SCOPE)) {
//查询只有一个参数map的时候key等于IGNORE_DATA_SCOPE不使用@Param时候
return true;
}
}
if (val instanceof QueryWrapper) {
String sqlSegment = ((QueryWrapper) val).getSqlSegment();
if (StrUtil.isNotBlank(sqlSegment) && sqlSegment.contains(Cts.IGNORE_DATA_SCOPE_CONDITION)) {
//调用mybatisplus的mapper方法使用last方法添加参数IGNORE_DATA_SCOPE_CONDITION
return true;
}
} else if (val instanceof Map) {
Map innerMap = (Map) val;
for (Object entry : innerMap.entrySet()) {
Map.Entry<Object, Object> et = (Map.Entry<Object, Object>) entry;
if (et.getKey().equals(Cts.IGNORE_DATA_SCOPE)) {
//查询只有一个参数map的时候key等于IGNORE_DATA_SCOPE使用@Param时候
return true;
}
}
}
}
}
if (parameter instanceof String) {
return parameter.equals(Cts.IGNORE_DATA_SCOPE);
}
return false;
}
public static boolean isNotSqlTest() {
try {
HttpServletRequest request = ((ServletRequestAttributes) (RequestContextHolder.currentRequestAttributes())).getRequest();
return request.getParameter("qqq") == null;
} catch (Exception e) {
return true;
}
}
private static String getCallPositionForDev() {
StringBuilder sb = new StringBuilder();
StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
for (StackTraceElement e : stackTrace) {
if (e.getClassName().startsWith("com.zhgd") && LogMdcUtil.notInPackage(e.getClassName()) && !e.getClassName().contains("$$")) {
sb.append("\r\n ");
sb.append(StrUtil.subAfter(e.getClassName(), ".", true));
sb.append(".");
sb.append(e.getMethodName());
sb.append("(");
sb.append(StrUtil.subAfter(e.getClassName(), ".", true));
sb.append(".java:");
sb.append(e.getLineNumber());
sb.append(")");
}
}
return PrintColorUtil.getPrintColorStr(sb.toString());
}
private static String getCallPosition() {
StringBuilder sb = new StringBuilder();
StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
for (StackTraceElement e : stackTrace) {
if (e.getClassName().startsWith("com.zhgd.xmgl")) {
sb.append(e.getClassName());
sb.append(".");
sb.append(e.getMethodName());
sb.append("(");
sb.append(StrUtil.subAfter(e.getClassName(), ".", true));
sb.append(".java:");
sb.append(e.getLineNumber());
sb.append(")");
break;
}
}
return sb.toString();
}
@Override
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpSh.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType();
try {
Class<?> clazz = Class.forName(ms.getId().substring(0, ms.getId().lastIndexOf(StringPool.DOT)));
String methodName = ms.getId().substring(ms.getId().lastIndexOf(".") + 1);
String dev = "gsx-other-env-show-dev";
if (dev.equals(EnvironmentUtil.getActiveEnvironment())) {
//开发环境
if (sct == SqlCommandType.SELECT) {
log.debug("查询mapper ↙↙↙ \r\n {}#{}{}", clazz.getName(), methodName, getCallPositionForDev());
} else if (sct == SqlCommandType.UPDATE) {
log.debug("更新mapper ↙↙↙ \r\n {}#{}{}", clazz.getName(), methodName, getCallPositionForDev());
} else if (sct == SqlCommandType.INSERT) {
log.debug("插入mapper ↙↙↙ \r\n {}#{}{}", clazz.getName(), methodName, getCallPositionForDev());
} else if (sct == SqlCommandType.DELETE) {
log.debug("删除mapper ↙↙↙ \r\n {}#{}{}", clazz.getName(), methodName, getCallPositionForDev());
}
}
// else {
// if (sct == SqlCommandType.SELECT) {
// log.debug("查询mapper方法: {} >>> {}#{}", getCallPosition(), clazz.getName(), methodName);
// } else if (sct == SqlCommandType.UPDATE) {
// log.debug("更新mapper方法: {} >>> {}#{}", getCallPosition(), clazz.getName(), methodName);
// } else if (sct == SqlCommandType.INSERT) {
// log.debug("插入mapper方法: {} >>> {}#{}", getCallPosition(), clazz.getName(), methodName);
// } else if (sct == SqlCommandType.DELETE) {
// log.debug("删除mapper方法: {} >>> {}#{}", getCallPosition(), clazz.getName(), methodName);
// }
// }
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
@Override
protected void processInsert(Insert insert, int index, String sql, Object obj) {
//dataScopeHandler.addParam(insert, obj);
}
@Override
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
try {
if (SecurityUtils.getUser() == null) {
return;
}
Class<?> clazz = Class.forName(ms.getId().substring(0, ms.getId().lastIndexOf(StringPool.DOT)));
String methodName = ms.getId().substring(ms.getId().lastIndexOf(".") + 1);
DataScope annotation = null;
Method[] declaredMethods = clazz.getDeclaredMethods();
Optional<DataScope> dsOption = Arrays.stream(declaredMethods).filter(method -> method.getName().equals(methodName)).map(method -> method.getAnnotation(DataScope.class)).filter(Objects::nonNull).findFirst();
annotation = dsOption.orElseGet(() -> clazz.getAnnotation(DataScope.class));
if (findIgnoreDataScope(parameter, annotation) && isNotSqlTest()) {
return;
}
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
JSONObject jo = new JSONObject();
jo.put("ds", annotation);
jo.put("parameter", parameter);
mpBs.sql(this.parserSingle(mpBs.sql(), jo));
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
@Override
public void beforeUpdate(Executor executor, MappedStatement ms, Object parameter) throws SQLException {
try {
if (Objects.equals(ThreadLocalUtil.getByKey(Cts.TL_IS_FROM_WEB, Boolean.class), true)) {
SqlCommandType sct = ms.getSqlCommandType();
if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
String sql = this.getShowSql(ms.getConfiguration(), ms.getBoundSql(parameter));
String whereSql = StrUtil.subAfter(sql, "WHERE", true);
saveQueryResult(StrUtil.sub(ms.getId(), 0, StringUtils.lastIndexOf(ms.getId(), ".")), whereSql);
} else if (sct == SqlCommandType.INSERT) {
List<OperLogInsertChange> paramList = ThreadLocalUtil.getByKey(Cts.TL_INSERT_BEFORE_PARAM, List.class);
if (paramList == null) {
paramList = new ArrayList<>();
ThreadLocalUtil.addInKey(Cts.TL_INSERT_BEFORE_PARAM, paramList);
}
OperLogInsertChange operLogInsertChange = new OperLogInsertChange();
operLogInsertChange.setMapperName(StrUtil.sub(ms.getId(), 0, StringUtils.lastIndexOf(ms.getId(), ".")));
operLogInsertChange.setResult(new ArrayList<>(Arrays.asList(parameter)));
operLogInsertChange.setTimestamp(System.currentTimeMillis());
paramList.add(operLogInsertChange);
}
}
} catch (Exception e) {
log.error("前后数据变化错误", e);
}
}
private void saveQueryResult(String mapperName, String whereSql) throws ClassNotFoundException {
QueryWrapper<Object> wrapper = Wrappers.query().last("WHERE " + whereSql);
Object mapperObj = SpringContextUtils.getBean(Class.forName(mapperName));
Method selectListMethod = ReflectUtil.getMethod(mapperObj.getClass(), "selectList", QueryWrapper.class);
Object rs = ReflectUtil.invoke(mapperObj, selectListMethod, wrapper);
List<OperLogDataChange> paramList = ThreadLocalUtil.getByKey(Cts.TL_UPDATE_DEL_BEFORE_PARAM, List.class);
if (paramList == null) {
paramList = new ArrayList<>();
ThreadLocalUtil.addInKey(Cts.TL_UPDATE_DEL_BEFORE_PARAM, paramList);
}
OperLogDataChange operLogDataChange = new OperLogDataChange();
operLogDataChange.setMapperName(mapperName);
operLogDataChange.setWhereSql(whereSql);
operLogDataChange.setResult(rs);
operLogDataChange.setTimestamp(System.currentTimeMillis());
paramList.add(operLogDataChange);
}
/**
* 获取完整的sql
*
* @param configuration
* @param boundSql
* @return
*/
private String getShowSql(Configuration configuration, BoundSql boundSql) {
Object parameterObject = boundSql.getParameterObject();
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
String sql = boundSql.getSql().replaceAll("[\\s]+", " ");
if (parameterMappings != null && parameterMappings.size() > 0 && parameterObject != null) {
TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
sql = sql.replaceFirst("\\?", Matcher.quoteReplacement(getParameterValue(parameterObject)));
} else {
MetaObject metaObject = configuration.newMetaObject(parameterObject);
for (ParameterMapping parameterMapping : parameterMappings) {
String propertyName = parameterMapping.getProperty();
if (metaObject.hasGetter(propertyName)) {
Object obj = metaObject.getValue(propertyName);
sql = sql.replaceFirst("\\?", Matcher.quoteReplacement(getParameterValue(obj)));
} else if (boundSql.hasAdditionalParameter(propertyName)) {
Object obj = boundSql.getAdditionalParameter(propertyName);
sql = sql.replaceFirst("\\?", Matcher.quoteReplacement(getParameterValue(obj)));
}
}
}
}
return sql;
}
private String getParameterValue(Object obj) {
String value = null;
if (obj instanceof String) {
value = "'" + obj.toString() + "'";
} else if (obj instanceof Date) {
DateFormat formatter = DateFormat.getDateTimeInstance(DateFormat.DEFAULT, DateFormat.DEFAULT, Locale.CHINA);
value = "'" + formatter.format(obj) + "'";
} else if (obj instanceof LocalDate) {
value = "'" + ((LocalDate) obj).format(DateTimeFormatter.ofPattern("yyyy-MM-dd")) + "'";
} else if (obj instanceof LocalDateTime) {
value = "'" + ((LocalDateTime) obj).format(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss")) + "'";
} else {
if (obj != null) {
value = obj.toString();
} else {
value = "";
}
}
return value;
}
@Override
protected void processSelect(Select select, int index, String sql, Object obj) {
this.processSelectBody(select.getSelectBody(), obj);
}
protected void processSelectBody(SelectBody selectBody, Object obj) {
if (selectBody != null) {
if (selectBody instanceof PlainSelect) {
this.processPlainSelect((PlainSelect) selectBody, obj);
} else if (selectBody instanceof WithItem) {
WithItem withItem = (WithItem) selectBody;
this.processSelectBody(withItem.getSelectBody(), obj);
} else {
SetOperationList operationList = (SetOperationList) selectBody;
if (operationList.getSelects() != null && operationList.getSelects().size() > 0) {
this.processSelectBody(operationList.getSelects().get(0), obj);
}
}
}
}
protected void processPlainSelect(PlainSelect plainSelect, Object obj) {
FromItem fromItem = plainSelect.getFromItem();
if (fromItem instanceof Table) {
this.dataScopeHandler.getSqlSegment(plainSelect, obj);
} else {
processFromItem(fromItem, obj);
}
// 如果还存在关联查询
List<Join> joins = plainSelect.getJoins();
if (CollUtil.isNotEmpty(joins)) {
for (Join join : joins) {
processJoin(join, obj);
}
}
}
protected void processFromItem(FromItem fromItem, Object obj) {
if (fromItem instanceof SubSelect) {
SubSelect subSelect = (SubSelect) fromItem;
if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody(), obj);
}
}
}
/**
* 处理关联查询
*
* @param join 关联查询
* @param obj
*/
protected void processJoin(Join join, Object obj) {
FromItem joinTable = join.getRightItem();
if (joinTable instanceof SubSelect) {
processSelectBody(((SubSelect) joinTable).getSelectBody(), obj);
}
}
}