package org.seqcode.motifs;
import java.util.*;
import java.io.*;
import java.sql.*;
import java.text.DecimalFormat;
import org.seqcode.data.connections.DatabaseException;
import org.seqcode.data.connections.DatabaseFactory;
import org.seqcode.data.connections.UnknownRoleException;
import org.seqcode.data.io.parsing.FASTAStream;
import org.seqcode.data.motifdb.*;
import org.seqcode.genome.Genome;
import org.seqcode.genome.location.Region;
import org.seqcode.genome.location.StrandedRegion;
import org.seqcode.genome.sequence.SequenceGenerator;
import org.seqcode.genome.sequence.SequenceUtils;
import org.seqcode.gsebricks.verbs.*;
import org.seqcode.gseutils.*;
import org.seqcode.motifs.*;
import cern.jet.random.Binomial;
import cern.jet.random.engine.RandomEngine;
//import org.seqcode.math.probability.Binomial;
/** Compare the frequencies of a set of motifs between two FASTA files
* usage:
* java org.seqcode.motifs.CompareEnrichment --species "$SC;SGDv2" [--first foo.fasta] [--second bar.fasta]
*
* can also specify --acceptwm to give a regex which the motif name must match
* of --rejectwm to specify a regex that the motif name must not match. Remember the regex must match
* the *entire* name, so use something like Hnf.*
*
* If --first is not supplied, then regions are read from STDIN. If --second is not supplied, then random
* regions are chosen according to --randombgcount and --randombgsize.
*
* Input files ending in .fasta or .fa are parsed as fasta. Otherwise, they're parsed as a list of regions.
*
* --cutoff .5 minimum percent (specify between 0 and 1) match to maximum motif score that counts as a match.
* --filtersig .001 maximum pvalue for reporting an enrichment between the two files
* --minfoldchange 1
* --minfrac 0 minimum fraction of the sequences that must contain the motif (can be in either file)
* --savedata motif_presence_file.txt
* --savedatahits
* --outfg output_fg.fasta
* --outbg output_bg.fasta
* --mask name;version;cutoff
* --threads 4 to control number of parallel threads
*
* The comparison code will check all percent cutoffs between the value you specify as --cutoff and 1 (in increments of .05)
* to find the most significant threshold that also meets the other criteria.
*
* Output columns are
* 1) log foldchange in frequency between the two files
* 2) motif count in first set
* 3) size of first set
* 4) frequency in first set
* 5) motif count in second set
* 6) size of second set
* 7) frequency in second set
* 8) pvalue of first count given second frequency
* 9) motif name
* 10) motif version
* 11) percent cutoff used
*
*/
public class CompareEnrichment {
public static final double step = .05; // when looking for most significant threshold, increment % identity by this much each try
int randombgcount = 1000, // number of random background regions to pic
randombgsize = 100, // size of random background regions
parsedregionexpand; // expand input regions by this much on either side
Genome genome;
double cutoffpercent, minfrac, minfoldchange, filtersig, maxbackfrac;
ArrayList<WeightMatrix> matrices; // these are the matrices to scan for
Map<String,char[]> foreground, background; // foreground and background sequences
Map<WeightMatrix, Double> maskingMatrices; // matrices to mask out of foreground and background
ArrayList<String> fgkeys, bgkeys; // order in which to scan; also used when saving region list to a matrix of motif-sequence presence
PrintWriter savedatafg = null, savedatabg = null, outfg = null, outbg = null;
boolean savedatahits = false, matchedbg = false;
int threads = 1;
public static void saveFasta(PrintWriter pw, Map<String, char[]> seqs) throws IOException {
for (String s : seqs.keySet()) {
pw.println(">" + s);
int i = 0;
char[] chars = seqs.get(s);
String seq = new String(chars);
if (!seq.matches("[ACTGactgNn]*")) {
throw new RuntimeException("Invalid sequence " + s + ": " + seq);
}
while (i < chars.length) {
int l = i + 60 < chars.length ? 60 : chars.length -i;
pw.write(chars,i,l);
i += 60;
pw.println();
}
}
}
public static Map<String,char[]> readFasta(BufferedReader reader) throws IOException {
FASTAStream stream = new FASTAStream(reader);
Map<String,char[]> output = new HashMap<String,char[]>();
while (stream.hasNext()) {
Pair<String,String> pair = stream.next();
output.put(pair.car(), pair.cdr().toCharArray());
}
return output;
}
/** reads region strings, eg "3:100-5000" fromreader returns corresponding sequence as output.
* If matchedRegions is not null, then fills it in with the flanking regions
* for each output region. You can use matchedRegion as background sequence since it came
* from the same approximate loci as the foreground that was read
*/
public static Map<String,char[]> readRegions(Genome g,
BufferedReader reader,
int parsedregionexpand,
Map<String,char[]> matchedRegions) throws IOException, NotFoundException {
String line = null;
Map<String,char[]> output = new HashMap<String,char[]>();
SequenceGenerator seqgen = new SequenceGenerator();
seqgen.useCache(true);
seqgen.useLocalFiles(true);
while ((line = reader.readLine()) != null) {
StrandedRegion region = null;
region = StrandedRegion.fromString(g, line);
if (region == null) {
Region r= Region.fromString(g,line);
if (r != null) {
region = new StrandedRegion(r ,'+');
}
}
if (region == null) {
System.err.println("Couldn't parse a region from " + line);
continue;
}
if (parsedregionexpand > 0) {
region = region.expand(parsedregionexpand,parsedregionexpand);
}
char[] chars = seqgen.execute(region).toCharArray();
if (region.getStrand() == '-' ) {
SequenceUtils.reverseComplement(chars);
}
output.put(line,chars);
if (matchedRegions != null) {
Region before = new Region(region.getGenome(),
region.getChrom(),
region.getStart() - region.getWidth(),
region.getStart());
Region after = new Region(region.getGenome(),
region.getChrom(),
region.getEnd(),
region.getEnd() + region.getWidth());
if (region.getStrand() == '-') {
Region t = before;
before = after;
after = t;
}
matchedRegions.put(before.toString(),
seqgen.execute(before).toCharArray());
matchedRegions.put(after.toString(),
seqgen.execute(after).toCharArray());
}
}
return output;
}
/** generate random genomic regions */
public static Map<String,char[]> randomRegions(Genome genome, int count, int size) {
Map<String,char[]> output = new HashMap<String,char[]>();
SequenceGenerator seqgen = new SequenceGenerator();
Map<String,Integer> chromlengthmap = genome.getChromLengthMap();
ArrayList<String> chromnames = new ArrayList<String>();
ArrayList<Integer> chromlengths = new ArrayList<Integer>();
chromnames.addAll(chromlengthmap.keySet());
long totallength = 0;
for (String s : chromnames) {
totallength += chromlengthmap.get(s);
chromlengths.add(chromlengthmap.get(s));
}
while (count-- > 0) {
long target = (long)(Math.random() * totallength);
for (int i = 0; i < chromlengths.size(); i++) {
if (i == chromlengths.size() - 1 ||
target + size < chromlengths.get(i)) {
Region r = new Region(genome,
chromnames.get(i),
(int)target,
(int)target+size);
String seq = seqgen.execute(r);
if (!seq.matches("[ACTGactgN]*")) {
count++;
break;
}
if (seq.matches(".*NNNNNNNN.*")) {
count++;
break;
}
output.put(r.toString(),seq.toCharArray() );
break;
} else {
target -= chromlengths.get(i);
if (target < 0) {
target = 10;
}
}
}
}
return output;
}
/** count the number of hits that meet the score threshold t */
public static int countMeetsThreshold(Map<String,List<WMHit>> hits,
double t) {
int count = 0;
for (String s : hits.keySet()) {
List<WMHit> list = hits.get(s);
for (int i = 0; i < list.size(); i++) {
if (list.get(i).getScore() > t) {
count++;
break;
}
}
}
return count;
}
/** mask the motifs in maskingMatrices out of both foreground and background */
public void maskSequence() {
for (WeightMatrix wm : maskingMatrices.keySet()) {
double threshold = maskingMatrices.get(wm);
for (String s : foreground.keySet()) {
maskSequence(wm, threshold,foreground.get(s));
}
for (String s : background.keySet()) {
maskSequence(wm, threshold,background.get(s));
}
}
}
/** convert all instances of wm with score > theshhold into NNNs in seq*/
public void maskSequence(WeightMatrix wm, double threshold, char[] seq) {
List<WMHit> hits = WeightMatrixScanner.scanSequence(wm,
(float)threshold,
seq);
for (WMHit hit : hits) {
for (int i = hit.start; i < hit.end; i++) {
seq[i] = 'N';
}
}
}
public static CEResult doScan(WeightMatrix matrix,
Map<String,char[]> fg,
Map<String,char[]> bg,
List<String> fgkeys,
List<String> bgkeys,
double cutoffpercent,
double filtersig,
double minfoldchange,
double minfrac,
double maxbackfrac,
PrintWriter savedatafg,
PrintWriter savedatabg,
boolean savedatahits) {
Binomial binomial = new Binomial(100, .01, RandomEngine.makeDefault());
if (fgkeys == null) {
fgkeys = new ArrayList<String>();
fgkeys.addAll(fg.keySet());
Collections.sort(fgkeys);
}
if (bgkeys == null) {
bgkeys = new ArrayList<String>();
bgkeys.addAll(bg.keySet());
Collections.sort(bgkeys);
}
Map<String,List<WMHit>> fghits = new HashMap<String,List<WMHit>>();
Map<String,List<WMHit>> bghits = new HashMap<String,List<WMHit>>();
double maxscore = matrix.getMaxScore();
for (String s : fgkeys) {
List<WMHit> hits = WeightMatrixScanner.scanSequence(matrix,
(float)(maxscore * cutoffpercent),
fg.get(s));
fghits.put(s, hits);
}
for (String s : bgkeys) {
List<WMHit> hits = WeightMatrixScanner.scanSequence(matrix,
(float)(maxscore * cutoffpercent),
bg.get(s));
bghits.put(s, hits);
}
double percent = cutoffpercent;
int fgsize = fg.size();
int bgsize = bg.size();
CEResult result = new CEResult();
double bestthresh = maxscore * percent;
result.pval = 1.0;
result.matrix = matrix;
result.logfoldchange = 0;
result.sizeone = fgsize;
result.sizetwo = bgsize;
while (percent <= 1.0) {
percent += step;
double t = maxscore * percent;
int fgcount = countMeetsThreshold(fghits, t);
int bgcount = countMeetsThreshold(bghits, t);
if (bgcount == 0) {
bgcount = 1;
}
double thetaone = ((double)fgcount) / ((double)fgsize);
double thetatwo = ((double)bgcount) / ((double)bgsize);
if (fgsize <= 0 || thetatwo <= 0 || thetatwo >= 1) {
continue;
}
binomial.setNandP(fgsize, thetatwo);
double pval = 1 - binomial.cdf(fgcount);
double fc = Math.log(thetaone / thetatwo);
if (pval <= filtersig &&
// pval <= result.pval &&
Math.abs(fc) >= Math.abs(result.logfoldchange) &&
Math.abs(fc) >= Math.abs(Math.log(minfoldchange)) &&
(thetatwo < maxbackfrac) &&
(thetaone >= minfrac || thetatwo >= minfrac)) {
bestthresh = t;
result.pval = pval;
result.percentString = Double.toString(percent);
result.cutoffString = Double.toString(t);
result.countone = fgcount;
result.counttwo = bgcount;
result.logfoldchange = fc;
result.freqone = thetaone;
result.freqtwo = thetatwo;
}
}
if (savedatafg != null) {
if (savedatahits) {
ArrayList<WMHit> toprint = new ArrayList<WMHit>();
for (String s : fgkeys) {
toprint.clear();
for (WMHit hit : fghits.get(s)) {
if (hit.getScore() >= bestthresh) {
toprint.add(hit);
}
}
savedatafg.print("\t" + toprint);
}
for (String s : bgkeys) {
toprint.clear();
for (WMHit hit : bghits.get(s)) {
if (hit.getScore() >= bestthresh) {
toprint.add(hit);
}
}
savedatabg.print("\t" + toprint);
}
} else {
for (String s : fgkeys) {
int count = 0;
for (WMHit hit : fghits.get(s)) {
if (hit.getScore() >= bestthresh) {
count++;
}
}
savedatafg.print("\t" + count);
}
for (String s : bgkeys) {
int count = 0;
for (WMHit hit : bghits.get(s)) {
if (hit.getScore() >= bestthresh) {
count++;
}
}
savedatabg.print("\t" + count);
}
}
}
return result;
}
public void parseArgs(String args[]) throws Exception {
String firstfname = null, secondfname = null;
genome = Args.parseGenome(args).cdr();
cutoffpercent = Args.parseDouble(args,"cutoff",.5);
filtersig = Args.parseDouble(args,"filtersig",.001);
minfoldchange = Args.parseDouble(args,"minfoldchange",1);
minfrac = Args.parseDouble(args,"minfrac",0);
maxbackfrac = Args.parseDouble(args,"maxbackfrac",1.0);
firstfname = Args.parseString(args,"first",null);
secondfname = Args.parseString(args,"second",null);
randombgcount = Args.parseInteger(args,"randombgcount",1000);
randombgsize = Args.parseInteger(args,"randombgsize",100);
parsedregionexpand = Args.parseInteger(args,"parsedregionexpand",0);
String savefile = Args.parseString(args,"savedata",null);
savedatahits = Args.parseFlags(args).contains("savedatahits");
matchedbg = Args.parseFlags(args).contains("matchedbg");
String outfgfile = Args.parseString(args,"outfg",null);
String outbgfile = Args.parseString(args,"outbg",null);
threads = Args.parseInteger(args,"threads",threads);
if (savefile != null) {
savedatafg = new PrintWriter(savefile + ".fg");
savedatabg = new PrintWriter(savefile + ".bg");
}
if (outfgfile != null) {
outfg = new PrintWriter(outfgfile);
}
if (outbgfile != null) {
outbg = new PrintWriter(outbgfile);
}
maskingMatrices = new HashMap<WeightMatrix,Double>();
MarkovBackgroundModel bgModel = null;
String bgmodelname = Args.parseString(args,"bgmodel","whole genome zero order");
BackgroundModelMetadata md = BackgroundModelLoader.getBackgroundModel(bgmodelname,
1,
"MARKOV",
Args.parseGenome(args).cdr().getDBID());
if (md != null) {
bgModel = BackgroundModelLoader.getMarkovModel(md);
} else {
System.err.println("Couldn't get metadata for " + bgmodelname);
}
matrices = new ArrayList<WeightMatrix>();
matrices.addAll(Args.parseWeightMatrices(args));
if (bgModel == null) {
for (WeightMatrix m : matrices) {
m.toLogOdds();
}
} else {
for (WeightMatrix m : matrices) {
m.toLogOdds(bgModel);
}
}
WeightMatrixLoader wmloader = new WeightMatrixLoader();
Collection<String> maskstrings = Args.parseStrings(args,"mask");
for (String maskstring : maskstrings) {
String[] pieces = maskstring.split(";");
String name = pieces[0];
String version = "";
for (int i = 1; i < pieces.length - 1; i++) {
version += (i > 1 ? ";" : "") + pieces[i];
}
Double threshold = Double.parseDouble(pieces[pieces.length - 1]);
Collection<WeightMatrix> matrices = wmloader.query(name,version,null);
for (WeightMatrix m : matrices) {
if (threshold < 1) {
maskingMatrices.put(m, threshold * m.getMaxScore());
} else {
maskingMatrices.put(m, threshold);
}
}
}
wmloader.close();
System.err.println("Going to scan for " + matrices.size() + " matrices");
if (matchedbg) {
System.err.println("Using matched background");
background = new HashMap<String,char[]>();
}
if (firstfname == null) {
System.err.println("No --first specified. Reading from stdin");
foreground = readRegions(genome,
new BufferedReader(new InputStreamReader(System.in)),
parsedregionexpand,
matchedbg ? background : null);
} else {
if (firstfname.matches(".*\\.fasta") ||
firstfname.matches(".*\\.fa")) {
foreground = readFasta(new BufferedReader(new FileReader(firstfname)));
} else {
foreground = readRegions(genome,
new BufferedReader(new FileReader(firstfname)),
parsedregionexpand,
matchedbg ? background : null);
}
}
if (!matchedbg) {
if (secondfname == null) {
System.err.println("No background file given. Generating " + randombgcount + " regions of size " + randombgsize);
background = randomRegions(genome,
randombgcount,
randombgsize);
} else {
if (secondfname.matches(".*\\.fasta") ||
secondfname.matches(".*\\.fa")){
background = readFasta(new BufferedReader(new FileReader(secondfname)));
} else {
background = readRegions(genome, new BufferedReader(new FileReader(secondfname)), parsedregionexpand,null);
}
}
}
}
public void saveSequences() throws IOException {
if (outfg != null) {
saveFasta(outfg, foreground);
outfg.close();
outfg = null;
}
if (outbg != null) {
saveFasta(outbg, background);
outbg.close();
outbg = null;
}
}
public void doScan() {
DecimalFormat nf = new DecimalFormat("0.000E000");
fgkeys = new ArrayList<String>();
bgkeys = new ArrayList<String>();
fgkeys.addAll(foreground.keySet());
bgkeys.addAll(background.keySet());
Collections.sort(fgkeys);
if (savedatafg != null) {
savedatafg.print("Motif");
savedatabg.print("Motif");
for (String s : fgkeys) {
savedatafg.print("\t" + s);
}
savedatafg.println();
for (String s : bgkeys) {
savedatabg.print("\t" + s);
}
savedatabg.println();
}
for (WeightMatrix matrix : matrices) {
if (savedatafg != null) {
savedatafg.print(matrix.toString());
savedatabg.print(matrix.toString());
}
CEResult result = doScan(matrix,
foreground,
background,
fgkeys,
bgkeys,
cutoffpercent,
filtersig,
minfoldchange,
minfrac,
maxbackfrac,
savedatafg,
savedatabg,
savedatahits);
if (savedatafg != null) {
savedatafg.println();
savedatabg.println();
}
if (result.pval <= filtersig &&
Math.abs(result.logfoldchange) >= Math.abs(Math.log(minfoldchange)) &&
(result.freqtwo <= maxbackfrac) &&
(result.freqone >= minfrac || result.freqtwo >= minfrac)) {
System.out.println(result.toString());
}
}
if (savedatafg != null) {
savedatafg.close();
savedatabg.close();
}
}
public static void main(String args[]) throws Exception {
CompareEnrichment ce = new CompareEnrichment();
ce.parseArgs(args);
ce.maskSequence();
ce.saveSequences();
ce.doScan();
}
}