package org.simpleflatmapper.jdbi; import org.simpleflatmapper.jdbc.JdbcMapperFactory; import org.simpleflatmapper.jdbc.QueryPreparer; import org.simpleflatmapper.jdbc.SqlTypeColumnProperty; import org.simpleflatmapper.jdbc.named.NamedSqlQuery; import org.simpleflatmapper.util.ErrorHelper; import org.skife.jdbi.v2.Binding; import org.skife.jdbi.v2.SQLStatement; import org.skife.jdbi.v2.StatementContext; import org.skife.jdbi.v2.sqlobject.Binder; import org.skife.jdbi.v2.sqlobject.BinderFactory; import org.skife.jdbi.v2.tweak.RewrittenStatement; import org.skife.jdbi.v2.tweak.StatementRewriter; import java.sql.PreparedStatement; import java.sql.SQLException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; public class SfmBinderFactory implements BinderFactory<SfmBind> { public Binder build(SfmBind annotation) { return new SfmBinder(); } private static class SfmBinder<T> implements Binder<SfmBind, T> { private final ConcurrentMap<QueryPreparerKey, QueryPreparer<T>> cache = new ConcurrentHashMap<QueryPreparerKey, QueryPreparer<T>>(); @Override public void bind(SQLStatement<?> sqlStatement, SfmBind annotation, T o) { QueryPreparer<T> queryPreparer = getQueryPreparer(sqlStatement, annotation, o.getClass()); sqlStatement.setStatementRewriter(new SfmStatementRewriter<T>(queryPreparer, o)); } private QueryPreparer<T> getQueryPreparer(SQLStatement<?> sqlStatement, SfmBind annotation, Class<?> aClass) { QueryPreparerKey key = new QueryPreparerKey(sqlStatement.getContext().getRawSql(), aClass); QueryPreparer<T> queryPreparer = cache.get(key); if (queryPreparer == null) { NamedSqlQuery parse = NamedSqlQuery.parse(sqlStatement.getContext().getRawSql()); JdbcMapperFactory jdbcMapperFactory = JdbcMapperFactory .newInstance(); for (SqlType col : annotation.sqlTypes()) { jdbcMapperFactory.addColumnProperty(col.name(), SqlTypeColumnProperty.of(col.type())); } queryPreparer = jdbcMapperFactory.<T>from(aClass).to(parse); QueryPreparer<T> cachedQP = cache.putIfAbsent(key, queryPreparer); if (cachedQP != null) { queryPreparer = cachedQP; } } return queryPreparer; } } private static class QueryPreparerKey { private final String sql; private final Class<?> target; private QueryPreparerKey(String sql, Class<?> target) { this.sql = sql; this.target = target; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; QueryPreparerKey that = (QueryPreparerKey) o; if (!sql.equals(that.sql)) return false; return target.equals(that.target); } @Override public int hashCode() { int result = sql.hashCode(); result = 31 * result + target.hashCode(); return result; } } private static class SfmStatementRewriter<T> implements StatementRewriter { private final QueryPreparer<T> queryPreparer; private final T o; public SfmStatementRewriter(QueryPreparer<T> queryPreparer, T o) { this.queryPreparer = queryPreparer; this.o = o; } @Override public RewrittenStatement rewrite(String s, Binding binding, StatementContext statementContext) { final String sql = queryPreparer.toRewrittenSqlQuery(o); return new SfmRewrittenStatement(sql); } private class SfmRewrittenStatement implements RewrittenStatement { private final String sql; public SfmRewrittenStatement(String sql) { this.sql = sql; } @Override public void bind(Binding binding, PreparedStatement preparedStatement) throws SQLException { try { queryPreparer.mapper().mapTo(o, preparedStatement, null); } catch (Exception e) { ErrorHelper.rethrow(e); } } @Override public String getSql() { return sql; } } } }