/** * AnalyzerBeans * Copyright (C) 2014 Neopost - Customer Information Management * * This copyrighted material is made available to anyone wishing to use, modify, * copy, or redistribute it subject to the terms and conditions of the GNU * Lesser General Public License, as published by the Free Software Foundation. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License * for more details. * * You should have received a copy of the GNU Lesser General Public License * along with this distribution; if not, write to: * Free Software Foundation, Inc. * 51 Franklin Street, Fifth Floor * Boston, MA 02110-1301 USA */ package org.eobjects.analyzer.storage; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; import org.eobjects.analyzer.data.InputColumn; import org.eobjects.analyzer.data.InputRow; import org.eobjects.analyzer.data.MockInputRow; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class SqlDatabaseRowAnnotationFactory implements RowAnnotationFactory { private final static Logger logger = LoggerFactory.getLogger(SqlDatabaseRowAnnotationFactory.class); private final Map<InputColumn<?>, String> _inputColumnNames = new LinkedHashMap<InputColumn<?>, String>(); private final Map<RowAnnotation, String> _annotationColumnNames = new HashMap<RowAnnotation, String>(); private final Connection _connection; private final String _tableName; private final AtomicInteger _nextColumnIndex = new AtomicInteger(1); public SqlDatabaseRowAnnotationFactory(Connection connection, String tableName) { _connection = connection; _tableName = tableName; String intType = SqlDatabaseUtils.getSqlType(Integer.class); performUpdate(SqlDatabaseUtils.CREATE_TABLE_PREFIX + tableName + " (id " + intType + " PRIMARY KEY, distinct_count " + intType + ")"); } @Override protected void finalize() throws Throwable { super.finalize(); performUpdate("DROP TABLE " + _tableName); } private void performUpdate(String sql) { SqlDatabaseUtils.performUpdate(_connection, sql); } @Override public RowAnnotation createAnnotation() { return new RowAnnotationImpl(); } private boolean containsRow(InputRow row) { ResultSet rs = null; PreparedStatement st = null; try { st = _connection.prepareStatement("SELECT COUNT(*) FROM " + _tableName + " WHERE id = ?"); boolean contains; st.setInt(1, row.getId()); rs = st.executeQuery(); if (rs.next()) { int count = rs.getInt(1); if (count == 0) { contains = false; } else if (count == 1) { contains = true; } else { throw new IllegalStateException(count + " rows with id=" + row.getId() + " exists in database!"); } } else { contains = false; } return contains; } catch (SQLException e) { throw new IllegalStateException(e); } finally { SqlDatabaseUtils.safeClose(rs, st); } } @Override public void annotate(InputRow[] rows, RowAnnotation annotation) { for (InputRow row : rows) { annotate(row, 1, annotation); } } @Override public synchronized void annotate(InputRow row, int distinctCount, RowAnnotation annotation) { RowAnnotationImpl a = (RowAnnotationImpl) annotation; List<InputColumn<?>> inputColumns = row.getInputColumns(); List<String> columnNames = new ArrayList<String>(inputColumns.size()); List<Object> values = new ArrayList<Object>(inputColumns.size()); for (InputColumn<?> inputColumn : inputColumns) { String columnName = getColumnName(inputColumn, true); columnNames.add(columnName); Object value = row.getValue(inputColumn); values.add(value); } String annotationColumnName = getColumnName(annotation, true); if (containsRow(row)) { PreparedStatement st = null; ResultSet rs = null; boolean annotated; try { st = _connection.prepareStatement("SELECT " + annotationColumnName + " FROM " + _tableName + " WHERE id=?"); st.setInt(1, row.getId()); rs = st.executeQuery(); if (rs.next()) { annotated = rs.getBoolean(1); } else { logger.error("No rows returned on annotation status for id={}", row.getId()); annotated = false; } } catch (SQLException e) { throw new IllegalStateException(e); } finally { SqlDatabaseUtils.safeClose(rs, st); } if (!annotated) { try { st = _connection.prepareStatement("UPDATE " + _tableName + " SET " + annotationColumnName + "=TRUE WHERE id=?"); st.setInt(1, row.getId()); st.executeUpdate(); a.incrementRowCount(distinctCount); } catch (SQLException e) { throw new IllegalStateException(e); } finally { SqlDatabaseUtils.safeClose(null, st); } } } else { StringBuilder sb = new StringBuilder(); sb.append("INSERT INTO "); sb.append(_tableName); sb.append(" (id,distinct_count"); sb.append(','); sb.append(annotationColumnName); for (String columnName : columnNames) { sb.append(','); sb.append(columnName); } sb.append(") VALUES (?,?,?"); for (int i = 0; i < values.size(); i++) { sb.append(",?"); } sb.append(")"); PreparedStatement st = null; try { st = _connection.prepareStatement(sb.toString()); st.setInt(1, row.getId()); st.setInt(2, distinctCount); st.setBoolean(3, true); for (int i = 0; i < values.size(); i++) { st.setObject(i + 4, values.get(i)); } st.executeUpdate(); a.incrementRowCount(distinctCount); } catch (SQLException e) { throw new IllegalStateException(e); } finally { SqlDatabaseUtils.safeClose(null, st); } } } private String getColumnName(RowAnnotation annotation, boolean createIfNonExisting) { String columnName = _annotationColumnNames.get(annotation); if (columnName == null) { if (createIfNonExisting) { int index = _nextColumnIndex.getAndIncrement(); columnName = "col" + index; performUpdate("ALTER TABLE " + _tableName + " ADD COLUMN " + columnName + " " + SqlDatabaseUtils.getSqlType(Boolean.class) + " DEFAULT FALSE"); _annotationColumnNames.put(annotation, columnName); } } return columnName; } private String getColumnName(InputColumn<?> inputColumn, boolean createIfNonExisting) { String columnName = _inputColumnNames.get(inputColumn); if (columnName == null) { if (createIfNonExisting) { int index = _nextColumnIndex.getAndIncrement(); columnName = "col" + index; Class<?> javaType = inputColumn.getDataType(); performUpdate("ALTER TABLE " + _tableName + " ADD COLUMN " + columnName + " " + SqlDatabaseUtils.getSqlType(javaType)); _inputColumnNames.put(inputColumn, columnName); } } return columnName; } @Override public synchronized void reset(RowAnnotation annotation) { String columnName = getColumnName(annotation, false); if (columnName != null) { performUpdate("UPDATE " + _tableName + " SET " + columnName + " = FALSE"); } } @Override public InputRow[] getRows(RowAnnotation annotation) { String annotationColumnName = getColumnName(annotation, false); if (annotationColumnName == null) { return new InputRow[0]; } ResultSet rs = null; Statement st = null; try { st = _connection.createStatement(); StringBuilder sb = new StringBuilder(); sb.append("SELECT id"); ArrayList<InputColumn<?>> inputColumns = new ArrayList<InputColumn<?>>(_inputColumnNames.keySet()); for (InputColumn<?> inputColumn : inputColumns) { sb.append(','); String columnName = _inputColumnNames.get(inputColumn); sb.append(columnName); } sb.append(" FROM "); sb.append(_tableName); sb.append(" WHERE "); sb.append(annotationColumnName); sb.append(" = TRUE"); rs = st.executeQuery(sb.toString()); List<InputRow> rows = new ArrayList<InputRow>(); while (rs.next()) { int id = rs.getInt(1); MockInputRow row = new MockInputRow(id); int colIndex = 2; for (InputColumn<?> inputColumn : inputColumns) { Object value = rs.getObject(colIndex); row.put(inputColumn, value); colIndex++; } rows.add(row); } return rows.toArray(new InputRow[rows.size()]); } catch (SQLException e) { throw new IllegalStateException(e); } finally { SqlDatabaseUtils.safeClose(rs, st); } } @Override public Map<Object, Integer> getValueCounts(RowAnnotation annotation, InputColumn<?> inputColumn) { HashMap<Object, Integer> map = new HashMap<Object, Integer>(); String inputColumnName = getColumnName(inputColumn, false); if (inputColumnName == null) { return map; } String annotationColumnName = getColumnName(annotation, false); if (annotationColumnName == null) { return map; } ResultSet rs = null; PreparedStatement st = null; try { st = _connection.prepareStatement("SELECT " + inputColumnName + ", SUM(distinct_count) FROM " + _tableName + " WHERE " + annotationColumnName + " = TRUE GROUP BY " + inputColumnName); rs = st.executeQuery(); while (rs.next()) { Object value = rs.getObject(1); int count = rs.getInt(2); map.put(value, count); } return map; } catch (SQLException e) { throw new IllegalStateException(e); } finally { SqlDatabaseUtils.safeClose(rs, st); } } @Override public void transferAnnotations(RowAnnotation from, RowAnnotation to) { final int increment = from.getRowCount(); ((RowAnnotationImpl) to).incrementRowCount(increment); // TODO: Copy records to new annotation also? } }