package org.howsun.dao.jpadao; import java.io.Serializable; import java.math.BigInteger; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import javax.inject.Named; import javax.persistence.EmbeddedId; import javax.persistence.EntityManager; import javax.persistence.Id; import javax.persistence.PersistenceContext; import javax.persistence.Query; import javax.persistence.criteria.CriteriaQuery; import org.howsun.dao.ExtendExecutant; import org.howsun.dao.GenericDao; import org.howsun.dao.OrderBean; import org.howsun.dao.page.Page; import org.howsun.log.Log; import org.howsun.log.LogFactory; import org.howsun.util.Asserts; import org.howsun.util.Numbers; import org.howsun.util.Strings; /** * * 功能描述: * * PROPAGATION_REQUIRED:支持当前事务,如果当前没有事务,就新建一个事务。这是最常见的选择。<br> * PROPAGATION_SUPPORTS:支持当前事务,如果当前没有事务,就以非事务方式执行。<br> * PROPAGATION_MANDATORY:支持当前事务,如果当前没有事务,就抛出异常。<br> * PROPAGATION_REQUIRES_NEW:新建事务,如果当前存在事务,把当前事务挂起。<br> * PROPAGATION_NOT_SUPPORTED:以非事务方式执行操作,如果当前存在事务,就把当前事务挂起。<br> * PROPAGATION_NEVER:以非事务方式执行,如果当前存在事务,则抛出异常。<br> * PROPAGATION_NESTED:支持当前事务,如果当前事务存在,则执行一个嵌套事务,如果当前没有事务,就新建一个事务。 * 嵌套事务实现了隔离机制,例如B事务嵌套在A事务中,B失败不会影响A提交。而PROPAGATION_REQUIRED则全部回滚<br> * * @author howsun(howsun.zhang@google.com) * @version 1.0.2 */ @Named("jpaGenericDao") public class JpaGenericDao implements GenericDao { protected Log log = LogFactory.getLog(JpaGenericDao.class); @PersistenceContext protected EntityManager entityManager; @Override public void clear() { entityManager.clear(); } @Override public <T> int delete(Class<T> entityClass, Serializable... entityids) { int i = 0; for(Serializable id : entityids){ try { entityManager.remove(entityManager.find(entityClass, id)); i++; } catch (Exception e) { log.info(e.getMessage(), e); continue; } } return i; } @Override public int delete(Object object) { entityManager.remove(object); return 1; } @Override public <T> int delete(Class<T> entityClass, String condition, Object[] params){ if(Strings.hasLength(condition)){ Query query = entityManager.createQuery("DELETE " + entityClass.getName() + " WHERE " + condition); this.setParameter(query, params); return query.executeUpdate(); } return 0; } @Override public <T> T find(Class<T> entityClass, Serializable entityid) { return (T)entityManager.find(entityClass, entityid); } @Override @SuppressWarnings("unchecked") public <T> T findBySQL(String sql, Object[] params) { T t = null; Query query = entityManager.createNativeQuery(sql); this.setParameter(query, params); try { t = (T)query.getSingleResult(); } catch (javax.persistence.NoResultException e) { return null; } catch (Exception e) { throw new RuntimeException(e); } return t; } @SuppressWarnings("unchecked") @Override public <T> T findByXQL(Class<T> entityClass, String condition, Object[] params) { T t = null; StringBuffer sql = new StringBuffer("from " + entityClass.getName() ); if(condition != null && condition.trim().length() > 0) sql.append(" where ").append(condition); Query query = entityManager.createQuery(sql.toString()); this.setParameter(query, params); try { t = (T)query.getSingleResult(); } catch (javax.persistence.NoResultException e) { return null; } catch (Exception e) { throw new RuntimeException(e); } return t; } @SuppressWarnings("unchecked") @Override public <T> List<T> finds(Class<T> entityClass, String fields, Page page, String condition, Object[] params, OrderBean order) { String entityName = entityClass.getName(); StringBuffer sql = new StringBuffer("select new " + entityName + "(" + fields + ") from " + entityName +" o"); StringBuffer countSql = new StringBuffer("select count(" + getCountField(entityClass) + ") from " + entityName +" o"); //设置查询条件 if(condition != null && condition.trim().length() > 0){ sql.append(" where ").append(condition); countSql.append(" where ").append(condition); } Query query = null; if(page != null && page.getTotalCount() == 0){ query = entityManager.createQuery(countSql.toString()); this.setParameter(query, params); this.setCount(query, page); } //设置排序 //设置排序 if(order != null) sql.append(order.toSQL("o")); query = entityManager.createQuery(sql.toString()); this.setParameter(query, params); //分页 if(page != null){ query.setFirstResult(page.getFirstIndex()).setMaxResults(page.getPageSize()); } return query.getResultList(); } @Override public <T> List<T> finds(Class<T> entityClass, Page page, String condition, Object[] params, OrderBean order) { return getScrollData(entityClass, page, condition, params, order); } @Override public <T> List<T> finds(Class<T> entityClass, Page page, String condition, Object[] params) { return getScrollData(entityClass, page, condition, params, null); } @Override public <T> List<T> finds(Class<T> entityClass, Page page, OrderBean order) { return getScrollData(entityClass, page, null, null, order); } @Override public <T> List<T> finds(Class<T> entityClass, Page page) { return getScrollData(entityClass, page, null, null, null); } @Override public <T> List<T> finds(Class<T> entityClass, OrderBean order) { return getScrollData(entityClass, null, null, null, order); } @Override public <T> List<T> finds(Class<T> entityClass) { CriteriaQuery<T> criteriaQuery = entityManager.getCriteriaBuilder().createQuery(entityClass); return entityManager.createQuery(criteriaQuery).getResultList(); } @Override public List<?> findsBySQL(String sql, Page page, Object[] params) { String sqlCopy = sql.toLowerCase(); Asserts.isTrue(sqlCopy.contains("select") && sqlCopy.contains("from"), "不合格的SQL语句"); if(page != null && page.getTotalCount() == 0){ int start = sqlCopy.indexOf("select") + 6; int end = sqlCopy.indexOf("from"); StringBuffer s = new StringBuffer(); s.append(sql.substring(0, start)).append(" count(*) ").append(sql.substring(end, sql.length())); Query query = entityManager.createNativeQuery(s.toString()); this.setParameter(query, params); this.setCount(query, page); } Query query = entityManager.createNativeQuery(sql); setParameter(query, params); return query.getResultList(); } @Override public List<?> findsBySQL(String sql, Object[] params) { return this.findsBySQL(sql, null, params); } @Override public List<?> findsByXQL(String xql, Page page, Object[] params) { Query query = entityManager.createQuery(xql); if(page != null && page.getTotalCount() == 0){ String countSql = "select count(*) " + xql.substring(xql.indexOf("from")); query = entityManager.createQuery(countSql.toString()); this.setParameter(query, params); this.setCount(query, page); } // 分页 if (page !=null) query.setFirstResult(page.getFirstIndex()).setMaxResults(page.getPageSize()); this.setParameter(query, params); return query.getResultList(); } @Override public void flush() { entityManager.flush(); } @Override public <T> long getCount(Class<T> entityClass) { StringBuffer countSql = new StringBuffer("select count(" + getCountField(entityClass) + ") from " + entityClass.getName() +" o"); Query query = entityManager.createQuery(countSql.toString()); return getCount(query); } @Override public <T> long getCount(Class<T> entityClass, String condition, Object[] params) { String entityName = entityClass.getName(); StringBuffer countSql = new StringBuffer("select count(" + getCountField(entityClass) + ") from " + entityName +" o"); //设置查询条件 if(condition != null && condition.trim().length() > 0){ countSql.append(" where ").append(condition); } Query query = entityManager.createQuery(countSql.toString()); this.setParameter(query, params); return getCount(query); } @Override public void merge(Object object) { entityManager.merge(object); } @Override public void save(Object object) { entityManager.persist(object); } @Override public void update(Object object) { entityManager.merge(object); } /** * update() * e.g: update Entity set f1=?, f2=? where id=? * * @param <T> * @param entityName * @param fields * @param values * @param id */ public <T> int update(Class<T> entityName, String[] fields, Object[] values, Serializable id){ boolean valid = fields == null || fields.length == 0 || id == null; Asserts.isFalse(valid, "无法按要求数据更新实体"); StringBuffer fs = new StringBuffer(); for (String field : fields) { fs.append(field).append(","); } List<Object> vs = new ArrayList<Object>(Arrays.asList(values)); vs.add(id); String idField = getIdField(entityName); return updateByBatch(entityName, fs.toString(), idField, vs.toArray()); } @Override public <T> int updateByBatch(Class<T> entityName, String fields, String condition, Object[] values){ Asserts.isTrue(Strings.hasLength(fields), "无法按要求数据更新实体"); String fs[] = fields.split(","); StringBuffer xQL = new StringBuffer("UPDATE ") .append(entityName.getName()) .append(" SET"); for(String field : fs){ xQL.append(" ").append(field).append(field.indexOf('=') > -1 ? "," : "=?,"); } if(xQL.toString().endsWith(",")){ xQL.deleteCharAt(xQL.length() - 1); } if(Strings.hasLength(condition)){ xQL.append(" WHERE"); String conditions[] = condition.split(","); if(conditions.length > 1){ for(String c : conditions){ xQL.append(" ").append(c).append(c.indexOf('=') > -1 ? "" : "=?").append(" AND"); } }else if(conditions[0].indexOf('=') == -1){ xQL.append(" ").append(conditions[0]).append("=?"); }else{ xQL.append(" ").append(conditions[0]); } if(xQL.toString().endsWith(" AND")){ xQL.delete(xQL.length() - 4, xQL.length()); } } Query query = entityManager.createQuery(xQL.toString()); setParameter(query, values); return query.executeUpdate(); } @Override public Long nextId(Class<?> entityClass) { Query query = entityManager.createQuery("SELECT MAX(id) FROM " + entityClass.getName()); Integer maxId = (Integer)query.getSingleResult(); return maxId == null ? 1L : ++maxId; } @Override public <T> void increaseFieldValue(Class<T> entityName, String field, Integer defaultValue, Serializable id) { String idFieldName = getCountField(entityName); Asserts.notNull(idFieldName, "未找到主键字段"); if(!Numbers.thanZero(defaultValue)){ defaultValue = 1; } Query query = entityManager.createQuery(String.format("UPDATE %s SET %s=%s+" + defaultValue + " WHERE %s=?:id", entityName.getName(), field, field, idFieldName)); query.setParameter("id", id); query.executeUpdate(); } @Override public void execute(ExtendExecutant extendExecutant){ extendExecutant.executing(entityManager); } ////////////////////////////////////////////////////////private method//////////////////////////////////////////////// /** * 为查询对象设定参数,注意:如果是JPA,则索引位要加一 * @param query * @param params */ protected void setParameter(Query query, Object[] params){ if(params != null){ for (int i = 0; i < params.length; i++) { query.setParameter(i+1, params[i]); } } } private <T> String getIdField(Class<T> entityName){ String idField = null; try { idField = entityManager.getEntityManagerFactory().getMetamodel().entity(entityName).getId(Id.class).getName(); } catch (Exception e) { try { idField = entityManager.getEntityManagerFactory().getMetamodel().entity(entityName).getId(EmbeddedId.class).getName(); } catch (Exception e2) { idField = "id"; } } return idField == null ? "id" : idField; } /** * 得到统计字段 * @param <T> * @param entityClass * @return String */ private <T> String getCountField(Class<T> entityClass){ String idField = null; try { idField = entityManager.getEntityManagerFactory().getMetamodel().entity(entityClass).getId(Id.class).getName(); } catch (Exception e) { // TODO: handle exception } return idField == null ? "*" : "o." + idField; /* try { PropertyDescriptor[] ps = Introspector.getBeanInfo(entityClass).getPropertyDescriptors(); for(PropertyDescriptor propertydesc : ps){ Method getter = propertydesc.getReadMethod(); if(getter != null && getter.isAnnotationPresent(EmbeddedId.class)){ PropertyDescriptor[] idClassps = Introspector.getBeanInfo(propertydesc.getPropertyType()).getPropertyDescriptors(); return "o." + propertydesc.getName() + "." + (idClassps[0].getName().equals("class") ? idClassps[1].getName() : idClassps[0].getName()); } //需要修改,查找到id字段 String idField = entityManager.getEntityManagerFactory().getMetamodel().entity(entityClass).getId(Id.class).getName(); if(idField == null){ idField = "*"; } return "o." + idField; } } catch (Exception e) { e.printStackTrace(); } return "o"; */ } //主方法 @SuppressWarnings("unchecked") private <T> List<T> getScrollData(Class<T> entityClass, Page page, String condition, Object[] params, OrderBean order) { StringBuffer sql = new StringBuffer("select o from " + entityClass.getName() +" o"); StringBuffer countSql = new StringBuffer("select count("+getCountField(entityClass)+") from " + entityClass.getName() +" o"); //设置查询条件 if(Strings.hasLengthBytrim(condition)){ sql.append(" where ").append(condition); countSql.append(" where ").append(condition); } Query query = null; if(page != null && page.getTotalCount() == 0){ log.debug(countSql.toString()); query = entityManager.createQuery(countSql.toString()); this.setParameter(query, params); this.setCount(query, page); } //设置排序 if(order != null) sql.append(order.toSQL("o"));//排序 query = entityManager.createQuery(sql.toString()); this.setParameter(query, params); //分页 if(page != null){ query.setFirstResult(page.getFirstIndex()).setMaxResults(page.getPageSize()); } log.debug(sql.toString()); return query.getResultList(); } private void setCount(Query query, Page page){ int count = getCount(query); page.setTotalCount(count); } private int getCount(Query query){ Object object = query.getSingleResult(); if(object instanceof BigInteger){ BigInteger bi = (BigInteger)object; return bi.intValue(); } if(object instanceof Long){ long i = (Long)object; return (int)i; } if(object instanceof Integer){ Integer i = (Integer)object; return i; } return 0; } @Override public void finalize() throws Throwable { if(this.entityManager != null){ entityManager.close(); } super.finalize(); } //////////////////////////////////////////////////////////////////// public EntityManager getEntityManager() { return entityManager; } public void setEntityManager(EntityManager entityManager) { this.entityManager = entityManager; } public static void main(String[] args) { } }