package com.jqmobile.core.server.db.orm;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import com.jqmobile.core.orm.AutoVersionControl;
import com.jqmobile.core.orm.DBColumn;
import com.jqmobile.core.orm.ORM;
import com.jqmobile.core.orm.TableUtil;
import com.jqmobile.core.orm.exception.ORMException;
import com.jqmobile.core.orm.exception.ORMNotDBTableException;
import com.jqmobile.core.orm.exception.VersionFieldTypeException;
import com.jqmobile.core.utils.TypeArgFinder;
import com.jqmobile.core.utils.plain.BeanUtils;
import com.jqmobile.core.utils.plain.GUIDUtils;
import com.jqmobile.core.utils.plain.Log;
class ORMImpl<T> extends ORMSImpl implements ORM<T> {
private final BaseDBTable table;
private final Class<T> beanClass;
/**
* 初始化
*
* @param conn
* 数据库connection
* @param c
* 要操作的实体class
* @throws ORMNotDBTableException
* @throws ORMException
*/
public ORMImpl(Connection conn, Class<T> c) throws ORMNotDBTableException, ORMException {
super(conn);
this.beanClass = c;
table = BaseDBTable.getInstance(c);
table.autoCreateTable(conn);
table.updateTable(conn);
}
/**
* 获取参数
*
* @param obj
* @param column
* @return
*/
private Object getParam(Object obj, BaseDBColumn column) {
if (null == column) {
return getParam(obj);
}
return column.getFormatObj(obj);
}
@Override
public int insert(T t) throws ORMException {
int row = 0;
try {
PreparedStatement ps = getPrepareStatement(table.getInstnerSql(), t);
try {
row = ps.executeUpdate();
} finally {
ps.close();
}
if (table.hasChild()) {
List<Class<?>> childClasses = table.getChildClasses();
for (Class<?> childClass : childClasses) {
table.autoCreateMiddleTable(beanClass, childClass, getConnection());
// 从对象中获取所有关联对象
List<Object> objList = getChildObj(beanClass, childClass, t);
for (Object o : objList) {
insertChildData(o, t);
}
}
}
} catch (SQLException e) {
throw new ORMException(e);
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
return row;
}
private TableUtil getTableUtil() {
return BaseTableUtil.getInstance(getConnection());
}
/**
* 构建插入语句
*
* @param obj
* @param t
* @return
* @throws ORMException
*/
private String getInsertSql(Object obj, T t) throws ORMException {
StringBuilder sb = new StringBuilder("insert into ");
TableUtil btu = getTableUtil();
for (String tableName : BaseTableUtil.getTableName(t.getClass(), obj.getClass())) {
if (btu.valiTableExist(tableName)) {
sb.append(tableName + "(C1,C2");
}
}
sb.append(") values (?,?)");
return sb.toString();
}
/**
* 子表数据插入
*
* @param obj
* @param t
* @throws ORMException
* @throws IllegalAccessException
* @throws IllegalArgumentException
*/
private void insertChildData(Object obj, T t) throws ORMException, IllegalArgumentException, IllegalAccessException {
ORM orm =ORMFactory.instance(getConnection(), obj.getClass());
orm.insert(obj);
PreparedStatement ps = null;
String sql = getInsertSql(obj, t);
try {
ps = getConnection().prepareStatement(sql);
ps.setObject(1, getIDValue(getIdValue(t)));
ps.setObject(2, getIDValue(getIdValue(obj)));
ps.executeUpdate();
} catch (SQLException e) {
e.printStackTrace();
} finally {
if (ps != null) {
try {
ps.close();
} catch (SQLException e) {
e.printStackTrace();
}
}
}
}
/**
* 获取主键值
* @param obj
* @return byte数组
* @throws ORMException
*/
private Object getIDValue(Object obj) throws ORMException {
if (obj instanceof UUID) {
return GUIDUtils.getBytes((UUID) obj);
} else if (obj instanceof String) {
return GUIDUtils.getBytes(UUID.fromString((String) obj));
} else if (obj instanceof byte[] || obj instanceof Byte[]) {
return obj;
} else {
throw new ORMException("主键必须为uuid或uuid的String形式");
}
}
/**
* 获取子表对象
*
* @param beanClass2
* @param childClass
* @param t
* @return
* @throws IllegalAccessException
* @throws IllegalArgumentException
* @throws ORMException
*/
private List<Object> getChildObj(Class<T> beanClass2, Class<?> childClass, T t) throws ORMException, IllegalArgumentException, IllegalAccessException {
Field[] fields = BeanUtils.getAllFields_Cache(beanClass2);
List<Object> list = new ArrayList<Object>();
for (Field f : fields) {
f.setAccessible(true);
if (null != f.get(t)) {
if (f.getType().isAssignableFrom(childClass)) {
try {
list.add(f.get(t));
break;
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
} else if (List.class.isAssignableFrom(f.getType()) && getFieldClass(f).isAssignableFrom(childClass)) {
list.addAll((List) f.get(t));
break;
} else if (Set.class.isAssignableFrom(f.getType()) && getFieldClass(f).isAssignableFrom(childClass)) {
list.addAll((Set) f.get(t));
break;
}
}
}
return list;
}
/**
* 如果属性是普通类,则直接返回,如果是集合,则返回泛型
*
* @param c
* @return
* @throws ORMException
*/
public static Class<?> getFieldClass(Field field) throws ORMException {
if (List.class.isAssignableFrom(field.getType())) {
return TypeArgFinder.getFieldClassGenricType(field);
} else if (Set.class.isAssignableFrom(field.getType())) {
return TypeArgFinder.getFieldClassGenricType(field);
} else if (Map.class.isAssignableFrom(field.getType())) {
throw new ORMException("集合不能是MAP");
} else if (Object[].class.isAssignableFrom(field.getType())) {
throw new ORMException("集合不能是Object");
} else {
return field.getType();
}
}
/**
* 将值封装进sql里
*
* @param sql
* 原生sql
* @param t
* 对象
* @return 返回PreparedStatement对象
* @throws SQLException
*/
private PreparedStatement getPrepareStatement(String sql, T t) throws SQLException {
PreparedStatement ps = getPrepareStatement(sql);
for (int i = 0; i < table.getMappingFields().size(); i++) {
try {
BaseDBColumn column = table.getMappingFields().get(i);
if (column.getField().getAnnotation(DBColumn.class) != null && column.getField().getAnnotation(DBColumn.class).date()) {
long longDate = (Long) column.getField().get(t);
Timestamp tsDate = new Timestamp(longDate == 0 ? new Date().getTime() : longDate);
ps.setTimestamp(i + 1, tsDate);
} else {
ps.setObject(i + 1, getParam(column.getField().get(t), column));
}
} catch (Exception e) {
Log.getLog(getClass()).e(e);
continue;
}
}
if (sql.startsWith("update")) {
BaseDBColumn parmaryId = table.getParmaryId();
try {
ps.setObject(table.getMappingFields().size() + 1, getParam(parmaryId.getField().get(t), parmaryId));
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
return ps;
}
@Override
public int update(T t) throws ORMException {
int row = 1;
try {
StringBuilder sb = new StringBuilder(table.getUpdateSqlByPID());
if (judgeInterface()) {
sb.append(getVerSql(t));
row--;
}
PreparedStatement ps = getPrepareStatement(sb.toString(), t);
try {
row = ps.executeUpdate();
if (row == 0) {
throw new ORMException("更新失败,可能版本不一致或者sql错误");
}
} finally {
ps.close();
}
if (table.hasChild()) {
List<Class<?>> childClasses = table.getChildClasses();
for (Class<?> childClass : childClasses) {
List<Object> objList = getChildObj(beanClass, childClass, t);
table.autoCreateMiddleTable(beanClass, childClass, getConnection());
deleteMiddleData(GUIDUtils.getGUID((byte[]) getIdValue(t)), childClass);
for (Object o : objList) {
updateChildData(o, t);
}
}
}
} catch (SQLException e) {
throw new ORMException(e);
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
}
return row;
}
/**
* 获取对象主键值
*
* @param t
* @return
* @throws IllegalArgumentException
* @throws IllegalAccessException
* @throws ORMNotDBTableException
*/
private Object getIdValue(Object object) throws IllegalArgumentException, IllegalAccessException, ORMNotDBTableException {
BaseDBTable table = BaseDBTable.getInstance(object.getClass());
BaseDBColumn column = table.getParmaryId();
if (null != column)
return column.getFormatObj(column.getField().get(object));
return null;
}
/**
* 更新映射表数据
*
* @param obj
* @param t
* @throws IllegalAccessException
* @throws IllegalArgumentException
* @throws ORMException
* @throws ORMNotDBTableException
*/
private void updateChildData(Object obj, T t) throws IllegalArgumentException, IllegalAccessException, ORMException {
insertChildData(obj, t);
}
private String getVerSql(T t) {
// 获取类中指定的版本字段
String validateWord = getValidateWord(t);
// 获取对象中版本值
Object entityVW = getEntityVW(t, validateWord);
return " and " + validateWord + " = " + entityVW;
}
@Override
public int delete(String recid) throws ORMException {
return delete(UUID.fromString(recid));
}
@Override
public T find(String recid) {
return find(UUID.fromString(recid));
}
@Override
public int delete(UUID recid) throws ORMException {
int row = 0;
try {
PreparedStatement ps = getPreparedStatement(table.getDeleteSqlByPID(), recid);
try {
row = ps.executeUpdate();
} finally {
ps.close();
}
if (table.hasChild()) {
List<Class<?>> childClasses = table.getChildClasses();
for (Class<?> childClass : childClasses) {
table.autoCreateMiddleTable(beanClass, childClass, getConnection());
deleteMiddleData(recid, childClass);
}
}
} catch (SQLException e) {
throw new ORMException(e);
}
return row;
}
/**
* 主表删除(更新)数据,中间表对应关系删除
*
* @param recid
* 主表id
* @param childClass
* 字表类
* @throws SQLException
* @throws ORMException
*/
private void deleteMiddleData(UUID recid, Class<?> childClass) throws SQLException, ORMException {
TableUtil btu = getTableUtil();
for (String tableName : BaseTableUtil.getTableName(beanClass, childClass)) {
if (btu.valiTableExist(tableName)) {
String sql = "delete from " + tableName + " where C1=?";
PreparedStatement ps = getConnection().prepareStatement(sql);
ps.setObject(1, GUIDUtils.getBytes(recid));
ps.executeUpdate();
}
}
}
@Override
public T find(UUID recid) {
try {
PreparedStatement ps = getPreparedStatement(table.getFindSqlByPID(), recid);
try {
ResultSet rs = ps.executeQuery();
try {
if (rs.next())
return fillChildObj(instanceObject(rs), recid);
} finally {
rs.close();
}
} finally {
ps.close();
}
} catch (Exception e) {
Log.getLog(getClass()).e(e);
}
return null;
}
/**
* 反射获取返回对象结果list。数据库字段名称,必须和属性名称一致
*
* @param <T>
* @param c
* @param rs
* @return List<T>
* @throws Exception
*/
private List<T> getList(ResultSet rs) throws Exception {
List<T> list = new ArrayList<T>();
List<BaseDBColumn> fields = table.getMappingFields();
String id = "";
for (BaseDBColumn bc : fields) {
if (bc.isPaimaryId()) {
id = bc.getColumnName();
}
}
if (!table.hasChild()) {
while (rs.next()) {
list.add(instanceObject(rs, fields));
}
} else {
while (rs.next()) {
list.add(fillChildObj(instanceObject(rs), GUIDUtils.getGUID((byte[]) rs.getObject(id))));
}
}
return list;
}
/**
* rs封装成对象
*
* @param rs
* @param fields
* @return
* @throws InstantiationException
* @throws IllegalAccessException
*/
private T instanceObject(ResultSet rs, List<BaseDBColumn> fields) throws InstantiationException, IllegalAccessException {
T t = beanClass.newInstance();
for (BaseDBColumn field : fields) {
Object obj = field.getValue(rs);
if (null == obj) {
continue;
}
field.set(t, obj);
}
return t;
}
/**
* rs封装成对象
*
* @param rs
* @return
* @throws InstantiationException
* @throws IllegalAccessException
* @throws SQLException
*/
private T instanceObject(ResultSet rs) throws InstantiationException, IllegalAccessException, SQLException {
T t = beanClass.newInstance();
List<BaseDBColumn> fields = table.getMappingFields();
for (BaseDBColumn field : fields) {
Object obj = field.getValue(rs);
if (null == obj) {
continue;
}
field.set(t, obj);
}
return t;
}
@Override
public List<T> query(String where, Object... args) throws ORMException {
try {
PreparedStatement ps = getPreparedStatement(table.getQuerySql(where), args);
try {
ResultSet rs = ps.executeQuery();
try {
return getList(rs);
} catch (Exception e) {
throw new ORMException(e);
} finally {
rs.close();
}
} finally {
ps.close();
}
} catch (SQLException e) {
throw new ORMException(e);
}
}
@Override
public List<T> queryPage(String where, long startIndex, long endIndex, Object... args) throws ORMException {
try {
PreparedStatement ps = getPreparedStatement(table.getQuerySql(where) + " limit ?,?", args);
ps.setLong(args.length + 1, startIndex);
ps.setLong(args.length + 2, endIndex);
try {
ResultSet rs = ps.executeQuery();
try {
return getList(rs);
} catch (Exception e) {
throw new ORMException(e);
} finally {
rs.close();
}
} finally {
ps.close();
}
} catch (SQLException e) {
throw new ORMException(e);
}
}
@Override
public T queryFirst(String where, Object... args) throws ORMException {
try {
PreparedStatement ps = getPreparedStatement(table.getQuerySql(where), args);
try {
ResultSet rs = ps.executeQuery();
try {
if (rs.next())
return instanceObject(rs);
return null;
} catch (Exception e) {
throw new ORMException(e);
} finally {
rs.close();
}
} finally {
ps.close();
}
} catch (SQLException e) {
throw new ORMException(e);
}
}
@Override
public int queryRow(String where, Object... args) throws ORMException {
try {
PreparedStatement ps = getPreparedStatement(table.getQueryRowSql(where), args);
try {
ResultSet rs = ps.executeQuery();
try {
if (rs.next())
return rs.getInt(1);
return 0;
} catch (Exception e) {
throw new ORMException(e);
} finally {
rs.close();
}
} finally {
ps.close();
}
} catch (SQLException e) {
throw new ORMException(e);
}
}
@Override
public int update(String set, Object... args) throws ORMException {
try {
PreparedStatement ps = getPreparedStatement(table.getModifySql(set), args);
try {
return ps.executeUpdate();
} finally {
ps.close();
}
} catch (SQLException e) {
throw new ORMException(e);
}
}
@Override
public int delete(String where, Object... args) throws ORMException {
try {
PreparedStatement ps = getPreparedStatement(table.getDeleteSql(where), args);
try {
return ps.executeUpdate();
} finally {
ps.close();
}
} catch (SQLException e) {
throw new ORMException(e);
}
}
@Override
public List<T> getAll() throws ORMException {
return query("select * from " + table.getTableName());
}
/**
* 获取实体对象中版本控制字段值
*
* @param t
* 实体对象
* @param validateWord
* 指定版本字段名称
* @return
*/
private Object getEntityVW(T t, String validateWord) {
try {
Field field = beanClass.getDeclaredField(validateWord);
if (field.getType() == long.class) {
String entityVW = "get" + validateWord.substring(0, 1).toUpperCase() + validateWord.substring(1, validateWord.length());
Method method = beanClass.getMethod(entityVW, null);
return (Object) method.invoke(t, null);
} else {
throw new VersionFieldTypeException("版本字段必修是long类型");
}
} catch (NoSuchMethodException e) {
e.printStackTrace();
} catch (SecurityException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
} catch (NoSuchFieldException e) {
e.printStackTrace();
} catch (VersionFieldTypeException e) {
e.printStackTrace();
}
return null;
}
/**
* 获取类中指定的版本字段
*
* @param t
* 实体对象
* @return
*/
private String getValidateWord(T t) {
try {
Method method = beanClass.getMethod("getVersionWord", null);
return (String) method.invoke(t, null);
} catch (NoSuchMethodException e) {
e.printStackTrace();
} catch (SecurityException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (IllegalArgumentException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
}
return null;
}
/**
* 判断是否实现版本控制接口
*
* @param t
* 实体对象
* @return 若实现返回true,反之,返回false
*/
private boolean judgeInterface() {
Class<?>[] classes = beanClass.getInterfaces();
for (Class<?> c : classes) {
if (c.equals(AutoVersionControl.class)) {
return true;
}
}
return false;
}
/**
* 非懒加载查询
*
* @param t
* @param recid
* @return
* @throws ORMException
* @throws SQLException
* @throws IllegalAccessException
* @throws IllegalArgumentException
*/
private T fillChildObj(T t, UUID recid) throws ORMException, SQLException, IllegalArgumentException, IllegalAccessException {
if (table.hasChild()) {
List<Class<?>> childClasses = table.getChildClasses();
List<byte[]> bytes = new ArrayList<byte[]>();
for (Class<?> childClass : childClasses) {
List<Object> list = new ArrayList<Object>();
TableUtil btu = getTableUtil();
for (String tableName : BaseTableUtil.getTableName(beanClass, childClass)) {
if (btu.valiTableExist(tableName)) {
String query = "select C2 from " + tableName + " where C1=?";
PreparedStatement ps = getConnection().prepareStatement(query);
ps.setBytes(1, GUIDUtils.getBytes(recid));
ResultSet rs = ps.executeQuery();
while (rs.next()) {
bytes.add(rs.getBytes("C2"));
}
ORM orm = ORMFactory.instance(getConnection(), childClass);
for (byte[] b : bytes) {
Object obj=orm.find(GUIDUtils.getGUID(b));
if(null!=obj){
list.add(obj);
}
}
if (list.size() > 0) {
t = setObjectValue(t, list);
}
bytes.clear();
}
}
}
}
return t;
}
/**
* 查询通用方法
*
* @param t
* @param object
* @throws IllegalAccessException
* @throws IllegalArgumentException
*/
private T setObjectValue(T t, List<Object> list) throws IllegalArgumentException, IllegalAccessException {
Field[] fields = BeanUtils.getAllFields_Cache(beanClass);
Class<?> c = list.get(0).getClass();
for (Field f : fields) {
f.setAccessible(true);
BaseDBColumn column = BaseDBColumn.getInstance(f);
if (column.isMapping() && f.getType().isAssignableFrom(c)) {
f.set(t, list.get(0));
} else if (column.isMapping() && Collection.class.isAssignableFrom(f.getType()) && TypeArgFinder.getFieldClassGenricType(f).isAssignableFrom(c)) {
f.set(t, list);
}
}
return t;
}
}