/* * #%L * gitools-core * %% * Copyright (C) 2013 Universitat Pompeu Fabra - Biomedical Genomics group * %% * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public * License along with this program. If not, see * <http://www.gnu.org/licenses/gpl-3.0.html>. * #L% */ package org.gitools.plugins.mutex.analysis; import com.google.common.primitives.Doubles; import com.google.common.primitives.Ints; import org.gitools.analysis.AnalysisProcessor; import org.gitools.api.analysis.IProgressMonitor; import org.gitools.api.matrix.*; import org.gitools.api.modulemap.IModuleMap; import org.gitools.api.resource.ResourceReference; import org.gitools.heatmap.Heatmap; import org.gitools.heatmap.decorator.impl.NonEventToNullFunction; import org.gitools.matrix.filter.NotNullPredicate; import org.gitools.matrix.filter.ValueFilterFunction; import org.gitools.matrix.model.hashmatrix.HashMatrix; import org.gitools.matrix.model.hashmatrix.HashMatrixDimension; import org.gitools.matrix.model.matrix.element.LayerAdapter; import org.gitools.matrix.sort.AggregationFunction; import org.gitools.utils.aggregation.NonNullCountAggregator; import org.gitools.utils.cutoffcmp.CutoffCmp; import java.util.*; import static com.google.common.collect.Lists.newArrayList; import static org.gitools.api.matrix.MatrixDimensionKey.COLUMNS; import static org.gitools.api.matrix.MatrixDimensionKey.ROWS; public class MutualExclusiveProcessor implements AnalysisProcessor { private final static Key<MutualExclusiveWeightCache> CACHEKEY = new Key<MutualExclusiveWeightCache>() { }; private final MutualExclusiveAnalysis analysis; public MutualExclusiveProcessor(MutualExclusiveAnalysis analysis) { this.analysis = analysis; } @Override public void run(IProgressMonitor monitor) { Date startTime = new Date(); // Prepare results matrix IMatrixDimension testDimension = analysis.getTestDimension(); IModuleMap testGroups; IModuleMap weightGroups; if (testDimension.getId().equals(ROWS)) { testGroups = analysis.getRowsModuleMap().get(); weightGroups = analysis.getColumnsModuleMap().get(); } else { weightGroups = analysis.getRowsModuleMap().get(); testGroups = analysis.getColumnsModuleMap().get(); } IMatrix results = calculate( monitor, analysis.getData().get(), analysis.getData().get().getLayers().get(analysis.getLayer()), testGroups, weightGroups, testDimension, analysis.getWeightDimension(), analysis.getIterations(), analysis.getEventFunction(), analysis.isDiscardEmpty() ); analysis.setResults(new ResourceReference<>("results", new Heatmap(results))); analysis.setStartTime(startTime); analysis.setElapsedTime(new Date().getTime() - startTime.getTime()); } private IMatrix calculate(final IProgressMonitor monitor, final Heatmap data, final IMatrixLayer<Double> dataLayer, final IModuleMap testGroups, final IModuleMap weightGroups, final IMatrixDimension testDimension, final IMatrixDimension weightDimension, final int iterations, NonEventToNullFunction eventFunction, boolean discardEmpty) { final IMatrixDimension resultWeightDimension = new HashMatrixDimension(COLUMNS, weightGroups.getModules()); final IMatrixDimension resultTestDimension = new HashMatrixDimension(ROWS, testGroups.getModules()); final MutualExclusiveTest test = new MutualExclusiveTest(); final LayerAdapter<MutualExclusiveResult> adapter = new LayerAdapter<>(MutualExclusiveResult.class); IMatrix resultsMatrix = new HashMatrix( adapter.getMatrixLayers(), resultTestDimension, resultWeightDimension ); String weightGroupInfo; int counter = 0; for (final String weightGroup : resultWeightDimension) { ++counter; double[] weights = new double[0]; weightGroupInfo = weightGroup + " (" + counter + "/" + resultWeightDimension.size() + ")"; boolean weightGroupInfoSet = false; IMatrixDimension weightDimensionSubset = null; for (String testGroup : resultTestDimension) { IMatrixDimension testDimensionSubset = testDimension.subset(testGroups.getMappingItems(testGroup)); Set<String> samples = weightGroups.getMappingItems(weightGroup); if (discardEmpty) { samples = getNonEmptySamples(data, dataLayer, testDimensionSubset, weightDimension.subset(samples), eventFunction); weightDimensionSubset = weightDimension.subset(samples); weights = getWeights(monitor, data, dataLayer, testDimension.getId(), weightDimensionSubset, weightGroupInfo + ": " + testGroup, eventFunction); } else if (weights.length == 0) { weightDimensionSubset = weightDimension.subset(samples); weights = getWeights(monitor, data, dataLayer, testDimension.getId(), weightDimensionSubset, weightGroupInfo, eventFunction); } if (!weightGroupInfoSet) { monitor.begin("Performing test for " + weightGroupInfo, resultTestDimension.size() * iterations); weightGroupInfoSet = true; } if (monitor.isCancelled()) { break; } monitor.info("Group: " + testGroup); int coverage = getCoverage(data, dataLayer, testDimensionSubset, weightDimensionSubset, eventFunction); int[] draws = getDraws(data, dataLayer, testDimensionSubset, weightDimensionSubset, eventFunction); int signal = 0; for (int d : draws) { signal += d; } if (monitor.isCancelled()) { break; } //sets monitor.worked MutualExclusiveResult r = new MutualExclusiveTest().processTest(draws, weights, coverage, signal, iterations, monitor); if (monitor.isCancelled()) { break; } adapter.set(resultsMatrix, r, testGroup, weightGroup); } } return resultsMatrix; } private Set<String> getNonEmptySamples(IMatrix data, IMatrixLayer<Double> dataLayer, IMatrixDimension testDimensionSubset, IMatrixDimension weightDimensionSubset, NonEventToNullFunction eventFunction) { //Set<String> samples = new HashSet<>(); IMatrixIterable<String> iterable = data.newPosition() .iterate(weightDimensionSubset) .transform(new AggregationFunction(dataLayer, NonNullCountAggregator.INSTANCE, testDimensionSubset, eventFunction)) .transform(new DoubleToIntegerFunction()) .filter(new NotNullPredicate<Integer>()) .transform(new DimensionIdFunction(weightDimensionSubset)); ArrayList<String> samples = newArrayList(iterable); return new HashSet<>(samples); } private Map<String, Double> getCachedWeights(IMatrixLayer<Double> dataLayer, IMatrixDimension weightDimension, NonEventToNullFunction<?> eventFunction) { MutualExclusiveWeightCache mutexCache = dataLayer.getCache(CACHEKEY); if (mutexCache == null) { mutexCache = new MutualExclusiveWeightCache(); } return mutexCache.getCacheWeights( createFingerprint(eventFunction, weightDimension)); } private String createFingerprint(NonEventToNullFunction<?> eventFunction, IMatrixDimension weightDimension) { return weightDimension.getId().getLabel() + "-" + eventFunction.getDescription(); } private void setCachedWeights(IMatrixLayer<Double> dataLayer, Map<String, Double> cachedWeights, IMatrixDimension weightDimension, NonEventToNullFunction function) { MutualExclusiveWeightCache cache = dataLayer.getCache(CACHEKEY); if (cache == null) { cache = new MutualExclusiveWeightCache(); } cache.setCacheWeights(createFingerprint(function, weightDimension), cachedWeights); dataLayer.setCache(CACHEKEY, cache); } private double[] getWeights(IProgressMonitor monitor, Heatmap data, final IMatrixLayer<Double> dataLayer, MatrixDimensionKey testDimensionKey, IMatrixDimension weightDimension, String weightGroupInfo, final NonEventToNullFunction<?> eventFunction) { IMatrixIterable<Double> weightIterator; final Map<String, Double> cachedWeights = getCachedWeights(dataLayer, weightDimension, eventFunction); int cacheSize = cachedWeights.size(); IMatrixDimension completeDataDimension = data.getContents().getDimension(testDimensionKey); final AggregationFunction aggregation = new AggregationFunction(dataLayer, NonNullCountAggregator.INSTANCE, completeDataDimension, eventFunction); final AbstractMatrixFunction<Double, String> readWeightFunction = new AbstractMatrixFunction<Double, String>() { @Override public Double apply(String identifier, IMatrixPosition position) { Double v; if (cachedWeights.containsKey(identifier)) { v = cachedWeights.get(identifier); } else { v = aggregation.apply(identifier, position); cachedWeights.put(identifier, v); } return v; } }; weightIterator = data.newPosition().iterate(weightDimension) .monitor(monitor, "Calculating weights for " + weightGroupInfo) .transform(readWeightFunction) .filter(new ValueFilterFunction(dataLayer, CutoffCmp.NE, 0.0, 0.0)); double[] weights = Doubles.toArray(newArrayList(weightIterator)); if (cacheSize != cachedWeights.size()) { setCachedWeights(dataLayer, cachedWeights, weightDimension, eventFunction); } return weights; } /** * Calculates the coverage: Which items of the weightDimension have at least * one positive event. */ private int getCoverage(IMatrix data, IMatrixLayer<Double> dataLayer, IMatrixDimension testDimension, final IMatrixDimension weightDimension, NonEventToNullFunction<?> eventFunction) { IMatrixIterable<String> it = data.newPosition() .iterate(dataLayer, testDimension, weightDimension) .transform(eventFunction) .transform(new AbstractMatrixFunction<String, Double>() { @Override public String apply(Double value, IMatrixPosition position) { return (value != null) ? position.get(weightDimension) : ""; } }); HashSet<String> ids = new HashSet<>(); for (String id : it) { if (id.equals("")) continue; ids.add(id); } return ids.size(); } /** * Calculate for each item how many positive events are available. */ private int[] getDraws(IMatrix data, IMatrixLayer<Double> dataLayer, IMatrixDimension testDimension, IMatrixDimension weightDimension, NonEventToNullFunction<?> eventFunction) { IMatrixIterable<Integer> it; it = data.newPosition() .iterate(testDimension) .transform(new AggregationFunction(dataLayer, NonNullCountAggregator.INSTANCE, weightDimension, eventFunction)) .transform(new DoubleToIntegerFunction()) .filter(new NotNullPredicate<Integer>()); return Ints.toArray(newArrayList(it)); } private class MutualExclusiveWeightCache { //1st String: DimensionID + event description //2nd String id of row or col //Double weight Map<String, Map<String, Double>> cacheWeightsCatalog; public MutualExclusiveWeightCache() { cacheWeightsCatalog = new HashMap<>(); } public Map<String, Double> getCacheWeights(String fingerprint) { if (cacheWeightsCatalog == null || !cacheWeightsCatalog.containsKey(fingerprint)) { return new HashMap<>(); } return cacheWeightsCatalog.get(fingerprint); } public void setCacheWeights(String fingerprint, Map<String, Double> cacheWeights) { this.cacheWeightsCatalog.put(fingerprint, cacheWeights); } } }