package com.jqmobile.core.server.db.orm; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Timestamp; import java.util.ArrayList; import java.util.Collection; import java.util.Date; import java.util.List; import java.util.Map; import java.util.Set; import java.util.UUID; import com.jqmobile.core.orm.AutoVersionControl; import com.jqmobile.core.orm.DBColumn; import com.jqmobile.core.orm.ORM; import com.jqmobile.core.orm.TableUtil; import com.jqmobile.core.orm.exception.ORMException; import com.jqmobile.core.orm.exception.ORMNotDBTableException; import com.jqmobile.core.orm.exception.VersionFieldTypeException; import com.jqmobile.core.utils.TypeArgFinder; import com.jqmobile.core.utils.plain.BeanUtils; import com.jqmobile.core.utils.plain.GUIDUtils; import com.jqmobile.core.utils.plain.Log; class ORMImpl<T> extends ORMSImpl implements ORM<T> { private final BaseDBTable table; private final Class<T> beanClass; /** * 初始化 * * @param conn * 数据库connection * @param c * 要操作的实体class * @throws ORMNotDBTableException * @throws ORMException */ public ORMImpl(Connection conn, Class<T> c) throws ORMNotDBTableException, ORMException { super(conn); this.beanClass = c; table = BaseDBTable.getInstance(c); table.autoCreateTable(conn); table.updateTable(conn); } /** * 获取参数 * * @param obj * @param column * @return */ private Object getParam(Object obj, BaseDBColumn column) { if (null == column) { return getParam(obj); } return column.getFormatObj(obj); } @Override public int insert(T t) throws ORMException { int row = 0; try { PreparedStatement ps = getPrepareStatement(table.getInstnerSql(), t); try { row = ps.executeUpdate(); } finally { ps.close(); } if (table.hasChild()) { List<Class<?>> childClasses = table.getChildClasses(); for (Class<?> childClass : childClasses) { table.autoCreateMiddleTable(beanClass, childClass, getConnection()); // 从对象中获取所有关联对象 List<Object> objList = getChildObj(beanClass, childClass, t); for (Object o : objList) { insertChildData(o, t); } } } } catch (SQLException e) { throw new ORMException(e); } catch (IllegalArgumentException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } return row; } private TableUtil getTableUtil() { return BaseTableUtil.getInstance(getConnection()); } /** * 构建插入语句 * * @param obj * @param t * @return * @throws ORMException */ private String getInsertSql(Object obj, T t) throws ORMException { StringBuilder sb = new StringBuilder("insert into "); TableUtil btu = getTableUtil(); for (String tableName : BaseTableUtil.getTableName(t.getClass(), obj.getClass())) { if (btu.valiTableExist(tableName)) { sb.append(tableName + "(C1,C2"); } } sb.append(") values (?,?)"); return sb.toString(); } /** * 子表数据插入 * * @param obj * @param t * @throws ORMException * @throws IllegalAccessException * @throws IllegalArgumentException */ private void insertChildData(Object obj, T t) throws ORMException, IllegalArgumentException, IllegalAccessException { ORM orm =ORMFactory.instance(getConnection(), obj.getClass()); orm.insert(obj); PreparedStatement ps = null; String sql = getInsertSql(obj, t); try { ps = getConnection().prepareStatement(sql); ps.setObject(1, getIDValue(getIdValue(t))); ps.setObject(2, getIDValue(getIdValue(obj))); ps.executeUpdate(); } catch (SQLException e) { e.printStackTrace(); } finally { if (ps != null) { try { ps.close(); } catch (SQLException e) { e.printStackTrace(); } } } } /** * 获取主键值 * @param obj * @return byte数组 * @throws ORMException */ private Object getIDValue(Object obj) throws ORMException { if (obj instanceof UUID) { return GUIDUtils.getBytes((UUID) obj); } else if (obj instanceof String) { return GUIDUtils.getBytes(UUID.fromString((String) obj)); } else if (obj instanceof byte[] || obj instanceof Byte[]) { return obj; } else { throw new ORMException("主键必须为uuid或uuid的String形式"); } } /** * 获取子表对象 * * @param beanClass2 * @param childClass * @param t * @return * @throws IllegalAccessException * @throws IllegalArgumentException * @throws ORMException */ private List<Object> getChildObj(Class<T> beanClass2, Class<?> childClass, T t) throws ORMException, IllegalArgumentException, IllegalAccessException { Field[] fields = BeanUtils.getAllFields_Cache(beanClass2); List<Object> list = new ArrayList<Object>(); for (Field f : fields) { f.setAccessible(true); if (null != f.get(t)) { if (f.getType().isAssignableFrom(childClass)) { try { list.add(f.get(t)); break; } catch (IllegalArgumentException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } } else if (List.class.isAssignableFrom(f.getType()) && getFieldClass(f).isAssignableFrom(childClass)) { list.addAll((List) f.get(t)); break; } else if (Set.class.isAssignableFrom(f.getType()) && getFieldClass(f).isAssignableFrom(childClass)) { list.addAll((Set) f.get(t)); break; } } } return list; } /** * 如果属性是普通类,则直接返回,如果是集合,则返回泛型 * * @param c * @return * @throws ORMException */ public static Class<?> getFieldClass(Field field) throws ORMException { if (List.class.isAssignableFrom(field.getType())) { return TypeArgFinder.getFieldClassGenricType(field); } else if (Set.class.isAssignableFrom(field.getType())) { return TypeArgFinder.getFieldClassGenricType(field); } else if (Map.class.isAssignableFrom(field.getType())) { throw new ORMException("集合不能是MAP"); } else if (Object[].class.isAssignableFrom(field.getType())) { throw new ORMException("集合不能是Object"); } else { return field.getType(); } } /** * 将值封装进sql里 * * @param sql * 原生sql * @param t * 对象 * @return 返回PreparedStatement对象 * @throws SQLException */ private PreparedStatement getPrepareStatement(String sql, T t) throws SQLException { PreparedStatement ps = getPrepareStatement(sql); for (int i = 0; i < table.getMappingFields().size(); i++) { try { BaseDBColumn column = table.getMappingFields().get(i); if (column.getField().getAnnotation(DBColumn.class) != null && column.getField().getAnnotation(DBColumn.class).date()) { long longDate = (Long) column.getField().get(t); Timestamp tsDate = new Timestamp(longDate == 0 ? new Date().getTime() : longDate); ps.setTimestamp(i + 1, tsDate); } else { ps.setObject(i + 1, getParam(column.getField().get(t), column)); } } catch (Exception e) { Log.getLog(getClass()).e(e); continue; } } if (sql.startsWith("update")) { BaseDBColumn parmaryId = table.getParmaryId(); try { ps.setObject(table.getMappingFields().size() + 1, getParam(parmaryId.getField().get(t), parmaryId)); } catch (IllegalArgumentException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } } return ps; } @Override public int update(T t) throws ORMException { int row = 1; try { StringBuilder sb = new StringBuilder(table.getUpdateSqlByPID()); if (judgeInterface()) { sb.append(getVerSql(t)); row--; } PreparedStatement ps = getPrepareStatement(sb.toString(), t); try { row = ps.executeUpdate(); if (row == 0) { throw new ORMException("更新失败,可能版本不一致或者sql错误"); } } finally { ps.close(); } if (table.hasChild()) { List<Class<?>> childClasses = table.getChildClasses(); for (Class<?> childClass : childClasses) { List<Object> objList = getChildObj(beanClass, childClass, t); table.autoCreateMiddleTable(beanClass, childClass, getConnection()); deleteMiddleData(GUIDUtils.getGUID((byte[]) getIdValue(t)), childClass); for (Object o : objList) { updateChildData(o, t); } } } } catch (SQLException e) { throw new ORMException(e); } catch (IllegalArgumentException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } return row; } /** * 获取对象主键值 * * @param t * @return * @throws IllegalArgumentException * @throws IllegalAccessException * @throws ORMNotDBTableException */ private Object getIdValue(Object object) throws IllegalArgumentException, IllegalAccessException, ORMNotDBTableException { BaseDBTable table = BaseDBTable.getInstance(object.getClass()); BaseDBColumn column = table.getParmaryId(); if (null != column) return column.getFormatObj(column.getField().get(object)); return null; } /** * 更新映射表数据 * * @param obj * @param t * @throws IllegalAccessException * @throws IllegalArgumentException * @throws ORMException * @throws ORMNotDBTableException */ private void updateChildData(Object obj, T t) throws IllegalArgumentException, IllegalAccessException, ORMException { insertChildData(obj, t); } private String getVerSql(T t) { // 获取类中指定的版本字段 String validateWord = getValidateWord(t); // 获取对象中版本值 Object entityVW = getEntityVW(t, validateWord); return " and " + validateWord + " = " + entityVW; } @Override public int delete(String recid) throws ORMException { return delete(UUID.fromString(recid)); } @Override public T find(String recid) { return find(UUID.fromString(recid)); } @Override public int delete(UUID recid) throws ORMException { int row = 0; try { PreparedStatement ps = getPreparedStatement(table.getDeleteSqlByPID(), recid); try { row = ps.executeUpdate(); } finally { ps.close(); } if (table.hasChild()) { List<Class<?>> childClasses = table.getChildClasses(); for (Class<?> childClass : childClasses) { table.autoCreateMiddleTable(beanClass, childClass, getConnection()); deleteMiddleData(recid, childClass); } } } catch (SQLException e) { throw new ORMException(e); } return row; } /** * 主表删除(更新)数据,中间表对应关系删除 * * @param recid * 主表id * @param childClass * 字表类 * @throws SQLException * @throws ORMException */ private void deleteMiddleData(UUID recid, Class<?> childClass) throws SQLException, ORMException { TableUtil btu = getTableUtil(); for (String tableName : BaseTableUtil.getTableName(beanClass, childClass)) { if (btu.valiTableExist(tableName)) { String sql = "delete from " + tableName + " where C1=?"; PreparedStatement ps = getConnection().prepareStatement(sql); ps.setObject(1, GUIDUtils.getBytes(recid)); ps.executeUpdate(); } } } @Override public T find(UUID recid) { try { PreparedStatement ps = getPreparedStatement(table.getFindSqlByPID(), recid); try { ResultSet rs = ps.executeQuery(); try { if (rs.next()) return fillChildObj(instanceObject(rs), recid); } finally { rs.close(); } } finally { ps.close(); } } catch (Exception e) { Log.getLog(getClass()).e(e); } return null; } /** * 反射获取返回对象结果list。数据库字段名称,必须和属性名称一致 * * @param <T> * @param c * @param rs * @return List<T> * @throws Exception */ private List<T> getList(ResultSet rs) throws Exception { List<T> list = new ArrayList<T>(); List<BaseDBColumn> fields = table.getMappingFields(); String id = ""; for (BaseDBColumn bc : fields) { if (bc.isPaimaryId()) { id = bc.getColumnName(); } } if (!table.hasChild()) { while (rs.next()) { list.add(instanceObject(rs, fields)); } } else { while (rs.next()) { list.add(fillChildObj(instanceObject(rs), GUIDUtils.getGUID((byte[]) rs.getObject(id)))); } } return list; } /** * rs封装成对象 * * @param rs * @param fields * @return * @throws InstantiationException * @throws IllegalAccessException */ private T instanceObject(ResultSet rs, List<BaseDBColumn> fields) throws InstantiationException, IllegalAccessException { T t = beanClass.newInstance(); for (BaseDBColumn field : fields) { Object obj = field.getValue(rs); if (null == obj) { continue; } field.set(t, obj); } return t; } /** * rs封装成对象 * * @param rs * @return * @throws InstantiationException * @throws IllegalAccessException * @throws SQLException */ private T instanceObject(ResultSet rs) throws InstantiationException, IllegalAccessException, SQLException { T t = beanClass.newInstance(); List<BaseDBColumn> fields = table.getMappingFields(); for (BaseDBColumn field : fields) { Object obj = field.getValue(rs); if (null == obj) { continue; } field.set(t, obj); } return t; } @Override public List<T> query(String where, Object... args) throws ORMException { try { PreparedStatement ps = getPreparedStatement(table.getQuerySql(where), args); try { ResultSet rs = ps.executeQuery(); try { return getList(rs); } catch (Exception e) { throw new ORMException(e); } finally { rs.close(); } } finally { ps.close(); } } catch (SQLException e) { throw new ORMException(e); } } @Override public List<T> queryPage(String where, long startIndex, long endIndex, Object... args) throws ORMException { try { PreparedStatement ps = getPreparedStatement(table.getQuerySql(where) + " limit ?,?", args); ps.setLong(args.length + 1, startIndex); ps.setLong(args.length + 2, endIndex); try { ResultSet rs = ps.executeQuery(); try { return getList(rs); } catch (Exception e) { throw new ORMException(e); } finally { rs.close(); } } finally { ps.close(); } } catch (SQLException e) { throw new ORMException(e); } } @Override public T queryFirst(String where, Object... args) throws ORMException { try { PreparedStatement ps = getPreparedStatement(table.getQuerySql(where), args); try { ResultSet rs = ps.executeQuery(); try { if (rs.next()) return instanceObject(rs); return null; } catch (Exception e) { throw new ORMException(e); } finally { rs.close(); } } finally { ps.close(); } } catch (SQLException e) { throw new ORMException(e); } } @Override public int queryRow(String where, Object... args) throws ORMException { try { PreparedStatement ps = getPreparedStatement(table.getQueryRowSql(where), args); try { ResultSet rs = ps.executeQuery(); try { if (rs.next()) return rs.getInt(1); return 0; } catch (Exception e) { throw new ORMException(e); } finally { rs.close(); } } finally { ps.close(); } } catch (SQLException e) { throw new ORMException(e); } } @Override public int update(String set, Object... args) throws ORMException { try { PreparedStatement ps = getPreparedStatement(table.getModifySql(set), args); try { return ps.executeUpdate(); } finally { ps.close(); } } catch (SQLException e) { throw new ORMException(e); } } @Override public int delete(String where, Object... args) throws ORMException { try { PreparedStatement ps = getPreparedStatement(table.getDeleteSql(where), args); try { return ps.executeUpdate(); } finally { ps.close(); } } catch (SQLException e) { throw new ORMException(e); } } @Override public List<T> getAll() throws ORMException { return query("select * from " + table.getTableName()); } /** * 获取实体对象中版本控制字段值 * * @param t * 实体对象 * @param validateWord * 指定版本字段名称 * @return */ private Object getEntityVW(T t, String validateWord) { try { Field field = beanClass.getDeclaredField(validateWord); if (field.getType() == long.class) { String entityVW = "get" + validateWord.substring(0, 1).toUpperCase() + validateWord.substring(1, validateWord.length()); Method method = beanClass.getMethod(entityVW, null); return (Object) method.invoke(t, null); } else { throw new VersionFieldTypeException("版本字段必修是long类型"); } } catch (NoSuchMethodException e) { e.printStackTrace(); } catch (SecurityException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } catch (IllegalArgumentException e) { e.printStackTrace(); } catch (InvocationTargetException e) { e.printStackTrace(); } catch (NoSuchFieldException e) { e.printStackTrace(); } catch (VersionFieldTypeException e) { e.printStackTrace(); } return null; } /** * 获取类中指定的版本字段 * * @param t * 实体对象 * @return */ private String getValidateWord(T t) { try { Method method = beanClass.getMethod("getVersionWord", null); return (String) method.invoke(t, null); } catch (NoSuchMethodException e) { e.printStackTrace(); } catch (SecurityException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } catch (IllegalArgumentException e) { e.printStackTrace(); } catch (InvocationTargetException e) { e.printStackTrace(); } return null; } /** * 判断是否实现版本控制接口 * * @param t * 实体对象 * @return 若实现返回true,反之,返回false */ private boolean judgeInterface() { Class<?>[] classes = beanClass.getInterfaces(); for (Class<?> c : classes) { if (c.equals(AutoVersionControl.class)) { return true; } } return false; } /** * 非懒加载查询 * * @param t * @param recid * @return * @throws ORMException * @throws SQLException * @throws IllegalAccessException * @throws IllegalArgumentException */ private T fillChildObj(T t, UUID recid) throws ORMException, SQLException, IllegalArgumentException, IllegalAccessException { if (table.hasChild()) { List<Class<?>> childClasses = table.getChildClasses(); List<byte[]> bytes = new ArrayList<byte[]>(); for (Class<?> childClass : childClasses) { List<Object> list = new ArrayList<Object>(); TableUtil btu = getTableUtil(); for (String tableName : BaseTableUtil.getTableName(beanClass, childClass)) { if (btu.valiTableExist(tableName)) { String query = "select C2 from " + tableName + " where C1=?"; PreparedStatement ps = getConnection().prepareStatement(query); ps.setBytes(1, GUIDUtils.getBytes(recid)); ResultSet rs = ps.executeQuery(); while (rs.next()) { bytes.add(rs.getBytes("C2")); } ORM orm = ORMFactory.instance(getConnection(), childClass); for (byte[] b : bytes) { Object obj=orm.find(GUIDUtils.getGUID(b)); if(null!=obj){ list.add(obj); } } if (list.size() > 0) { t = setObjectValue(t, list); } bytes.clear(); } } } } return t; } /** * 查询通用方法 * * @param t * @param object * @throws IllegalAccessException * @throws IllegalArgumentException */ private T setObjectValue(T t, List<Object> list) throws IllegalArgumentException, IllegalAccessException { Field[] fields = BeanUtils.getAllFields_Cache(beanClass); Class<?> c = list.get(0).getClass(); for (Field f : fields) { f.setAccessible(true); BaseDBColumn column = BaseDBColumn.getInstance(f); if (column.isMapping() && f.getType().isAssignableFrom(c)) { f.set(t, list.get(0)); } else if (column.isMapping() && Collection.class.isAssignableFrom(f.getType()) && TypeArgFinder.getFieldClassGenricType(f).isAssignableFrom(c)) { f.set(t, list); } } return t; } }