/*************************************************************************
* *
* This file is part of the 20n/act project. *
* 20n/act enables DNA prediction for synthetic biology/bioengineering. *
* Copyright (C) 2017 20n Labs, Inc. *
* *
* Please direct all queries to act@20n.com. *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* *
*************************************************************************/
package com.act.lcms.v2.fullindex;
import com.act.utils.CLIUtil;
import com.act.utils.rocksdb.DBUtil;
import com.act.utils.rocksdb.RocksDBAndHandles;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Option;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.joda.time.DateTime;
import org.rocksdb.RocksDBException;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
/**
* This is the conjoined twin of Builder. If IndexBuilder changes in a material way, this class should also.
*/
public class Searcher {
private static final Logger LOGGER = LogManager.getFormatterLogger(Searcher.class);
private static final Character RANGE_SEPARATOR = ':';
private static final String OUTPUT_HEADER = StringUtils.join(new String[] {
"id", "time", "m/z", "intensity"
}, "\t");
public static final String OPTION_INDEX_PATH = "x";
public static final String OPTION_MZ_RANGE = "m";
public static final String OPTION_TIME_RANGE = "t";
public static final String OPTION_OUTPUT_FILE = "o";
public static final String HELP_MESSAGE = StringUtils.join(new String[]{
"Queries a triple index constructed by Builder for readings in some m/z and time window.",
}, "");
public static final List<Option.Builder> OPTION_BUILDERS = new ArrayList<Option.Builder>() {{
add(Option.builder(OPTION_INDEX_PATH)
.argName("index path")
.desc("A path to the directory where the on-disk index will be stored; must not already exist")
.hasArg().required()
.longOpt("index")
);
add(Option.builder(OPTION_MZ_RANGE)
.argName("m/z range")
.desc("An m/z range to query separated by a colon, like 151.0:152.0")
.hasArg()
.longOpt("mz-range")
);
add(Option.builder(OPTION_OUTPUT_FILE)
.argName("output file")
.desc("A destination at which to write the found triples as a TSV (default is stdout)")
.hasArg()
.longOpt("output")
);
add(Option.builder(OPTION_TIME_RANGE)
.argName("time range")
.desc("An time range to query separated by a colon, like 45.0:50.0")
.hasArg()
.longOpt("time-range")
);
}};
public static class Factory {
public static Searcher makeSearcher(File indexDir)
throws RocksDBException, ClassNotFoundException, IOException {
RocksDBAndHandles<ColumnFamilies> dbAndHandles =
DBUtil.openExistingRocksDB(indexDir, ColumnFamilies.values());
Searcher searcher = new Searcher(dbAndHandles);
searcher.init();
return searcher;
}
}
private RocksDBAndHandles<ColumnFamilies> dbAndHandles;
private List<MZWindow> mzWindows;
private List<Float> timepoints;
Searcher(RocksDBAndHandles<ColumnFamilies> dbAndHandles) {
this.dbAndHandles = dbAndHandles;
}
public static void main(String args[]) throws Exception {
CLIUtil cliUtil = new CLIUtil(Searcher.class, HELP_MESSAGE, OPTION_BUILDERS);
CommandLine cl = cliUtil.parseCommandLine(args);
File indexDir = new File(cl.getOptionValue(OPTION_INDEX_PATH));
if (!indexDir.exists() || !indexDir.isDirectory()) {
cliUtil.failWithMessage("Unable to read index directory at %s", indexDir.getAbsolutePath());
}
if (!cl.hasOption(OPTION_MZ_RANGE) && !cl.hasOption(OPTION_TIME_RANGE)) {
cliUtil.failWithMessage("Extracting all readings is not currently supported; specify an m/z or time range");
}
Pair<Double, Double> mzRange = extractRange(cl.getOptionValue(OPTION_MZ_RANGE));
Pair<Double, Double> timeRange = extractRange(cl.getOptionValue(OPTION_TIME_RANGE));
Searcher searcher = Factory.makeSearcher(indexDir);
List<TMzI> results = searcher.searchIndexInRange(mzRange, timeRange);
if (cl.hasOption(OPTION_OUTPUT_FILE)) {
try (PrintWriter writer = new PrintWriter(new FileWriter(cl.getOptionValue(OPTION_OUTPUT_FILE)))) {
Searcher.writeOutput(writer, results);
}
} else {
// Don't close the print writer if we're writing to stdout.
Searcher.writeOutput(new PrintWriter(new OutputStreamWriter(System.out)), results);
}
LOGGER.info("Done");
}
private static void writeOutput(PrintWriter writer, List<TMzI> results) throws IOException {
int counter = 0;
writer.println(OUTPUT_HEADER);
for (TMzI triple : results) {
writer.format("%d\t%.6f\t%.6f\t%.6f\n", counter, triple.getTime(), triple.getMz(), triple.getIntensity());
counter++;
}
writer.flush();
}
private static Pair<Double, Double> extractRange(String rangeStr) {
// Skip empty ranges so we can just limit on time or m/z.
if (rangeStr == null || rangeStr.isEmpty()) {
return null;
}
String[] parts = StringUtils.split(rangeStr, RANGE_SEPARATOR);
if (parts.length == 1) {
LOGGER.info("Found only one value in ranged '%s', returning closed range (for exact extraction)", rangeStr);
Double exactVal = Double.valueOf(parts[0]);
return Pair.of(exactVal, exactVal);
} else if (parts.length == 2) {
Double lowerBound = Double.valueOf(parts[0]);
Double upperBound = Double.valueOf(parts[1]);
if (upperBound < lowerBound) {
String msg = String.format(
"Lower bound %.6f exceeds upper bound %.6f. Cowardly refusing to search for an empty range",
lowerBound, upperBound);
LOGGER.error(msg);
throw new RuntimeException(msg);
}
return Pair.of(lowerBound, upperBound);
} else {
String msg = String.format(
"Unable to parse range string '%s'; did you use the correct separator ('%c')?", RANGE_SEPARATOR);
LOGGER.error(msg);
throw new RuntimeException(msg);
}
}
protected void init() throws RocksDBException, ClassNotFoundException, IOException {
LOGGER.info("Initializing DB");
// TODO: hold onto the byte representation of the timepoints so we can use them as keys more easily.
timepoints = Utils.byteArrayToFloatList(
dbAndHandles.get(ColumnFamilies.TIMEPOINTS, Builder.TIMEPOINTS_KEY)
);
LOGGER.info("Loaded %d timepoints", timepoints.size());
// Assumes timepoints are sorted. TODO: check!
mzWindows = new ArrayList<>();
RocksDBAndHandles.RocksDBIterator mzIter = dbAndHandles.newIterator(ColumnFamilies.TARGET_TO_WINDOW);
mzIter.reset();
while (mzIter.isValid()) {
// The keys are the target m/z's, so we can ignore them.
mzWindows.add(Utils.deserializeObject(mzIter.value()));
mzIter.next();
}
// Sort windows so we can easily search through them
Collections.sort(mzWindows, (a, b) -> a.getTargetMZ().compareTo(b.getTargetMZ()));
LOGGER.info("Loaded %d m/z windows", mzWindows.size());
}
/**
* Searches an LCMS index for all (time, m/z, intensity) triples within some time and m/z ranges.
*
* Note that this method is very much a first-draft/WIP. There are many opportunities for optimization and
* improvement here, but this works as an initial attempt. This method is littered with TODOs, which once TODone
* should make this a near optimal method of searching through LCMS readings.
*
* @param mzRange The range of m/z values for which to search.
* @param timeRange The time range for which to search.
* @return A list of (time, m/z, intensity) triples that fall within the specified ranges.
* @throws RocksDBException
* @throws ClassNotFoundException
* @throws IOException
*/
public List<TMzI> searchIndexInRange(
Pair<Double, Double> mzRange,
Pair<Double, Double> timeRange)
throws RocksDBException, ClassNotFoundException, IOException {
// TODO: gracefully handle the case when only range is specified.
// TODO: consider producing some sort of query plan structure that can be used for optimization/explanation.
DateTime start = DateTime.now();
/* Demote the time range to floats, as we know that that's how we stored times in the DB. This tight coupling would
* normally be a bad thing, but given that this class is joined at the hip with Builder necessarily, it
* doesn't seem like a terrible thing at the moment. */
Pair<Float, Float> tRangeF = // My kingdom for a functor!
Pair.of(timeRange.getLeft().floatValue(), timeRange.getRight().floatValue());
LOGGER.info("Running search for %.6f <= t <= %.6f, %.6f <= m/z <= %.6f",
tRangeF.getLeft(), tRangeF.getRight(), mzRange.getLeft(), mzRange.getRight()
);
// TODO: short circuit these filters. The first failure after success => no more possible hits.
List<Float> timesInRange = timepointsInRange(tRangeF);
byte[][] timeIndexBytes = extractValueBytes(
ColumnFamilies.TIMEPOINT_TO_TRIPLES,
timesInRange,
Float.BYTES,
ByteBuffer::putFloat
);
// TODO: bail if all the timeIndexBytes lengths are zero.
List<MZWindow> mzWindowsInRange = mzWindowsInRange(mzRange);
byte[][] mzIndexBytes = extractValueBytes(
ColumnFamilies.WINDOW_ID_TO_TRIPLES,
mzWindowsInRange,
Integer.BYTES,
(buff, mz) -> buff.putInt(mz.getIndex())
);
// TODO: bail if all the mzIndexBytes are zero.
/* TODO: if the number of entries in one range is significantly smaller than the other (like an order of magnitude
* or more, skip extraction of the other set of ids and just filter at the end. This will be especially helpful
* when the number of ids in the m/z domain is small, as each time point will probably have >10k ids. */
LOGGER.info("Found/loaded %d matching time ranges, %d matching m/z ranges",
timesInRange.size(), mzWindowsInRange.size());
// TODO: there is no need to union the time indices since they are necessarily distinct. Just concatenate instead.
Set<Long> unionTimeIds = unionIdBuffers(timeIndexBytes);
Set<Long> unionMzIds = unionIdBuffers(mzIndexBytes);
// TODO: handle the case where one of the sets is empty specially. Either keep all in the other set or drop all.
// TODO: we might be able to do this faster by intersecting two sorted lists.
Set<Long> intersectionIds = new HashSet<>(unionTimeIds);
/* TODO: this is effectively a hash join, which isn't optimal for sets of wildly different cardinalities.
* Consider using sort-merge join instead, which will reduce the object overhead (by a lot) and allow us to pass
* over the union of the ids from each range just once when joining them. Additionally, just skip this whole step
* and filter at the end if one of the set's sizes is less than 1k or so and the other is large. */
intersectionIds.retainAll(unionMzIds);
LOGGER.info("Id intersection results: t = %d, mz = %d, t ^ mz = %d",
unionTimeIds.size(), unionMzIds.size(), intersectionIds.size());
List<Long> idsToFetch = new ArrayList<>(intersectionIds);
Collections.sort(idsToFetch); // Sort ids so we retrieve them in an order that exploits index locality.
LOGGER.info("Collecting TMzI triples");
// Collect all the triples for the ids we extracted.
// TODO: don't manifest all the bytes: just create a stream of results from the cursor to reduce memory overhead.
List<TMzI> results = new ArrayList<>(idsToFetch.size());
byte[][] resultBytes = extractValueBytes(
ColumnFamilies.ID_TO_TRIPLE,
idsToFetch,
Long.BYTES,
ByteBuffer::putLong
);
for (byte[] tmziBytes : resultBytes) {
results.add(TMzI.readNextFromByteBuffer(ByteBuffer.wrap(tmziBytes)));
}
// TODO: do this filtering inline with the extraction. We shouldn't have to load all the triples before filtering.
LOGGER.info("Performing final filtering");
int preFilterTMzICount = results.size();
results = results.stream().filter(tmzi ->
tmzi.getTime() >= tRangeF.getLeft() && tmzi.getTime() <= tRangeF.getRight() &&
tmzi.getMz() >= mzRange.getLeft() && tmzi.getMz() <= mzRange.getRight()
).collect(Collectors.toList());
LOGGER.info("Precise filtering results: %d -> %d", preFilterTMzICount, results.size());
DateTime end = DateTime.now();
LOGGER.info("Search completed in %dms", end.getMillis() - start.getMillis());
// TODO: return a stream instead that can load the triples lazily.
return results;
}
private List<Float> timepointsInRange(Pair<Float, Float> tRange) {
// TODO: short circuit these filters. The first failure after success => no more possible hits.
List<Float> timesInRange = new ArrayList<>( // Use an array list as we'll be accessing by index.
timepoints.stream().filter(x -> x >= tRange.getLeft() && x <= tRange.getRight()).collect(Collectors.toList())
);
if (timesInRange.size() == 0) {
LOGGER.warn("Found zero times in range %.6f - %.6f", tRange.getLeft(), tRange.getRight());
}
return timesInRange;
}
private List<MZWindow> mzWindowsInRange(Pair<Double, Double> mzRange) {
List<MZWindow> mzWindowsInRange = new ArrayList<>( // Same here--access by index.
mzWindows.stream().filter(x -> rangesOverlap(mzRange.getLeft(), mzRange.getRight(), x.getMin(), x.getMax())).
collect(Collectors.toList())
);
if (mzWindowsInRange.size() == 0) {
LOGGER.warn("Found zero m/z windows in range %.6f - %.6f", mzRange.getLeft(), mzRange.getRight());
}
return mzWindowsInRange;
}
/**
* Extracts the value bytes from the index corresponding to a list of keys of fixed primitive type.
* @param cf The column family from which to read.
* @param keys A list of keys whose values to extract.
* @param keyBytes The exact number of bytes required by a key; should be uniform for primitive-typed keys
* @param put A function that writes a key to a ByteBuffer.
* @param <K> The type of the key.
* @return An array of arrays of bytes, one per key, containing the values of the key at that position.
* @throws RocksDBException
*/
private <K> byte[][] extractValueBytes(
ColumnFamilies cf, List<K> keys, int keyBytes, BiFunction<ByteBuffer, K, ByteBuffer> put)
throws RocksDBException {
byte[][] valBytes = new byte[keys.size()][];
ByteBuffer keyBuffer = ByteBuffer.allocate(keyBytes);
for (int i = 0; i < keys.size(); i++) {
K k = keys.get(i);
keyBuffer.clear();
put.apply(keyBuffer, k).flip();
// TODO: try compacting the keyBuffer array to be safe?
valBytes[i] = dbAndHandles.get(cf, keyBuffer.array());
assert(valBytes[i] != null);
}
return valBytes;
}
private static boolean rangesOverlap(double aMin, double aMax, double bMin, double bMax) {
/* You can push this through negation and De Morgan's Law to get
* !(aMax < bMin || bMax < aMin) -> !(A to the left of B || B to the left of A) = intersection */
return aMax >= bMin && bMax >= aMin;
}
private static Set<Long> unionIdBuffers(byte[][] idBytes) {
/* TODO: this doesn't take advantage of the fact that all of the ids are in sorted order in every idBytes sub-array.
* We should be able to exploit that. For now, we'll just start by hashing the ids. */
Set<Long> uniqueIds = new HashSet<>();
for (int i = 0; i < idBytes.length; i++) {
assert(idBytes[i] != null);
ByteBuffer idsBuffer = ByteBuffer.wrap(idBytes[i]);
while (idsBuffer.hasRemaining()) {
uniqueIds.add(idsBuffer.getLong());
}
}
return uniqueIds;
}
}