package cn.org.rapid_framework.generator.provider.db.sql.model;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import cn.org.rapid_framework.generator.GeneratorConstants;
import cn.org.rapid_framework.generator.GeneratorProperties;
import cn.org.rapid_framework.generator.provider.db.sql.SqlFactory;
import cn.org.rapid_framework.generator.provider.db.table.model.Column;
import cn.org.rapid_framework.generator.provider.db.table.model.ColumnSet;
import cn.org.rapid_framework.generator.provider.db.table.model.Table;
import cn.org.rapid_framework.generator.util.StringHelper;
import cn.org.rapid_framework.generator.util.sqlparse.SqlParseHelper;
import cn.org.rapid_framework.generator.util.sqlparse.SqlTypeChecker;
import cn.org.rapid_framework.generator.util.sqlparse.SqlParseHelper.NameWithAlias;
import cn.org.rapid_framework.generator.util.typemapping.JavaPrimitiveTypeMapping;
/**
* 用于生成代码的Sql对象.对应数据库的sql语句
* 使用SqlFactory.parseSql()生成 <br />
*
* SQL参数同时支持以下几种语法
* <pre>
* hibernate: :username,
* ibatis2: #username#,$usename$,
* mybatis(or mybatis): #{username},${username}
* </pre>
* SQL对象创建示例:
* <pre>
* Sql sql = new SqlFactory().parseSql("select * from user_info where username=#username# and password=#password#");
* </pre>
*
* @see SqlFactory
* @author badqiu
*
*/
public class Sql {
public static String MULTIPLICITY_ONE = "one"; // select查询回一条记录, selectOne()
public static String MULTIPLICITY_MANY = "many"; // select查询回一个List, selectList()
public static String MULTIPLICITY_PAGING = "paging"; // 分页查询
public static String PARAMTYPE_PRIMITIVE = "primitive";
public static String PARAMTYPE_OBJECT = "object";
String operation = null; //这段 sql相对应的操作名称.
String resultClass; // select查询返回的结果集的class
String parameterClass; //参数代表的parameterClass
String remarks; //注释
String multiplicity = MULTIPLICITY_ONE; /* many or one or paging */
boolean paging = false; // 是否分页查询
String sqlmap; /* for ibatis and mybatis */
String resultMap = null; /* for ibatis and mybatis */
/** 代表一条select查询回来的结果列 */
LinkedHashSet<Column> columns = new LinkedHashSet<Column>();
/** 代表一条sql 查询有参数列表 */
LinkedHashSet<SqlParameter> params = new LinkedHashSet<SqlParameter>();
String sourceSql; // source sql
String executeSql; //代表在数据库执行的sql
private String paramType = PARAMTYPE_PRIMITIVE; /* primitive or object */
/** 代表一段SQL include 其它的sql片段. 如ibatis中的 <include refid='User.Where'/> */
private List<SqlSegment> sqlSegments = new ArrayList<SqlSegment>();
public Sql() {
}
/** 判断select查询回来的列是否是同一张表的字段 */
public boolean isColumnsInSameTable() {
// FIXME 还要增加表的列数与columns是否相等,才可以为select 生成 include语句
if(columns == null || columns.isEmpty()) return false;
Collection<NameWithAlias> tableNames = SqlParseHelper.getTableNamesByQuery(executeSql);
if(tableNames.size() > 1) {
return false;
}
Table t = SqlFactory.getTableFromCache(tableNames.iterator().next().getName());
for(Column c : columns) {
Column fromTableColumn = new ColumnSet(t.getColumns()).getBySqlName(c.getSqlName());
if(fromTableColumn == null) {
return false;
}
}
// Column firstTable = columns.iterator().next();
// if(columns.size() == 1) return true;
// if(firstTable.getTable() == null) {
// return false;
// }
//
// String preTableName = firstTable.getTable().getSqlName();
// for(Column c :columns) {
// Table table = c.getTable();
// if(table == null) {
// return false;
// }
// if(preTableName.equalsIgnoreCase(table.getSqlName())) {
// continue;
// }else {
// return false;
// }
// }
return true;
}
/**
* 得到select查询返回的resultClass,可以通过setResultClass()自定义,如果没有自定义则为你自动生成<br />
* resultClass可以为com.company.User的完全路径
* 示例:
* <pre>
* select count(*) from user, 返回值为: Long
* select * from user 返回值为: User
* select count(*) cnt, sum(age) sum_age 返回值为: getOperation()+"Result";
* </pre>
* @return
*/
public String getResultClass() {
String resultClass = _getResultClass();
if(isPaging() || MULTIPLICITY_MANY.equals(multiplicity)) {
return JavaPrimitiveTypeMapping.getWrapperType(resultClass);
}else {
return resultClass;
}
}
private String _getResultClass() {
if(StringHelper.isNotBlank(resultClass)) return resultClass;
if(columns.size() == 1) {
return columns.iterator().next().getSimpleJavaType();
}
if(isColumnsInSameTable()) {
Collection<NameWithAlias> tableNames = SqlParseHelper.getTableNamesByQuery(executeSql);
Table t = SqlFactory.getTableFromCache(tableNames.iterator().next().getName()); //FIXME 自定义的className将不会起作用,因为不是从同一个cache取的对象
return t.getClassName();
}else {
if(operation == null) return null;
return StringHelper.makeAllWordFirstLetterUpperCase(StringHelper.toUnderscoreName(operation))+GeneratorProperties.getProperty(GeneratorConstants.GENERATOR_SQL_RESULTCLASS_SUFFIX);
}
}
public void setResultClass(String queryResultClass) {
this.resultClass = queryResultClass;
}
public boolean isHasCustomResultClass() {
return StringHelper.isNotBlank(this.resultClass);
}
public boolean isHasResultMap() {
return StringHelper.isNotBlank(this.resultMap);
}
/**
* 返回getResultClass()的类名称 <br />
* 示例: <br />
* 如getResultClass()=com.company.User,将返回User
*/
public String getResultClassName() {
int lastIndexOf = getResultClass().lastIndexOf(".");
return lastIndexOf >= 0 ? getResultClass().substring(lastIndexOf+1) : getResultClass();
}
/**
* SQL参数过多时用于封装为一个ParameterObject的class<br />
* <pre>
* 可以通过setParameterClass()自定义
* 没有自定义则:
* 如果是select查询,返回 operation+"Query"
* 其它则返回operation+"Parameter"
* <pre>
*/
public String getParameterClass() {
if(StringHelper.isNotBlank(parameterClass)) return parameterClass;
if(StringHelper.isBlank(operation)) return null;
if(isSelectSql()) {
return StringHelper.makeAllWordFirstLetterUpperCase(StringHelper.toUnderscoreName(operation))+"Query";
}else {
return StringHelper.makeAllWordFirstLetterUpperCase(StringHelper.toUnderscoreName(operation))+"Parameter";
}
}
public void setParameterClass(String parameterClass) {
this.parameterClass = parameterClass;
}
/**
* 返回getParameterClass()的类名称 <br />
* 示例: <br />
* 如getParameterClass()=com.company.UserQuery,将返回UserQuery
*/
public String getParameterClassName() {
int lastIndexOf = getParameterClass().lastIndexOf(".");
return lastIndexOf >= 0 ? getParameterClass().substring(lastIndexOf+1) : getParameterClass();
}
// TODO columnsSize大于二并且不是在同一张表中,将创建一个QueryResultClassName类,同一张表中也要考虑创建类
public int getColumnsCount() {
return columns.size();
}
public void addColumn(Column c) {
columns.add(c);
}
/**
* 得到该sql方法相对应的操作名称,模板中的使用方式为: public List ${operation}(),示例值: findByUsername
* @return
*/
public String getOperation() {
return operation;
}
public void setOperation(String operation) {
this.operation = operation;
}
public String getOperationFirstUpper() {
return StringHelper.capitalize(getOperation());
}
/**
* 用于控制查询结果,固定值为:one,many
* @return
*/
public String getMultiplicity() {
return multiplicity;
}
public void setMultiplicity(String multiplicity) {
// TODO 是否要增加验证数据为 one,many
this.multiplicity = multiplicity;
}
/**
* 得到sqlect 查询的列对象(column),如果是insert,delete,update语句,则返回empty Set.<br />
* 示例:
* <pre>
* SQL : select count(*) cnt, sum(age) sum_age from user_info
* columns: cnt,sum_age
* </pre>
* @return
*/
public LinkedHashSet<Column> getColumns() {
return columns;
}
public void setColumns(LinkedHashSet<Column> columns) {
this.columns = columns;
}
/**
* 得到SQL的参数对象<br />
* 示例:
* <pre>
* SQL : select * from user_info where username=:user and password=:pwd limit :offset,:limit
* params: user,pwd,offset,limit
* </pre>
* @return
*/
public LinkedHashSet<SqlParameter> getParams() {
return params;
}
public void setParams(LinkedHashSet<SqlParameter> params) {
this.params = params;
}
public SqlParameter getParam(String paramName) {
for(SqlParameter p : getParams()) {
if(p.getParamName().equals(paramName)) {
return p;
}
}
return null;
}
/**
* 得到SQL原始语句
* @return
*/
public String getSourceSql() {
return sourceSql;
}
public void setSourceSql(String sourceSql) {
this.sourceSql = sourceSql;
}
public String getSqlmap() {
return getSqlmap(getParamNames());
}
public void setSqlmap(String sqlmap) {
if(StringHelper.isNotBlank(sqlmap)) {
sqlmap = StringHelper.replace(sqlmap, "${cdata-start}", "<![CDATA[");
sqlmap = StringHelper.replace(sqlmap, "${cdata-end}", "]]>");
}
this.sqlmap = sqlmap;
}
private List<String> getParamNames() {
List<String> paramNames = new ArrayList<String>();
for(SqlParameter p : params) {
paramNames.add(p.getParamName());
}
return paramNames;
}
private String getSqlmap(List<String> params) {
if (params == null || params.size() == 0) {
return sqlmap;
}
String result = sqlmap;
if (params.size() == 1) {
//FIXME: 与dalgen相比,修正是否将 ${param1} 的替换值是: value
return StringHelper.replace(result, "${param1}", "value");
} else {
for (int i = 0; i < params.size(); i++) {
result = StringHelper.replace(result, "${param" + (i + 1) + "}", params.get(i));
}
}
return result;
}
public boolean isHasSqlMap() {
return StringHelper.isNotBlank(sqlmap);
}
public String getResultMap() {
return resultMap;
}
public void setResultMap(String resultMap) {
this.resultMap = resultMap;
}
// public String replaceParamsWith(String prefix,String suffix) {
// String sql = sourceSql;
// List<SqlParameter> sortedParams = new ArrayList(params);
// Collections.sort(sortedParams,new Comparator<SqlParameter>() {
// public int compare(SqlParameter o1, SqlParameter o2) {
// return o2.paramName.length() - o1.paramName.length();
// }
// });
// for(SqlParameter s : sortedParams){ //FIXME 现在只实现了:username参数替换
// sql = StringHelper.replace(sql,":"+s.getParamName(),prefix+s.getParamName()+suffix);
// }
// return sql;
// }
/**
* sourceSql转换为在数据库实际执行的SQL,
* 示例:
* <pre>
* sourceSql: select * from user where username=:username and password=:password
* executeSql: select * from user where username=? and password=?
* </pre>
* @return
*/
public String getExecuteSql() {
return executeSql;
}
public void setExecuteSql(String executeSql) {
this.executeSql = executeSql;
}
public String getCountHql() {
return toCountSqlForPaging(getHql());
}
public String getCountSql() {
return toCountSqlForPaging(getSql());
}
public String getIbatisCountSql() {
return toCountSqlForPaging(getIbatisSql());
}
public String getMybatisCountSql() {
return toCountSqlForPaging(getMybatisSql());
}
public String getSqlmapCountSql() {
return toCountSqlForPaging(getSqlmap());
}
public String getSql() {
return replaceWildcardWithColumnsSqlName(sourceSql);
}
public static String toCountSqlForPaging(String sql) {
if(sql == null) return null;
if(SqlTypeChecker.isSelectSql(sql)) {
return SqlParseHelper.toCountSqlForPaging(sql, "select count(*) ");
}
return sql;
}
public String getSpringJdbcSql() {
return SqlParseHelper.convert2NamedParametersSql(getSql(),":","");
}
public String getHql() {
return SqlParseHelper.convert2NamedParametersSql(getSql(),":","");
}
public String getIbatisSql() {
return StringHelper.isBlank(ibatisSql) ? SqlParseHelper.convert2NamedParametersSql(getSql(),"#","#") : ibatisSql;
}
public String getMybatisSql() {
return StringHelper.isBlank(mybatisSql) ? SqlParseHelper.convert2NamedParametersSql(getSql(),"#{","}") : mybatisSql;
}
public void setIbatisSql(String ibatisSql) {
this.ibatisSql = ibatisSql;
}
public void setMybatisSql(String mybatisSql) {
this.mybatisSql = mybatisSql;
}
private String joinColumnsSqlName() {
// TODO 未解决 a.*,b.*问题
StringBuffer sb = new StringBuffer();
for(Iterator<Column> it = columns.iterator();it.hasNext();) {
Column c = it.next();
sb.append(c.getSqlName());
if(it.hasNext()) sb.append(",");
}
return sb.toString();
}
public String replaceWildcardWithColumnsSqlName(String sql) {
if(SqlTypeChecker.isSelectSql(sql) && SqlParseHelper.getSelect(SqlParseHelper.removeSqlComments(sql)).indexOf("*") >= 0 && SqlParseHelper.getSelect(SqlParseHelper.removeSqlComments(sql)).indexOf("count(") < 0) {
return SqlParseHelper.getPrettySql("select " + joinColumnsSqlName() + " " + SqlParseHelper.removeSelect(sql));
}else {
return sql;
}
}
public List<SqlSegment> getSqlSegments() {
return sqlSegments;
}
public void setSqlSegments(List<SqlSegment> includeSqls) {
this.sqlSegments = includeSqls;
}
public SqlSegment getSqlSegment(String id) {
for(SqlSegment seg : sqlSegments) {
if(seg.getId().equals(id)) {
return seg;
}
}
return null;
}
public List<SqlParameter> getFilterdWithSqlSegmentParams() {
List<SqlParameter> result = new ArrayList<SqlParameter>();
for(SqlParameter p : getParams()) {
if(isSqlSegementContainsParam(p.getParamName())) {
continue;
}
result.add(p);
}
return result;
}
private boolean isSqlSegementContainsParam(String paramName) {
for(SqlSegment seg : getSqlSegments()) {
//TODO 增加如果参数数是1,则不生成 SqlSegemnt,此处也要修改对1的特殊控制
if(seg.getParamNames().contains(paramName)) {
return true;
}
}
return false;
}
/**
* 当前的sourceSql是否是select语句
* @return
*/
public boolean isSelectSql() {
return SqlTypeChecker.isSelectSql(sourceSql);
}
/**
* 当前的sourceSql是否是update语句
* @return
*/
public boolean isUpdateSql() {
return SqlTypeChecker.isUpdateSql(sourceSql);
}
/**
* 当前的sourceSql是否是delete语句
* @return
*/
public boolean isDeleteSql() {
return SqlTypeChecker.isDeleteSql(sourceSql);
}
/**
* 当前的sourceSql是否是insert语句
* @return
*/
public boolean isInsertSql() {
return SqlTypeChecker.isInsertSql(sourceSql);
}
/**
* 得到备注
* @return
*/
public String getRemarks() {
return remarks;
}
public String getParamType() {
return paramType;
}
public void setParamType(String paramType) {
this.paramType = paramType;
}
public void setRemarks(String comments) {
this.remarks = comments;
}
public boolean isPaging() {
if(MULTIPLICITY_PAGING.equalsIgnoreCase(multiplicity)) {
return true;
}
return paging;
}
public void setPaging(boolean paging) {
this.paging = paging;
}
public Column getColumnBySqlName(String sqlName) {
for(Column c : getColumns()) {
if(c.getSqlName().equalsIgnoreCase(sqlName)) {
return c;
}
}
return null;
}
public Column getColumnByName(String name) {
Column c = getColumnBySqlName(name);
if(c == null) {
c = getColumnBySqlName(StringHelper.toUnderscoreName(name));
}
return c;
}
public void afterPropertiesSet() {
for(SqlSegment seg : sqlSegments) {
seg.setParams(seg.getParams(this));
}
}
public String toString() {
return "sourceSql:\n"+sourceSql+"\nsql:"+getSql();
}
private String ibatisSql;
private String mybatisSql;
}