/*************************************************************************
* *
* This file is part of the 20n/act project. *
* 20n/act enables DNA prediction for synthetic biology/bioengineering. *
* Copyright (C) 2017 20n Labs, Inc. *
* *
* Please direct all queries to act@20n.com. *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU General Public License for more details. *
* *
* You should have received a copy of the GNU General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* *
*************************************************************************/
package com.act.lcms.v2;
import com.act.lcms.LCMSNetCDFParser;
import com.act.lcms.LCMSSpectrum;
import com.act.lcms.MS1;
import com.act.lcms.XZ;
import com.act.utils.rocksdb.ColumnFamilyEnumeration;
import com.act.utils.rocksdb.DBUtil;
import com.act.utils.rocksdb.RocksDBAndHandles;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.lang3.tuple.Triple;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.rocksdb.FlushOptions;
import org.rocksdb.RocksDB;
import org.rocksdb.RocksDBException;
import org.rocksdb.RocksIterator;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.stream.XMLStreamException;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
public class TraceIndexExtractor {
private static final Logger LOGGER = LogManager.getFormatterLogger(TraceIndexExtractor.class);
private static final Charset UTF8 = StandardCharsets.UTF_8;
/* TIMEPOINTS_KEY is a fixed key into a separate column family in the index that just holds a list of time points.
* Within that column family, there is only one entry:
* "timepoints" -> serialized array of time point doubles
* and we use this key to write/read those time points. Since time points are shared across all traces, we can
* maintain this one copy in the index and reconstruct the XZ pairs as we read trace intensity arrays. */
private static final byte[] TIMEPOINTS_KEY = "timepoints".getBytes(UTF8);
private static final Double WINDOW_WIDTH_FROM_CENTER = MS1.MS1_MZ_TOLERANCE_DEFAULT;
// TODO: make this take a plate barcode and well coordinates instead of a scan file.
public static final String OPTION_INDEX_PATH = "x";
public static final String OPTION_SCAN_FILE = "i";
public static final String OPTION_TARGET_MASSES = "m";
public static final String HELP_MESSAGE = StringUtils.join(new String[]{
"This class extracts traces from an LCMS scan files for a list of target m/z values, ",
"and writes them to an on-disk index for later processing."
}, "");
public static final List<Option.Builder> OPTION_BUILDERS = new ArrayList<Option.Builder>() {{
add(Option.builder(OPTION_INDEX_PATH)
.argName("index path")
.desc("A path to the directory where the on-disk index will be stored; must not already exist")
.hasArg().required()
.longOpt("index")
);
add(Option.builder(OPTION_SCAN_FILE)
.argName("scan file")
.desc("A path to the LCMS NetCDF scan file to read")
.hasArg().required()
.longOpt("input")
);
add(Option.builder(OPTION_TARGET_MASSES)
.argName("target mass file")
.desc("A file containing m/z values for which to search")
.hasArg().required()
.longOpt("target-masses")
);
add(Option.builder("h")
.argName("help")
.desc("Prints this help message")
.longOpt("help")
);
}};
public static final HelpFormatter HELP_FORMATTER = new HelpFormatter();
static {
HELP_FORMATTER.setWidth(100);
}
public enum COLUMN_FAMILIES implements ColumnFamilyEnumeration<COLUMN_FAMILIES> {
TARGET_TO_WINDOW("target_mz_to_window_obj"),
ID_TO_TRACE("id_to_trace"),
TIMEPOINTS("timepoints"),
;
private static final Map<String, COLUMN_FAMILIES> reverseNameMap =
new HashMap<String, COLUMN_FAMILIES>() {{
for (COLUMN_FAMILIES cf : COLUMN_FAMILIES.values()) {
put(cf.getName(), cf);
}
}};
private String name;
COLUMN_FAMILIES(String name) {
this.name = name;
}
public String getName() {
return name;
}
@Override
public COLUMN_FAMILIES getFamilyByName(String name) {
return reverseNameMap.get(name);
}
}
public TraceIndexExtractor() {
}
public static void main(String[] args) throws Exception {
Options opts = new Options();
for (Option.Builder b : OPTION_BUILDERS) {
opts.addOption(b.build());
}
CommandLine cl = null;
try {
CommandLineParser parser = new DefaultParser();
cl = parser.parse(opts, args);
} catch (ParseException e) {
System.err.format("Argument parsing failed: %s\n", e.getMessage());
HELP_FORMATTER.printHelp(TraceIndexExtractor.class.getCanonicalName(), HELP_MESSAGE, opts, null, true);
System.exit(1);
}
if (cl.hasOption("help")) {
HELP_FORMATTER.printHelp(TraceIndexExtractor.class.getCanonicalName(), HELP_MESSAGE, opts, null, true);
return;
}
// Not enough memory available? We're gonna need a bigger heap.
long maxMemory = Runtime.getRuntime().maxMemory();
if (maxMemory < 1 << 34) { // 16GB
String msg = StringUtils.join(
String.format("You have run this class with a maximum heap size of less than 16GB (%d to be exact). ",
maxMemory),
"There is no way this process will complete with that much space available. ",
"Crank up your heap allocation with -Xmx and try again."
, "");
throw new RuntimeException(msg);
}
File inputFile = new File(cl.getOptionValue(OPTION_SCAN_FILE));
if (!inputFile.exists()) {
System.err.format("Cannot find input scan file at %s\n", inputFile.getAbsolutePath());
HELP_FORMATTER.printHelp(TraceIndexExtractor.class.getCanonicalName(), HELP_MESSAGE, opts, null, true);
System.exit(1);
}
File rocksDBFile = new File(cl.getOptionValue(OPTION_INDEX_PATH));
if (rocksDBFile.exists()) {
System.err.format("Index file at %s already exists--remove and retry\n", rocksDBFile.getAbsolutePath());
HELP_FORMATTER.printHelp(TraceIndexExtractor.class.getCanonicalName(), HELP_MESSAGE, opts, null, true);
System.exit(1);
}
List<Double> targetMZs = new ArrayList<>();
try (BufferedReader reader = new BufferedReader(new FileReader(cl.getOptionValue(OPTION_TARGET_MASSES)))) {
String line;
while ((line = reader.readLine()) != null) {
targetMZs.add(Double.valueOf(line));
}
}
TraceIndexExtractor extractor = new TraceIndexExtractor();
extractor.processScan(targetMZs, inputFile, rocksDBFile);
}
public void processScan(List<Double> targetMZs, File scanFile, File rocksDBFile)
throws RocksDBException, ParserConfigurationException, XMLStreamException, IOException {
LOGGER.info("Accessing scan file at %s", scanFile.getAbsolutePath());
LCMSNetCDFParser parser = new LCMSNetCDFParser();
Iterator<LCMSSpectrum> spectrumIterator = parser.getIterator(scanFile.getAbsolutePath());
LOGGER.info("Opening index at %s", rocksDBFile.getAbsolutePath());
RocksDB.loadLibrary();
RocksDBAndHandles<COLUMN_FAMILIES> dbAndHandles = null;
try {
// TODO: add to existing DB instead of complaining if the DB already exists. That'll enable one index per scan.
dbAndHandles = DBUtil.createNewRocksDB(rocksDBFile, COLUMN_FAMILIES.values());
// TODO: split targetMZs into batches of ~100k and extract incrementally to allow huge input sets.
LOGGER.info("Extracting traces");
IndexedTraces windowsTimesAndTraces = runSweepLine(targetMZs, spectrumIterator);
LOGGER.info("Writing search targets to on-disk index");
writeWindowsToDB(dbAndHandles, windowsTimesAndTraces.getWindows());
LOGGER.info("Writing trace data to on-disk index");
writeTracesToDB(dbAndHandles, windowsTimesAndTraces.getTimes(), windowsTimesAndTraces.getAllTraces());
} finally {
if (dbAndHandles != null) {
dbAndHandles.getDb().close();
}
}
LOGGER.info("Done");
}
// Make this public so it can be de/serialized
public static class MZWindow implements Serializable {
private static final long serialVersionUID = -3326765598920871504L;
int index;
Double targetMZ;
double min;
double max;
public MZWindow(int index, Double targetMZ) {
this.index = index;
this.targetMZ = targetMZ;
this.min = targetMZ - WINDOW_WIDTH_FROM_CENTER;
this.max = targetMZ + WINDOW_WIDTH_FROM_CENTER;
}
public int getIndex() {
return index;
}
public Double getTargetMZ() {
return targetMZ;
}
public double getMin() {
return min;
}
public double getMax() {
return max;
}
}
private static class IndexedTraces {
/* IndexedTraces is a 2D array of aggregated intensity values over some <mz window, time> domains. The organization
* of this matrix works in conjunction with the list of windows and the list of times that we build in parallel.
*
* The three structures look like:
* windows:
* <min_0, target_0, max_0>,
* <min_1, target_1, max_1>,
* <min_2, target_2, max_2>,
* ...
*
* times:
* t_0,
* t_1,
* t_2,
* ...
*
* allTraces (as i_{window_idx}_{time_idx}):
* i_0_0, i_0_1, i_0_2, ...
* i_1_0, i_1_1, i_1_2, ...
* i_2_0, i_2_1, i_2_2, ...
* ...
*
* So the aggregate intensity for all m/z values in the window <min_1, max_1> at time point 2 is i_1_2.
*
* We keep the window and time values separate for 1) efficiency and 2) ordering (i.e. no window -> array maps).
*
* When we want to create an iterator over the <time, intensity> traces (i.e. List<XZ>) for each window, we knit the
* single time array together with the appropriate list of intensity values online, reducing the overhead of storing
* several hundred million XZ objects (which turns out to be fairly expensive). */
List<MZWindow> windows;
List<Double> times;
List<List<Double>> allTraces;
public IndexedTraces(List<MZWindow> windows, List<Double> times, List<List<Double>> allTraces) {
this.windows = windows;
this.times = times;
this.allTraces = allTraces;
}
public List<MZWindow> getWindows() {
return windows;
}
public List<Double> getTimes() {
return times;
}
public List<List<Double>> getAllTraces() {
return allTraces;
}
}
/**
* Initiate a data feast of all traces within some window allocation. OM NOM NOM.
* @param iter An iterator over an LCMS data file.
* @return The windows, time points, and per-window traces.
*/
private IndexedTraces runSweepLine(List<Double> targetMZs, Iterator<LCMSSpectrum> iter)
throws RocksDBException, IOException {
// Create windows for sweep-linin'.
List<MZWindow> windows = new ArrayList<MZWindow>() {{
int i = 0;
for (Double targetMZ : targetMZs) {
add(new MZWindow(i, targetMZ));
i++;
}
}};
/* We *must* ensure the windows are sorted in m/z order for the sweep line to work. However, we don't know anything
* about the input targetMZs list, which may be immutable or may be in some order the client wants to preserve.
* Rather than mess with that array, we'll sort the windows in our internal array and leave be he client's targets.
*/
Collections.sort(windows, (a, b) -> a.getTargetMZ().compareTo(b.getTargetMZ()));
List<Double> times = new ArrayList<>();
List<List<Double>> allTraces = new ArrayList<List<Double>>(windows.size()) {{
for (int i = 0; i < windows.size(); i++) {
add(new ArrayList<>());
}
}};
// Keep an array of accumulators around to reduce the overhead of accessing the trace matrix for accumulation.
double[] sumIntensitiesInEachWindow = new double[windows.size()];
int timepointCounter = 0;
while (iter.hasNext()) {
LCMSSpectrum spectrum = iter.next();
Double time = spectrum.getTimeVal();
// Store one list of the time values so we can knit times and intensity sums later to form XZs.
times.add(time);
for (int i = 0; i < sumIntensitiesInEachWindow.length; i++) {
sumIntensitiesInEachWindow[i] = 0.0;
}
timepointCounter++;
if (timepointCounter % 100 == 0) {
LOGGER.info("Extracted %d timepoints (now at %.3fs)", timepointCounter, time);
}
/* We use a sweep-line approach to scanning through the m/z windows so that we can aggregate all intensities in
* one pass over the current LCMSSpectrum (this saves us one inner loop in our extraction process). The m/z
* values in the LCMSSpectrum become our "critical" or "interesting points" over which we sweep our m/z ranges.
* The next window in m/z order is guaranteed to be the next one we want to consider since we address the points
* in m/z order as well. As soon as we've passed out of the range of one of our windows, we discard it. It is
* valid for a window to be added to and discarded from the working queue in one application of the work loop. */
LinkedList<MZWindow> workingQueue = new LinkedList<>();
// TODO: can we reuse these instead of creating fresh?
LinkedList<MZWindow> tbdQueue = new LinkedList<>(windows);
// Assumption: these arrive in m/z order.
for (Pair<Double, Double> mzIntensity : spectrum.getIntensities()) {
Double mz = mzIntensity.getLeft();
Double intensity = mzIntensity.getRight();
// First, shift any applicable ranges onto the working queue based on their minimum mz.
while (!tbdQueue.isEmpty() && tbdQueue.peekFirst().getMin() <= mz) {
workingQueue.add(tbdQueue.pop());
}
// Next, remove any ranges we've passed.
while (!workingQueue.isEmpty() && workingQueue.peekFirst().getMax() < mz) {
workingQueue.pop();
}
if (workingQueue.isEmpty()) {
if (tbdQueue.isEmpty()) {
// If both queues are empty, there are no more windows to consider at all. One to the next timepoint!
break;
}
// If there's nothing that happens to fit in this range, skip it!
continue;
}
// The working queue should now hold only ranges that include this m/z value. Sweep line swept!
/* Now add this intensity to accumulator value for each of the items in the working queue.
* By the end of the outer loop, trace(t) = Sum(intensity) | win_min <= m/z <= win_max @ time point # t */
for (MZWindow window : workingQueue) {
// TODO: count the number of times we add intensities to each window's accumulator for MS1-style warnings.
sumIntensitiesInEachWindow[window.getIndex()] += intensity;
}
}
/* Extend allTraces to add a row of accumulated intensity values for this time point. We build this incrementally
* because the LCMSSpectrum iterator doesn't tell us how many time points to expect up front. */
for (int i = 0; i < sumIntensitiesInEachWindow.length; i++) {
allTraces.get(i).add(sumIntensitiesInEachWindow[i]);
}
}
// Trace data has been devoured. Might want to loosen the belt at this point...
LOGGER.info("Done extracting %d traces", allTraces.size());
return new IndexedTraces(windows, times, allTraces);
}
private void writeWindowsToDB(RocksDBAndHandles<COLUMN_FAMILIES> dbAndHandles, List<MZWindow> windows)
throws RocksDBException, IOException {
for (MZWindow window : windows) {
byte[] keyBytes = serializeObject(window.getTargetMZ());
byte[] valBytes = serializeObject(window);
dbAndHandles.put(COLUMN_FAMILIES.TARGET_TO_WINDOW, keyBytes, valBytes);
}
dbAndHandles.getDb().flush(new FlushOptions());
LOGGER.info("Done writing window data to index");
}
private void writeTracesToDB(RocksDBAndHandles<COLUMN_FAMILIES> dbAndHandles,
List<Double> times,
List<List<Double>> allTraces) throws RocksDBException, IOException {
LOGGER.info("Writing timepoints to on-disk index (%d points)", times.size());
dbAndHandles.put(COLUMN_FAMILIES.TIMEPOINTS, TIMEPOINTS_KEY, serializeDoubleList(times));
for (int i = 0; i < allTraces.size(); i++) {
byte[] keyBytes = serializeObject(i);
byte[] valBytes = serializeDoubleList(allTraces.get(i));
dbAndHandles.put(COLUMN_FAMILIES.ID_TO_TRACE, keyBytes, valBytes);
if (i % 1000 == 0) {
LOGGER.info("Finished writing %d traces", i);
}
// Drop this trace as soon as it's written so the GC can pick it up and hopefully reduce memory pressure.
allTraces.set(i, Collections.emptyList());
}
dbAndHandles.getDb().flush(new FlushOptions());
LOGGER.info("Done writing trace data to index");
}
public Iterator<Pair<Double, List<XZ>>> getIteratorOverTraces(File index)
throws IOException, RocksDBException {
RocksDBAndHandles<COLUMN_FAMILIES> dbAndHandles = DBUtil.openExistingRocksDB(index, COLUMN_FAMILIES.values());
final RocksDBAndHandles.RocksDBIterator rangesIterator = dbAndHandles.newIterator(COLUMN_FAMILIES.TARGET_TO_WINDOW);
rangesIterator.reset();
final List<Double> times;
try {
byte[] timeBytes = dbAndHandles.get(COLUMN_FAMILIES.TIMEPOINTS, TIMEPOINTS_KEY);
times = deserializeDoubleList(timeBytes);
} catch (RocksDBException e) {
LOGGER.error("Caught RocksDBException when trying to fetch times: %s", e.getMessage());
throw new RuntimeException(e);
} catch (IOException e) {
LOGGER.error("Caught IOException when trying to fetch timese %s", e.getMessage());
throw new UncheckedIOException(e);
}
return new Iterator<Pair<Double, List<XZ>>>() {
int windowNum = 0;
@Override
public boolean hasNext() {
return rangesIterator.isValid();
}
@Override
public Pair<Double, List<XZ>> next() {
byte[] valBytes = rangesIterator.value();
MZWindow window;
windowNum++;
try {
window = deserializeObject(valBytes);
} catch (IOException e) {
LOGGER.error("Caught IOException when iterating over mz windows (%d): %s", windowNum, e.getMessage());
throw new UncheckedIOException(e);
} catch (ClassNotFoundException e) {
LOGGER.error("Caught ClassNotFoundException when iterating over mz windows (%d): %s",
windowNum, e.getMessage());
throw new RuntimeException(e);
}
byte[] traceKeyBytes;
try {
traceKeyBytes = serializeObject(window.getIndex());
} catch (IOException e) {
throw new UncheckedIOException(e);
}
List<Double> trace;
try {
byte[] traceBytes = dbAndHandles.get(COLUMN_FAMILIES.ID_TO_TRACE, traceKeyBytes);
if (traceBytes == null) {
String msg = String.format("Got null byte array back for trace key %d (target: %.6f)",
window.getIndex(), window.getTargetMZ());
LOGGER.error(msg);
throw new RuntimeException(msg);
}
trace = deserializeDoubleList(traceBytes);
} catch (RocksDBException e) {
LOGGER.error("Caught RocksDBException when trying to extract trace %d (%.6f): %s",
window.getIndex(), window.getTargetMZ(), e.getMessage());
throw new RuntimeException(e);
} catch (IOException e) {
LOGGER.error("Caught IOException when trying to extract trace %d (%.6f): %s",
window.getIndex(), window.getTargetMZ(), e.getMessage());
throw new UncheckedIOException(e);
}
if (trace.size() != times.size()) {
LOGGER.error("Found mismatching trace and times size (%d vs. %d), continuing anyway",
trace.size(), times.size());
}
List<XZ> xzs = new ArrayList<>(times.size());
for (int i = 0; i < trace.size() && i < times.size(); i++) {
xzs.add(new XZ(times.get(i), trace.get(i)));
}
/* The Rocks iterator pattern is a bit backwards from the Java model, as we don't need an initial next() call
* to prime the iterator, and `isValid` indicates whether we've gone past the end of the iterator. We thus
* advance only after we've read the current value, which means the next hasNext call after we've walked off the
* edge will return false. */
rangesIterator.next();
return Pair.of(window.getTargetMZ(), xzs);
}
};
}
private static <T> byte[] serializeObject(T obj) throws IOException {
try (ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream oo = new ObjectOutputStream(bos)) {
oo.writeObject(obj);
oo.flush();
return bos.toByteArray();
}
}
private static <T> T deserializeObject(byte[] bytes) throws IOException, ClassNotFoundException {
try (ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))) {
// Assumes you know what you're getting into when deserializing. Don't use this blindly.
return (T) ois.readObject();
}
}
private static byte[] serializeDoubleList(List<Double> vals) throws IOException {
try (ByteArrayOutputStream bos = new ByteArrayOutputStream(vals.size() * Double.BYTES)) {
byte[] bytes = new byte[Double.BYTES];
for (Double val : vals) {
bos.write(ByteBuffer.wrap(bytes).putDouble(val).array());
}
return bos.toByteArray();
}
}
private static List<Double> deserializeDoubleList(byte[] byteStream) throws IOException {
List<Double> results = new ArrayList<>(byteStream.length / Double.BYTES);
try (ByteArrayInputStream is = new ByteArrayInputStream(byteStream)) {
byte[] bytes = new byte[Double.BYTES];
while (is.available() > 0) {
int readBytes = is.read(bytes); // Same as read(bytes, 0, bytes.length)
if (readBytes != bytes.length) {
throw new RuntimeException(String.format("Couldn't read a whole double at a time: %d", readBytes));
}
results.add(ByteBuffer.wrap(bytes).getDouble());
}
}
return results;
}
}