/*
* RHQ Management Platform
* Copyright (C) 2005-2010 Red Hat, Inc.
* All rights reserved.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation version 2 of the License.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
*/
package org.rhq.helpers.perftest.support.dbunit;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.dbunit.database.DatabaseSequenceFilter;
import org.dbunit.database.IDatabaseConnection;
import org.dbunit.dataset.DataSetException;
import org.dbunit.dataset.IDataSet;
import org.dbunit.dataset.ITableIterator;
import org.dbunit.dataset.filter.ITableFilter;
import org.rhq.helpers.perftest.support.jpa.ColumnValues;
import org.rhq.helpers.perftest.support.jpa.DependencyInclusionResolver;
import org.rhq.helpers.perftest.support.jpa.DependencyType;
import org.rhq.helpers.perftest.support.jpa.Edge;
import org.rhq.helpers.perftest.support.jpa.EntityDependencyGraph;
import org.rhq.helpers.perftest.support.jpa.Node;
import org.rhq.helpers.perftest.support.jpa.mapping.ColumnValuesTableMap;
import org.rhq.helpers.perftest.support.jpa.mapping.EntityTranslation;
import org.rhq.helpers.perftest.support.jpa.mapping.RelationshipTranslation;
/**
* This is an implementation of {@link ITableFilter} interface acts as a proxy
* between the {@link EntityDependencyGraph} and the dbUnit.
* <p>
* This filter is able to produce a table iterator that traverses the tables in the
* correct order so that foreign key constraints are obeyed during insertion of data.
* <p>
* It is also able to filter the data from the tables corresponding to the entities by only
* allowing entities (and the underlying table rows) of certain primary key values to be included.
*
* @author Lukas Krejci
*/
public class EntityRelationshipFilter extends DatabaseSequenceFilter {
private ColumnValuesTableMap resolvedPks;
//we need to compute the resolvedPks *AND* provide the tables necessary for inclusion
//to the super-constructor. Use this hack to help store away the resolution for initialization
//until after the super constructor call.
private static final ThreadLocal<ColumnValuesTableMap> RESOLUTION_IN_CONSTRUCTOR = new ThreadLocal<ColumnValuesTableMap>();
public EntityRelationshipFilter(IDatabaseConnection connection, Map<Class<?>, Set<ColumnValues>> allowedPks,
DependencyInclusionResolver inclusionResolver) throws DataSetException, SQLException {
super(connection, getNeccesaryTablesAndSetResolution(connection, inclusionResolver, allowedPks));
resolvedPks = RESOLUTION_IN_CONSTRUCTOR.get();
RESOLUTION_IN_CONSTRUCTOR.set(null);
}
@Override
public ITableIterator iterator(IDataSet dataSet, boolean reversed) throws DataSetException {
return new EntityRelationshipTableIterator(super.iterator(dataSet, reversed), resolvedPks);
}
private static String[] getNeccesaryTablesAndSetResolution(IDatabaseConnection connection,
DependencyInclusionResolver inclusionResolver, Map<Class<?>, Set<ColumnValues>> primaryPks) throws SQLException {
ColumnValuesTableMap resolution = resolve(connection, inclusionResolver, primaryPks);
RESOLUTION_IN_CONSTRUCTOR.set(resolution);
Set<String> tables = new HashSet<String>();
for (String t : resolution.keySet()) {
tables.add(t.toLowerCase());
}
return tables.toArray(new String[tables.size()]);
}
private static ColumnValuesTableMap resolve(IDatabaseConnection connection,
DependencyInclusionResolver inclusionResolver, Map<Class<?>, Set<ColumnValues>> primaryPks) throws SQLException {
ColumnValuesTableMap resolution = new ColumnValuesTableMap();
EntityDependencyGraph edg = new EntityDependencyGraph();
edg.addEntities(primaryPks.keySet());
resolvePks(connection, edg, inclusionResolver, primaryPks, resolution);
return resolution;
}
private static void resolvePks(IDatabaseConnection connection, EntityDependencyGraph edg,
DependencyInclusionResolver inclusionResolver, Map<Class<?>, Set<ColumnValues>> primaryPks,
ColumnValuesTableMap resolvedPks) throws SQLException {
for (Map.Entry<Class<?>, Set<ColumnValues>> entry : primaryPks.entrySet()) {
Node node = edg.getNode(entry.getKey());
Set<ColumnValues> pks = entry.getValue();
//check that the pk columns have defined names from the user
if (pks != null) {
for (ColumnValues pk : pks) {
int idx = 0;
for (ColumnValues.Column col : pk) {
if (col.getName() == null) {
col.setName(node.getTranslation().getPkColumns()[idx]);
}
idx++;
}
}
}
resolvePks(connection, inclusionResolver, node, pks, resolvedPks);
}
}
private static void resolvePks(IDatabaseConnection connection, DependencyInclusionResolver inclusionResolver,
Node node, Set<ColumnValues> nodePks, ColumnValuesTableMap resolvedPks) throws SQLException {
Set<ColumnValues> unresolvedPks;
Set<ColumnValues> resolvedTablePks = resolvedPks.get(node.getTranslation().getTableName());
//determine whether to bale out...
if (resolvedPks.containsKey(node.getTranslation().getTableName())) {
if (resolvedTablePks == null) {
//yes, this table has been identified as "include all"
return;
}
if (nodePks == null) {
//there is an entry for this table in the resolved pks already and we're
//telling it to include everything... let's leave what's in the resolution
//already and quit.
return;
}
}
if (resolvedTablePks == null || resolvedTablePks.isEmpty()) {
unresolvedPks = nodePks;
} else {
unresolvedPks = new LinkedHashSet<ColumnValues>();
for (ColumnValues pk : nodePks) {
if (!resolvedTablePks.contains(pk)) {
unresolvedPks.add(pk);
}
}
}
if (unresolvedPks != null) {
if (unresolvedPks.isEmpty()) {
//there are no data to include for this table. bale out.
return;
}
} else {
resolvedPks.put(node.getTranslation().getTableName(), null);
}
for (Edge e : node.getIncomingEdges()) {
if (e.getToField() != null) {
Set<ColumnValues> dependingPks = resolveDependingPks(connection, e, unresolvedPks, resolvedPks);
resolvePks(connection, inclusionResolver, e.getFrom(), dependingPks, resolvedPks);
} else {
resolvedPks.getOrCreate(e.getFrom().getTranslation().getTableName());
}
}
if (unresolvedPks != null) {
resolvedPks.getOrCreate(node.getTranslation().getTableName()).addAll(unresolvedPks);
}
for (Edge e : node.getOutgoingEdges()) {
//only include the dependents if the relationship
//is actually defined on the entity (i.e. don't include
//"back-references", like combined @JoinColumn @ManyToOne defined only on the target
//entity
if (e.getFromField() != null && inclusionResolver.isValid(e)) {
Set<ColumnValues> dependentPks = resolveDependentPks(connection, e, unresolvedPks, resolvedPks);
resolvePks(connection, inclusionResolver, e.getTo(), dependentPks, resolvedPks);
} else {
//add nothing or create a new record for this table
//this will mark it as "done"
resolvedPks.getOrCreate(e.getTo().getTranslation().getTableName());
}
}
}
private static Set<ColumnValues> resolveDependentPks(IDatabaseConnection connection, Edge edge,
Set<ColumnValues> fromPks, ColumnValuesTableMap resolvedPks) throws SQLException {
RelationshipTranslation translation = edge.getTranslation();
if (translation.getRelationTable() != null) {
//copy the fromPks to columnValues. We'll use the pks from the from table
//to find the corresponding entries in the relation table
Set<ColumnValues> columnValues = null;
if (fromPks != null) {
columnValues = new HashSet<ColumnValues>();
for (ColumnValues pk : fromPks) {
columnValues.add(pk.clone());
}
//now change the names of the columns in columnValues to the corresponding
//relationTableFromColumns (this assumes the same order of the columns
//in the case of composite pk)
for (int i = 0; i < translation.getRelationTableFromColumns().length; ++i) {
for (ColumnValues cols : columnValues) {
cols.getColumns().get(i).setName(translation.getRelationTableFromColumns()[i]);
}
}
}
String[] fromAndToCols = new String[translation.getRelationTableFromColumns().length
+ translation.getRelationTableToColumns().length];
System.arraycopy(translation.getRelationTableFromColumns(), 0, fromAndToCols, 0,
translation.getRelationTableFromColumns().length);
System.arraycopy(translation.getRelationTableToColumns(), 0, fromAndToCols,
translation.getRelationTableFromColumns().length, translation.getRelationTableToColumns().length);
if (fromPks != null) {
Set<ColumnValues> fromAndToValues = getValuesFromTable(connection, translation.getRelationTable(),
fromAndToCols, columnValues);
//add the relation table to the resolvedPks using fromAndToValues as its primary keys
resolvedPks.getOrCreate(translation.getRelationTable()).addAll(fromAndToValues);
//now read out the to pks from fromAndToCols are return them as the "to" table primary keys
Set<ColumnValues> toPks = new HashSet<ColumnValues>();
for (ColumnValues cols : fromAndToValues) {
ColumnValues toPk = new ColumnValues();
for (int i = 0; i < translation.getRelationTableToColumns().length; ++i) {
String colName = translation.getRelationTableToColumns()[i];
String pkName = edge.getTo().getTranslation().getPkColumns()[i];
toPk.add(pkName, cols.getColumnByName(colName).getValue());
}
toPks.add(toPk);
}
return removeValuesWithNullColumn(toPks);
} else {
resolvedPks.put(translation.getRelationTable(), null);
return null;
}
} else {
if (fromPks == null) {
return null;
}
//get the values of the "fromColumns" of the relation from the "from" table
Set<ColumnValues> columnValues = getValuesFromTable(connection, edge.getFrom().getTranslation()
.getTableName(), translation.getFromColumns(), fromPks);
//now change the names of the columns in columnValues to correspond to the ones
//in the "to" table (this assumes that the columns in fromColumns and toColumns
//correspond to each other by position)
for (int i = 0; i < translation.getToColumns().length; ++i) {
for (ColumnValues cols : columnValues) {
cols.getColumns().get(i).setName(translation.getToColumns()[i]);
}
}
//now translate the foreign keys into primary keys
//but first check if we even need to do it by comparing the column names
boolean columnsDiffer = false;
Set<String> pkColumns = new HashSet<String>(Arrays.asList(edge.getTo().getTranslation().getPkColumns()));
for(String col : translation.getToColumns()) {
if (!pkColumns.contains(col)) {
columnsDiffer = true;
break;
}
}
if (columnsDiffer) {
columnValues =
getValuesFromTable(connection, edge.getTo().getTranslation().getTableName(), edge.getTo()
.getTranslation().getPkColumns(), removeValuesWithNullColumn(columnValues));
}
Set<ColumnValues> ret = getValuesFromTable(connection, edge.getTo().getTranslation().getTableName(), edge
.getTo().getTranslation().getPkColumns(), columnValues);
return removeValuesWithNullColumn(ret);
}
}
private static Set<ColumnValues> resolveDependingPks(IDatabaseConnection connection, Edge edge,
Set<ColumnValues> toPks, ColumnValuesTableMap resolvedPks) throws SQLException {
RelationshipTranslation translation = edge.getTranslation();
if (translation.getRelationTable() == null) {
if (toPks == null) {
return null;
}
//get the foreign keys in the "to" table
Set<ColumnValues> columnValues = getValuesFromTable(connection, edge.getTo().getTranslation()
.getTableName(), translation.getToColumns(), toPks);
//now rename the foreign keys to their foreign key counterparts in the "from" table
for (int i = 0; i < translation.getFromColumns().length; ++i) {
for (ColumnValues cols : columnValues) {
cols.getColumns().get(i).setName(translation.getFromColumns()[i]);
}
}
EntityTranslation fromTranslation = edge.getFrom().getTranslation();
//now translate the foreign keys into primary keys
//but first check if we even need to do it by comparing the column names
boolean columnsDiffer = false;
Set<String> pkColumns = new HashSet<String>(Arrays.asList(fromTranslation.getPkColumns()));
for(String col : translation.getFromColumns()) {
if (!pkColumns.contains(col)) {
columnsDiffer = true;
break;
}
}
if (columnsDiffer) {
columnValues = getValuesFromTable(connection, fromTranslation.getTableName(),
fromTranslation.getPkColumns(), removeValuesWithNullColumn(columnValues));
}
return removeValuesWithNullColumn(columnValues);
} else {
//only bother with one-to-many relationships. A many-to-many
//relationship implicitly means that the two entities are not tightly
//connected (with a many-to-many relationship, either of the entities
//can always "live without" the entities from the other side of the relationship).
if (edge.getDependencyType() != DependencyType.MANY_TO_MANY) {
//copy the toPks to columnValues. We'll use the pks from the to table
//to find the corresponding entries in the relation table
Set<ColumnValues> columnValues = null;
if (toPks != null) {
columnValues = new HashSet<ColumnValues>();
for (ColumnValues pk : toPks) {
columnValues.add(pk.clone());
}
//now change the names of the columns in columnValues to the corresponding
//relationTableToColumns (this assumes the same order of the columns
//in the case of composite pk)
for (int i = 0; i < translation.getRelationTableToColumns().length; ++i) {
for (ColumnValues cols : columnValues) {
cols.getColumns().get(i).setName(translation.getRelationTableToColumns()[i]);
}
}
}
String[] fromAndToCols = new String[translation.getRelationTableFromColumns().length
+ translation.getRelationTableToColumns().length];
System.arraycopy(translation.getRelationTableFromColumns(), 0, fromAndToCols, 0,
translation.getRelationTableFromColumns().length);
System.arraycopy(translation.getRelationTableToColumns(), 0, fromAndToCols,
translation.getRelationTableFromColumns().length, translation.getRelationTableToColumns().length);
if (toPks != null) {
Set<ColumnValues> fromAndToValues = getValuesFromTable(connection, translation.getRelationTable(),
fromAndToCols, columnValues);
//add the relation table to the resolvedPks using fromAndToValues as its primary keys
resolvedPks.getOrCreate(translation.getRelationTable()).addAll(fromAndToValues);
//now read out the to pks from fromAndToCols are return them as the "from" table primary keys
Set<ColumnValues> fromPks = new HashSet<ColumnValues>();
for (ColumnValues cols : fromAndToValues) {
ColumnValues fromPk = new ColumnValues();
for (int i = 0; i < translation.getRelationTableFromColumns().length; ++i) {
String colName = translation.getRelationTableFromColumns()[i];
String pkName = edge.getFrom().getTranslation().getPkColumns()[i];
fromPk.add(pkName, cols.getColumnByName(colName).getValue());
}
fromPks.add(fromPk);
}
return removeValuesWithNullColumn(fromPks);
} else {
resolvedPks.put(translation.getRelationTable(), null);
return null;
}
} else {
//put no restrictions on the search if the toPks are null (unrestricted)
//otherwise pretend there's nothing depending.
return toPks == null ? null : new HashSet<ColumnValues>();
}
}
}
/**
* @param columnValues
* @return
*/
private static Set<ColumnValues> removeValuesWithNullColumn(Set<ColumnValues> columnValues) {
Set<ColumnValues> ret = new HashSet<ColumnValues>();
for (ColumnValues cols : columnValues) {
boolean add = true;
for (ColumnValues.Column c : cols) {
if (c.getValue() == null) {
add = false;
break;
}
}
if (add) {
ret.add(cols);
}
}
return ret;
}
private static String colNamesToSql(String[] colNames) {
StringBuilder bld = new StringBuilder();
if (colNames.length == 0)
return "";
for (String col : colNames) {
bld.append(", ").append(col);
}
return bld.substring(1);
}
private static Set<ColumnValues> getValuesFromTable(IDatabaseConnection connection, String tableName,
String[] valueColumns, Set<ColumnValues> knownlColumns) throws SQLException {
StringBuilder sql = new StringBuilder("SELECT ").append(colNamesToSql(valueColumns)).append(" FROM ")
.append(tableName).append(" WHERE ");
Set<ColumnValues> ret = new HashSet<ColumnValues>();
for (ColumnValues cols : knownlColumns) {
sql.append("(");
for (ColumnValues.Column c : cols) {
sql.append(c.getName()).append(" = ? AND ");
}
sql.replace(sql.length() - 5, sql.length(), ") OR ");
}
sql.replace(sql.length() - 4, sql.length(), "");
PreparedStatement st = null;
try {
st = connection.getConnection().prepareStatement(sql.toString());
int idx = 1;
for (ColumnValues cols : knownlColumns) {
for (ColumnValues.Column c : cols) {
st.setObject(idx++, c.getValue());
}
}
ResultSet rs = st.executeQuery();
ResultSetMetaData rsmd = rs.getMetaData();
while (rs.next()) {
ColumnValues vals = new ColumnValues();
for (int i = 1; i <= rsmd.getColumnCount(); ++i) {
String columnName = rsmd.getColumnName(i);
Object value = rs.getObject(i);
vals.add(columnName, value);
}
ret.add(vals);
}
} finally {
if (st != null) {
st.close();
}
}
return ret;
}
}