package com.w11k.lsql.sqlfile; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.google.common.io.CharStreams; import com.w11k.lsql.LSql; import com.w11k.lsql.query.PojoQuery; import com.w11k.lsql.query.RowQuery; import com.w11k.lsql.statement.AbstractSqlStatement; import com.w11k.lsql.statement.SqlStatementToPreparedStatement; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.sql.PreparedStatement; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; import static com.google.common.collect.ImmutableMap.copyOf; public class LSqlFile { private static final Pattern STMT_BLOCK_BEGIN = Pattern.compile( "^--\\s*(\\w*)\\s*$", Pattern.MULTILINE); private static final Pattern STMT_BLOCK_END = Pattern.compile( ";\\s*$", Pattern.MULTILINE); private final Logger logger = LoggerFactory.getLogger(getClass()); private final LSql lSql; private final String nameForDescription; // without .sql extension private final String path; private final Map<String, SqlStatementToPreparedStatement> statements = Maps.newHashMap(); public LSqlFile(LSql lSql, String nameForDescription, String path) { this.lSql = lSql; this.nameForDescription = nameForDescription; this.path = path; parseSqlStatements(); } // ----- public ----- public ImmutableMap<String, SqlStatementToPreparedStatement> getStatements() { return copyOf(statements); } public AbstractSqlStatement<RowQuery> statement(String name) { final SqlStatementToPreparedStatement stmtToPs = getStatement(name); return new AbstractSqlStatement<RowQuery>(stmtToPs) { @Override protected RowQuery createQueryInstance(LSql lSql, PreparedStatement ps) { return new RowQuery(LSqlFile.this.lSql, ps); } }; } public <T> AbstractSqlStatement<PojoQuery<T>> statement(String name, final Class<T> pojoClass) { final SqlStatementToPreparedStatement stmtToPs = getStatement(name); return new AbstractSqlStatement<PojoQuery<T>>(stmtToPs) { @Override protected PojoQuery<T> createQueryInstance(LSql lSql, PreparedStatement ps) { return new PojoQuery<T>(LSqlFile.this.lSql, ps, pojoClass); } }; } private SqlStatementToPreparedStatement getStatement(String name) { if (!this.statements.containsKey(name)) { throw new IllegalArgumentException("No statement with name '" + name + "' found in file '" + this.path + "'."); } return this.statements.get(name); } // ----- private ----- private void parseSqlStatements() { logger.info("Reading SQL file '" + nameForDescription + "'"); statements.clear(); InputStream is = getClass().getResourceAsStream(path); try { BufferedReader reader = new BufferedReader(new InputStreamReader(is)); String content = CharStreams.toString(reader); Matcher startMatcher = STMT_BLOCK_BEGIN.matcher(content); while (startMatcher.find()) { String name = startMatcher.group(1); String sub = content.substring(startMatcher.end()); Matcher endMatcher = STMT_BLOCK_END.matcher(sub); if (!endMatcher.find()) { throw new IllegalStateException( "Could not find the end of the SQL expression '" + name + "'. Did you add ';' at the end?"); } sub = sub.substring(0, endMatcher.end()).trim(); logger.debug("Found SQL statement '{}'", name); statements.put(name, new SqlStatementToPreparedStatement(lSql, nameForDescription + "." + name, sub)); } } catch (IOException e) { throw new RuntimeException(e); } } }