package org.sothis.dal; import java.beans.PropertyDescriptor; import java.io.Serializable; import java.lang.annotation.Annotation; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import javax.persistence.Column; import javax.persistence.GeneratedValue; import javax.persistence.Id; import javax.persistence.Table; import javax.persistence.Transient; import org.apache.commons.beanutils.PropertyUtils; import org.sothis.core.util.StringUtils; import org.sothis.dal.query.Chain; import org.sothis.dal.query.Cnd; import org.sothis.dal.query.Op; /** * 实现了基于JPA的字段映射 * * @author velna * * @param <E> * @param <K> */ public abstract class AbstractJpaCompatibleDao<E extends Entity, K extends Serializable> extends AbstractDao<E, K> { private final String tableName; private final Map<String, PropertyInfo> fieldMap; private final Map<String, PropertyInfo> propertyMap; private String idColumnName; private boolean idGeneratedValue; public AbstractJpaCompatibleDao() { // entity annotation javax.persistence.Entity entity = this.getEntityClass().getAnnotation(javax.persistence.Entity.class); if (null == entity) { throw new RuntimeException("no Entity annotation found of entity class " + this.getEntityClass().getName()); } // table annotation Table table = this.getEntityClass().getAnnotation(Table.class); if (null == table) { throw new RuntimeException("no Table annotation found of entity class " + this.getEntityClass().getName()); } tableName = table.name(); if (StringUtils.isEmpty(tableName)) { throw new RuntimeException("name of Table annotation is empty of entity class " + this.getEntityClass().getName()); } if (!tableName.toLowerCase().equals(tableName)) { throw new IllegalArgumentException("table name of class [" + this.getEntityClass().getName() + "] must be lower cased, current is [" + tableName + "]"); } // fields annotation PropertyDescriptor[] propertyDescriptors = PropertyUtils.getPropertyDescriptors(this.getEntityClass()); Map<String, PropertyInfo> _fieldMap = new HashMap<String, PropertyInfo>(propertyDescriptors.length); Map<String, PropertyInfo> _propertyMap = new HashMap<String, PropertyInfo>(propertyDescriptors.length); boolean idFind = false; for (PropertyDescriptor descriptor : propertyDescriptors) { Column column = getAnnotation(this.getEntityClass(), descriptor, Column.class); if (null == column) { continue; } Id id = getAnnotation(this.getEntityClass(), descriptor, Id.class); if (null != id) { if (idFind) { throw new RuntimeException("multi Id annotation found of entity class " + this.getEntityClass().getName()); } idFind = true; idColumnName = descriptor.getName(); GeneratedValue generatedValue = getAnnotation(this.getEntityClass(), descriptor, GeneratedValue.class); idGeneratedValue = generatedValue != null; } Transient aTransient = getAnnotation(this.getEntityClass(), descriptor, Transient.class); PropertyInfo info = new PropertyInfo(descriptor, column, null != id, null != aTransient, this.getEntityClass()); _propertyMap.put(descriptor.getName(), info); _fieldMap.put(column.name(), info); } propertyMap = Collections.unmodifiableMap(_propertyMap); fieldMap = Collections.unmodifiableMap(_fieldMap); } private static <T extends Annotation> T getAnnotation(Class<?> entityClass, PropertyDescriptor descriptor, Class<T> annotationClass) { T a = null; for (Class<?> clazz = entityClass; clazz != Object.class; clazz = clazz.getSuperclass()) { try { Field f = clazz.getDeclaredField(descriptor.getName()); a = f.getAnnotation(annotationClass); if (a != null) { break; } } catch (NoSuchFieldException e) { } catch (SecurityException e) { } } if (null == a) { Method readMethod = descriptor.getReadMethod(); if (null != readMethod) { a = readMethod.getAnnotation(annotationClass); } } return a; } /** * 得到表名 * * @return */ public String getTableName() { return tableName; } /** * 得到字段映射表,key为实体类的字段名(并非实际数据库的字段名) * * @return */ public Map<String, PropertyInfo> getPropertyMap() { return propertyMap; } /** * 得到字段property的映射,property为实体类的字段名(并非实际数据库的字段名) * * @param property * @return */ public PropertyInfo getPropertyInfoByProperty(String property) { PropertyInfo ret = propertyMap.get(property); if (null == ret) { throw new IllegalArgumentException("no property named [" + property + "] found of entity class " + this.getEntityClass().getName()); } return ret; } /** * 得到数据库字段映射表,key为数据库的字段名(并非实体类的字段名) * * @return */ public Map<String, PropertyInfo> getFieldMap() { return fieldMap; } /** * 得到数据库字段field的映射,field为数据库的字段名(并非实体类的字段名) * * @param field * @return */ public PropertyInfo getPropertyInfoByField(String field) { PropertyInfo ret = fieldMap.get(field); if (null == ret) { throw new IllegalArgumentException("no field named [" + field + "] found of table " + tableName); } return ret; } /** * 得到id字段的字段名 * * @return */ public String getIdColumnName() { return idColumnName; } /** * id字段是否为自动生成 * * @return */ public boolean isIdGeneratedValue() { return idGeneratedValue; } /* * (non-Javadoc) * * @see com.fangjia.dal.EntityDao#update(com.fangjia.dal.Entity) */ @SuppressWarnings("unchecked") public E update(E entity) { if (null == entity) { throw new IllegalArgumentException("entity is null "); } K id = null; Chain chain = Chain.make(); Iterator<Map.Entry<String, PropertyInfo>> iterator = propertyMap.entrySet().iterator(); while (iterator.hasNext()) { Map.Entry<String, PropertyInfo> entry = iterator.next(); PropertyInfo propertyInfo = entry.getValue(); if (propertyInfo.isID()) { PropertyDescriptor descriptor = propertyInfo.getPropertyDescriptor(); Method readMethod = descriptor.getReadMethod(); try { id = (K) readMethod.invoke(entity, new Object[] {}); } catch (Exception e) { throw new IllegalArgumentException(entity.getClass().getName(), e); } continue; } if (propertyInfo.isTransient()) { continue; } PropertyDescriptor descriptor = propertyInfo.getPropertyDescriptor(); Method readMethod = descriptor.getReadMethod(); String columnName = descriptor.getName(); try { Object obj = readMethod.invoke(entity, new Object[] {}); chain.add(columnName, obj); } catch (Exception e) { throw new IllegalArgumentException(entity.getClass().getName(), e); } } if (null == id) { throw new IllegalArgumentException(entity.getClass().getName() + " ID is null !"); } if (chain.size() > 0) { updateById(id, chain); } return entity; } private void assertIdColumnNameNotNull() { if (null == this.getIdColumnName()) { throw new IllegalStateException("no id column defined"); } } @Override public E findById(K id) { assertIdColumnNameNotNull(); return findOne(Cnd.make(this.getIdColumnName(), id), null); } @Override public List<E> findByIds(Collection<K> idList) { assertIdColumnNameNotNull(); if (null == idList || idList.isEmpty()) { return Collections.emptyList(); } return find(Cnd.make(this.getIdColumnName(), Op.IN, idList), null, null); } @Override public int updateById(K id, Chain chain) { assertIdColumnNameNotNull(); return update(Cnd.make(this.getIdColumnName(), id), chain); } @Override public int deleteById(K id) { assertIdColumnNameNotNull(); return delete(Cnd.make(this.getIdColumnName(), id)); } @Override public int deleteByIds(List<K> idList) { assertIdColumnNameNotNull(); return delete(Cnd.make(this.getIdColumnName(), Op.IN, idList)); } /** * 实体类字段属性信息 * * @author velna * */ public static class PropertyInfo { private final PropertyDescriptor propertyDescriptor; private final Column column; private final Class<?> clazz; private final boolean isId; private final boolean transients; public PropertyInfo(PropertyDescriptor propertyDescriptor, Column column, boolean isId, boolean transients, Class<?> clazz) { if (!column.name().toLowerCase().equals(column.name())) { throw new IllegalArgumentException("name of column [" + propertyDescriptor.getName() + "] must be lower cased of class " + clazz.getName() + ", current is [" + column.name() + "]"); } this.isId = isId; this.propertyDescriptor = propertyDescriptor; this.column = column; this.clazz = clazz; this.transients = transients; } /** * 得到字段的描述信息 * * @return */ public PropertyDescriptor getPropertyDescriptor() { return propertyDescriptor; } /** * 得到Column注解 * * @return */ public Column getColumn() { return column; } /** * 是否为id字段 * * @return */ public boolean isID() { return isId; } /** * 得到所属的实体类 * * @return */ public Class<?> getClazz() { return clazz; } /** * 是否为瞬态的 * * @return */ public boolean isTransient() { return transients; } } }