package org.molgenis.matrix.component; import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import javax.persistence.EntityManager; import org.apache.commons.collections.Closure; import org.apache.commons.collections.CollectionUtils; import org.hibernate.ScrollableResults; import org.hibernate.Session; import org.hibernate.ejb.EntityManagerImpl; import org.molgenis.framework.db.Database; import org.molgenis.framework.db.DatabaseException; import org.molgenis.framework.db.Query; import org.molgenis.matrix.MatrixException; import org.molgenis.matrix.component.interfaces.BasicMatrix; import org.molgenis.matrix.component.interfaces.DatabaseMatrix; import org.molgenis.matrix.component.interfaces.SliceableMatrix; import org.molgenis.matrix.component.sqlbackend.Backend; import org.molgenis.matrix.component.sqlbackend.EAVViewBackend; import org.molgenis.organization.Investigation; import org.molgenis.pheno.Category; import org.molgenis.pheno.Measurement; import org.molgenis.pheno.ObservationElement; import org.molgenis.pheno.ObservedValue; import org.molgenis.protocol.Protocol; /** * Sliceable version of the PhenoMatrix. This assumes the rows are * ObservationTarget, the columns ObservableFeature and there can be zero or * more ObservedValue for each combination (hence return List < ObservedValue * > for each value 'V') * * Slicing will be done by setting filters. * * The data is retrieved by (a) retrieving visible columns and rows and (2) * retrieval of the matching data using columns and rows as filters. The whole * set is filtered by investigation. * */ public class SliceablePhenoMatrixMV<R extends ObservationElement, C extends ObservationElement, V extends ObservedValue> extends AbstractObservationElementMatrix<R, C, V> implements DatabaseMatrix // implements SliceableMatrix<R, C, V>, DatabaseMatrix { private final EntityManager em; private final Investigation investigation; private final LinkedHashMap<Protocol, List<Measurement>> measurementsByProtocol; private final Map<Measurement, List<Category>> categoryByMeasurement = new HashMap<Measurement, List<Category>>(); public final String JOIN_COLUMN = "PA_ID"; private final Backend backend; private Protocol sortProtocol; private Measurement sortMeasurement; private String sortOrder; private Database db; @SuppressWarnings("unchecked") public SliceablePhenoMatrixMV(Database database, Class<R> rowClass, Class<C> colClass, Investigation investigation, LinkedHashMap<Protocol, List<Measurement>> measurementByProtocol) throws MatrixException { this.db = database; this.rowClass = rowClass; this.colClass = colClass; this.valueClass = (Class<V>) ObservedValue.class; this.investigation = investigation; this.measurementsByProtocol = measurementByProtocol; this.em = database.getEntityManager(); this.backend = new EAVViewBackend(this, "LL_VWM_", "PATIENT"); try { loadCategories(); } catch (DatabaseException ex) { throw new MatrixException(ex); } } public void setDatabase(Database db) { this.db = db; } public Database getDatabase() { return db; } private void loadCategories() throws DatabaseException { String qlString = "SELECT m, c FROM Measurement m JOIN m.categories c WHERE m.investigation = :investigation"; List<Object[]> measCats = em.createQuery(qlString).setParameter("investigation", investigation).getResultList(); for (Object[] rec : measCats) { Measurement m = (Measurement) rec[0]; Category c = (Category) rec[1]; if (categoryByMeasurement.containsKey(m)) { categoryByMeasurement.get(m).add(c); } else { List<Category> cats = new ArrayList<Category>(); cats.add(c); categoryByMeasurement.put(m, cats); } } // objects.toString(); // List<Category> categories = // db.query(Category.class).eq(Category.INVESTIGATION, // investigation.getId()).find(); // for (Category category : categories) { // Collection<Measurement> measurements = // category.getCategoriesCollection(); // for (Measurement measurement : measurements) { // if(categoryByMeasurement.containsKey(measurement)) { // categoryByMeasurement.get(measurement).add(category); // } else { // List<Category> cats = new ArrayList<Category>(); // cats.add(category); // categoryByMeasurement.put(measurement, cats); // } // } // } } public void setSort(Protocol protocol, Measurement measurement, String sortOrder) { this.sortProtocol = protocol; this.sortMeasurement = measurement; this.sortOrder = sortOrder; } // public void setColumns(List<String> columNames) throws MatrixException { // boolean firstJOIN_COLUMN = true; // for(final String colName : columNames) { // if(colName.equalsIgnoreCase(JOIN_COLUMN)) { // if(!firstJOIN_COLUMN) { // continue; // } // firstJOIN_COLUMN = false; // } // // final String protocolName = StringUtils.substringBefore(colName, "_"); // try { // final Protocol p = db.query(Protocol.class).eq(Protocol.NAME, // protocolName).find().get(0); // final Measurement m = db.query(Measurement.class).eq(Measurement.NAME, // colName).find().get(0); // // if(getMeasurementsByProtocol().containsKey(p)) { // if(!getMeasurementsByProtocol().get(p).contains(m)) { // getMeasurementsByProtocol().get(p).add(m); // } // } else { // List<Measurement> ms = new ArrayList<Measurement>(); // ms.add(m); // getMeasurementsByProtocol().put(p, ms); // } // } catch (DatabaseException e) { // throw new MatrixException(e); // } // } // } @Deprecated @Override public List<R> getRowHeaders() throws MatrixException { // reload the rowheaders if filters have changed. if (rowDirty) { try { Query<R> query = this.createSelectQuery(getRowClass(), db); this.rowHeaders = query.find(); rowDirty = false; } catch (Exception e) { throw new MatrixException(e); } } return rowHeaders; } @Override public Integer getRowCount() throws MatrixException { try { String query = createCountQuery(); System.out.println(query); Number count = (Number) em.createNativeQuery(query).getSingleResult(); return count.intValue(); } catch (Exception e) { throw new MatrixException(e); } } @Deprecated @Override public List<String> getColPropertyNames() { final List<String> result = new ArrayList<String>(); try { for (final C col : getColHeaders()) { result.add(col.getName()); } } catch (MatrixException e) { throw new RuntimeException(e); } return result; } @SuppressWarnings("unchecked") @Override public List<C> getColHeaders() throws MatrixException { final List<C> result = new ArrayList<C>(); List<Column> columns = getColumns(); CollectionUtils.forAllDo(columns, new Closure() { @Override public void execute(Object arg0) { result.add((C) ((Column) arg0).getMeasurement()); } }); return result; } public List<Column> getColumns() { List<Column> result = new ArrayList<Column>(); for (Map.Entry<Protocol, List<Measurement>> entry : getMeasurementsByProtocol().entrySet()) { for (Measurement measurement : entry.getValue()) { Column c = new Column(entry.getKey(), measurement); result.add(c); } } return result; } // public List<String> getMyColumnNames() { // return null; // } @Deprecated // use getColumns().size() instead of this function @Override public Integer getColCount() throws MatrixException { return getColumns().size(); } @Deprecated @Override public BasicMatrix<R, C, V> getResult() throws Exception { throw new UnsupportedOperationException(); } /** Helper method to produce a selection query for columns or rows */ @Deprecated private <D extends ObservationElement> Query<D> createSelectQuery(Class<D> xClass, Database db) throws MatrixException { return this.createQuery(xClass, false, db); } /** * * @param field * , either ObservedValue.FEATURE or ObservedValue.TARGET * @throws MatrixException */ @Deprecated private <D extends ObservationElement> Query<D> createQuery(Class<D> xClass, boolean countAll, Database db) throws MatrixException { // If xClass == getRowClass(): // A. filter on rowIndex + rowHeaderProperty // B. filter on colValue: 1 subquery per column // C. filter on rowOffset and rowLimit try { Query<D> xQuery = db.query(xClass); // add limit and offset, unless count if (!countAll) { if (xClass.equals(getColClass())) { xQuery.limit(colLimit); xQuery.offset(colOffset); } else { xQuery.limit(rowLimit); xQuery.offset(rowOffset); } } return xQuery; } catch (Exception e) { throw new MatrixException(e); } } @Deprecated // use getColumns().size(); instead private int getVisibleColumnCount() { return getColumns().size(); } public String createQuery() { return backend.createQuery(false, rules); } private String createCountQuery() throws Exception { return backend.createQuery(true, rules); } // Todo add category (labels) @SuppressWarnings("unchecked") public List<Object[]> getTypedValues() throws MatrixException { List<Measurement> colMeasurements = new ArrayList<Measurement>(); for (Entry<Protocol, List<Measurement>> entry : measurementsByProtocol.entrySet()) { for (Measurement value : entry.getValue()) { colMeasurements.add(value); } } String sql = createQuery(); System.out.println(sql); return em.createNativeQuery(sql).setMaxResults(getRowLimit()).setFirstResult(getRowOffset()).getResultList(); } @Override @Deprecated public List<V>[][] getValueLists() throws MatrixException { try { List<Measurement> colMeasurements = (List<Measurement>) getColHeaders(); String sql = createQuery(); System.out.println(sql); int offset = getRowOffset(); @SuppressWarnings("unchecked") List<Object[]> data = em.createNativeQuery(sql).setMaxResults(getRowLimit()).setFirstResult(offset) .getResultList(); if (!data.isEmpty()) { int numColumns = data.get(0).length; if (offset > 0) { // oracle add a rownum column to the end numColumns--; } final List<V>[][] valueMatrix = create(data.size(), numColumns); for (int iRow = 0; iRow < data.size(); ++iRow) { for (int iCol = 0; iCol < numColumns; ++iCol) { valueMatrix[iRow][iCol] = new ArrayList<V>(); @SuppressWarnings("unchecked") V ov = (V) new ObservedValue(); if (data.get(iRow)[iCol] != null) { String value = data.get(iRow)[iCol].toString(); value = getCategoryLabel(colMeasurements, iCol, value); ov.setValue(value); } else { ov.setValue("null"); } valueMatrix[iRow][iCol].add(ov); } } return valueMatrix; } else { return create(0, 0); } } catch (Exception ex) { throw new MatrixException(ex); } } private String getCategoryLabel(List<Measurement> colMeasurements, int iCol, String value) { Measurement measurement = colMeasurements.get(iCol); if (categoryByMeasurement.containsKey(measurement)) { for (Category category : categoryByMeasurement.get(measurement)) { if (category.getCode_String().equalsIgnoreCase(value)) { return category.getLabel(); } } } return value; } @Deprecated public List<V>[][] create(int rows, int cols) { // create all empty rows as well @SuppressWarnings("unchecked") List<V>[][] data = new ArrayList[rows][cols]; for (int i = 0; i < data.length; i++) { for (int j = 0; j < cols; j++) { data[i][j] = new ArrayList<V>(); } } return data; } public ScrollableResults getScrollableValues(boolean exportVisibleRows) throws Exception { String sql = createQuery(); Session session = ((EntityManagerImpl) getDatabase().getEntityManager()).getSession(); ScrollableResults sr; if (exportVisibleRows) { int offset = getRowOffset(); int limit = getRowLimit(); sr = session.createSQLQuery(sql).setFirstResult(offset).setMaxResults(limit).scroll(); } else { sr = session.createSQLQuery(sql).scroll(); } return sr; } @Deprecated @Override public V[][] getValues() throws MatrixException { List<V>[][] values = getValueLists(); int rowCnt = values.length; int colCnt = values[0].length; @SuppressWarnings("unchecked") V[][] result = (V[][]) new ObservedValue[rowCnt][colCnt]; for (int i = 0; i < rowCnt; ++i) { for (int j = 0; j < colCnt; ++j) { result[i][j] = values[i][j].get(0); } } return result; } public EntityManager getEntityManager() { return em; } // TODO implement sorting in Columns @Deprecated public Measurement getSortMeasurement() { return sortMeasurement; } @Deprecated public String getSortOrder() { return sortOrder; } @Deprecated public Protocol getSortProtocol() { return sortProtocol; } public Investigation getInvestigation() { return investigation; } public LinkedHashMap<Protocol, List<Measurement>> getMeasurementsByProtocol() { return measurementsByProtocol; } }