package net.sf.jeasyorm;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import net.sf.jeasyorm.Mapping.ColumnInfo;
import net.sf.jeasyorm.Mapping.FieldInfo;
public class BasicEntityManager extends AbstractEntityManager {
public BasicEntityManager(Connection connection) {
super(connection);
}
protected SqlInfo newSqlInfo(Class<?> entityClass) {
Mapping mapping = Mapping.getMapping(this, entityClass);
SqlInfo info = new SqlInfo(mapping);
StringBuffer selectColumns = new StringBuffer();
StringBuffer insertColumns = new StringBuffer();
StringBuffer insertPlaceholders = new StringBuffer();
StringBuffer updateSet = new StringBuffer();
StringBuffer where = new StringBuffer();
for (ColumnInfo ci : mapping.getColumns()) {
if (selectColumns.length() > 0) selectColumns.append(", ");
selectColumns.append(ci.getName());
if (!ci.isAutoIncrement()) {
if (insertColumns.length() > 0) insertColumns.append(", ");
insertColumns.append(ci.getName());
if (insertPlaceholders.length() > 0) insertPlaceholders.append(", ");
insertPlaceholders.append("?");
}
if (ci.isPrimaryKey()) {
if (where.length() > 0) where.append(" and ");
where.append(ci.getName());
where.append(" = ?");
} else {
if (updateSet.length() > 0) updateSet.append(", ");
updateSet.append(ci.getName());
updateSet.append(" = ?");
}
}
String tableName = (mapping.getSchemaName() != null ? mapping.getSchemaName() + "." : "") +
mapping.getTableName();
info.loadSql = "select " + selectColumns + " from " + tableName + " where " + where;
info.selectSql = "select " + selectColumns + " from " + tableName;
info.insertSql = "insert into " + tableName + " ("+ insertColumns + ") values (" + insertPlaceholders + ")";
info.updateSql = "update " + tableName + " set "+ updateSet + " where " + where;
info.deleteSql = "delete from " + tableName + " where " + where;
return info;
}
@SuppressWarnings("unchecked")
public <T> T load(Class<T> entityClass, Object... pk) {
SqlInfo info = getSqlInfo(this, entityClass);
PreparedStatement stmt = null;
try {
stmt = prepareStatement(info.loadSql, null);
int i=0;
for (ColumnInfo ci : info.mapping.getColumns()) {
if (ci.isPrimaryKey()) {
setParameter(stmt, i+1, pk[i], ci.getType());
i++;
}
}
ResultSet rs = null;
try {
rs = stmt.executeQuery();
T entity = null;
if (rs.next()) entity = (T) getObject(rs, info.mapping);
if (rs.next()) {
throw new RuntimeSQLException("Multiple rows returned for statement [" + info.loadSql + "]");
}
return entity;
} catch (RuntimeSQLException se) {
throw se;
} catch (Exception e) {
throw new RuntimeSQLException("Error executing statement [" + info.loadSql + "]", e);
} finally {
close(rs);
}
} finally {
close(stmt);
}
}
protected String getSql(SqlInfo info, String query) {
String lcQuery = query.trim().toLowerCase();
if (lcQuery.startsWith("from")) {
int pos = info.selectSql.indexOf(" from ");
return info.selectSql.substring(0, pos+1) + query;
} else if (lcQuery.startsWith("where") || lcQuery.startsWith("order")) {
return info.selectSql + " " + query;
} else {
return query;
}
}
@SuppressWarnings("unchecked")
public <T> T findUnique(Class<T> entityClass, String query, Object... params) {
boolean isNative = Utils.isNativeType(entityClass);
SqlInfo info = isNative ? null : getSqlInfo(this, entityClass);
String sql = isNative ? query : getSql(info, query);
PreparedStatement stmt = null;
try {
stmt = prepareStatement(sql, null);
for (int i=0; i<params.length; i++) {
setParameter(stmt, i+1, params[i], Types.VARCHAR);
}
ResultSet rs = null;
try {
rs = stmt.executeQuery();
T entity = null;
if (rs.next()) {
entity = !isNative ? (T) getObject(rs, info.mapping) : (T) getValue(rs, 1, entityClass, Types.VARCHAR);
}
if (rs.next()) {
throw new RuntimeSQLException("Multiple rows returned for statement [" + sql + "]");
}
return entity;
} catch (RuntimeSQLException se) {
throw se;
} catch (Exception e) {
throw new RuntimeSQLException("Error executing statement [" + sql + "]", e);
} finally {
close(rs);
}
} finally {
close(stmt);
}
}
@SuppressWarnings("unchecked")
public <T> List<T> find(Class<T> entityClass, String query, Object... params) {
boolean isNative = Utils.isNativeType(entityClass);
SqlInfo info = isNative ? null : getSqlInfo(this, entityClass);
String sql = isNative ? query : getSql(info, query);
List<T> entities = new ArrayList<T>();
PreparedStatement stmt = null;
try {
stmt = prepareStatement(sql, null);
for (int i=0; i<params.length; i++) {
setParameter(stmt, i+1, params[i], Types.VARCHAR);
}
ResultSet rs = null;
try {
rs = stmt.executeQuery();
while (rs.next()) {
T entity = !isNative ? (T) getObject(rs, info.mapping) : (T) getValue(rs, 1, entityClass, Types.VARCHAR);
entities.add(entity);
}
return entities;
} catch (RuntimeSQLException se) {
throw se;
} catch (Exception e) {
throw new RuntimeSQLException("Error executing statement [" + sql + "]", e);
} finally {
close(rs);
}
} finally {
close(stmt);
}
}
public int count(String sql, Object... params) {
PreparedStatement stmt = null;
try {
int pos1 = sql.indexOf(" from ");
int pos2 = sql.lastIndexOf(" order by ");
stmt = prepareStatement("select count(*)" + sql.substring(pos1, pos2 > pos1 ? pos2 : sql.length()), null);
for (int i=0; i<params.length; i++) {
setParameter(stmt, i+1, params[i], Types.VARCHAR);
}
ResultSet rs = null;
try {
rs = stmt.executeQuery();
if (rs.next()) return rs.getInt(1);
} catch (RuntimeSQLException se) {
throw se;
} catch (Exception e) {
throw new RuntimeSQLException("Error executing statement [" + sql + "]", e);
} finally {
close(rs);
}
} finally {
close(stmt);
}
return 0;
}
@SuppressWarnings("unchecked")
public <T> Page<T> find(Class<T> entityClass, int pageNum, int pageSize, String query, Object... params) {
boolean isNative = Utils.isNativeType(entityClass);
SqlInfo info = isNative ? null : getSqlInfo(this, entityClass);
String sql = isNative ? query : getSql(info, query);
List<T> entities = new ArrayList<T>();
PreparedStatement stmt = null;
try {
stmt = prepareStatement(sql + " limit " + pageSize + " offset " + (pageNum*pageSize), null);
for (int i=0; i<params.length; i++) {
setParameter(stmt, i+1, params[i], Types.VARCHAR);
}
ResultSet rs = null;
try {
rs = stmt.executeQuery();
while (rs.next()) {
T entity = !isNative ? (T) getObject(rs, info.mapping) : (T) getValue(rs, 1, entityClass, Types.VARCHAR);
entities.add(entity);
}
int totalSize = (pageNum*pageSize) + entities.size();
if (entities.size() >= pageSize) {
totalSize = Math.max(totalSize, count(sql, params));
}
return new Page<T>(entities, pageNum, pageSize, totalSize);
} catch (RuntimeSQLException se) {
throw se;
} catch (Exception e) {
throw new RuntimeSQLException("Error executing statement [" + sql + "]", e);
} finally {
close(rs);
}
} finally {
close(stmt);
}
}
public <T> Iterator<T> iterator(Class<T> entityClass, String query, Object... params) {
SqlInfo info = getSqlInfo(this, entityClass);
String sql = getSql(info, query);
PreparedStatement stmt = null;
ResultSet rs = null;
try {
stmt = prepareStatement(sql, null);
for (int i=0; i<params.length; i++) {
setParameter(stmt, i+1, params[i], Types.VARCHAR);
}
rs = stmt.executeQuery();
return new ResultSetIterator<T>(stmt, rs, entityClass);
} catch (RuntimeSQLException se) {
close(rs);
close(stmt);
throw se;
} catch (Exception e) {
close(rs);
close(stmt);
throw new RuntimeSQLException("Error executing statement [" + sql + "]", e);
}
}
public int execute(String sql, Object... params) {
PreparedStatement stmt = null;
try {
stmt = prepareStatement(sql, null);
for (int i=0; i<params.length; i++) {
setParameter(stmt, i+1, params[i], Types.VARCHAR);
}
return stmt.executeUpdate();
} catch (Exception e) {
throw new RuntimeSQLException("Error executing statement [" + sql + "]", e);
} finally {
close(stmt);
}
}
public <T> T insert(T entity) {
SqlInfo info = getSqlInfo(this, entity.getClass());
PreparedStatement stmt = null;
try {
List<String> columnNames = new ArrayList<String>();
for (ColumnInfo ci : info.mapping.getColumns()) {
if (ci.isAutoIncrement()) columnNames.add(ci.getName());
}
stmt = prepareStatement(info.insertSql, columnNames.toArray(new String[columnNames.size()]));
int i = 0;
for (ColumnInfo ci : info.mapping.getColumns()) {
if (!ci.isAutoIncrement()) {
FieldInfo fi = info.mapping.getFieldForColumn(ci);
Object value = get(entity, fi);
if (ci.isPrimaryKey() && value == null) {
value = getPrimaryKey(info.mapping, ci);
set(entity, fi, Utils.convertTo(value, fi.getType()));
}
setParameter(stmt, i+1, value, ci.getType());
i++;
}
}
int num = stmt.executeUpdate();
if (num < 1) {
throw new RuntimeSQLException("Error inserting entity");
}
if (columnNames.size() > 0) {
ResultSet rs = null;
try {
rs = stmt.getGeneratedKeys();
for (String columnName : columnNames) {
FieldInfo fi = info.mapping.getFieldForColumn(columnName);
ColumnInfo ci = info.mapping.getColumnForColumn(columnName);
rs.next();
set(entity, fi, getValue(rs, 1, fi.getType(), ci.getType()));
}
} catch (SQLException e) {
throw new RuntimeSQLException("Error updating generated keys", e);
} finally {
close(rs);
}
}
} catch (SQLException e) {
throw new RuntimeSQLException("Error inserting entity", e);
} finally {
close(stmt);
}
return null;
}
protected Object getPrimaryKey(Mapping mapping, ColumnInfo ci) {
return null;
}
public <T> void update(T entity) {
SqlInfo info = getSqlInfo(this, entity.getClass());
PreparedStatement stmt = null;
try {
stmt = prepareStatement(info.updateSql, null);
int i = 0;
for (ColumnInfo ci : info.mapping.getColumns()) {
if (!ci.isPrimaryKey()) {
FieldInfo fi = info.mapping.getFieldForColumn(ci);
Object value = get(entity, fi);
setParameter(stmt, i+1, value, ci.getType());
i++;
}
}
for (ColumnInfo ci : info.mapping.getColumns()) {
if (ci.isPrimaryKey()) {
FieldInfo fi = info.mapping.getFieldForColumn(ci);
Object value = get(entity, fi);
setParameter(stmt, i+1, value, ci.getType());
i++;
}
}
int num = stmt.executeUpdate();
if (num < 1) {
throw new RuntimeSQLException("Error updating entity");
}
} catch (SQLException e) {
throw new RuntimeSQLException("Error updating entity", e);
} finally {
close(stmt);
}
}
public <T> void delete(T entity) {
SqlInfo info = getSqlInfo(this, entity.getClass());
PreparedStatement stmt = null;
try {
stmt = prepareStatement(info.deleteSql, null);
int i = 0;
for (ColumnInfo ci : info.mapping.getColumns()) {
FieldInfo fi = info.mapping.getFieldForColumn(ci);
if (ci.isPrimaryKey()) {
Object value = get(entity, fi);
setParameter(stmt, i+1, value, ci.getType());
i++;
}
}
int num = stmt.executeUpdate();
if (num < 1) {
throw new RuntimeSQLException("Error deleting entity");
}
} catch (SQLException e) {
throw new RuntimeSQLException("Error deleting entity", e);
} finally {
close(stmt);
}
}
public class ResultSetIterator<T> implements Iterator<T> {
private PreparedStatement stmt;
private ResultSet rs;
boolean isNative;
private Class<T> entityClass;
private Mapping mapping;
boolean hasNext;
protected ResultSetIterator(PreparedStatement stmt, ResultSet rs, Class<T> entityClass) {
this.stmt = stmt;
this.rs = rs;
this.isNative = Utils.isNativeType(entityClass);
this.entityClass = entityClass;
this.mapping = isNative ? null : Mapping.getMapping(BasicEntityManager.this, entityClass);
try {
this.hasNext = rs.next();
} catch (SQLException e) {
this.hasNext = false;
}
}
@Override
public boolean hasNext() {
return hasNext;
}
@Override
@SuppressWarnings("unchecked")
public T next() {
T entity = !isNative ? (T) getObject(rs, mapping) : (T) getValue(rs, 1, entityClass, Types.VARCHAR);
try {
this.hasNext = rs.next();
} catch (SQLException e) {
this.hasNext = false;
}
return entity;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
public void close() {
BasicEntityManager.this.close(rs);
BasicEntityManager.this.close(stmt);
}
}
}