package org.seqcode.motifs;
import java.util.*;
import java.io.*;
import java.sql.*;
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.awt.RenderingHints;
import java.awt.Color;
import cern.jet.random.Binomial;
import cern.jet.random.engine.RandomEngine;
import java.awt.image.BufferedImage;
import javax.imageio.ImageIO;
import org.seqcode.data.connections.DatabaseException;
import org.seqcode.data.connections.DatabaseFactory;
import org.seqcode.data.connections.UnknownRoleException;
import org.seqcode.data.motifdb.*;
import org.seqcode.genome.Genome;
import org.seqcode.genome.location.Region;
import org.seqcode.genome.sequence.SequenceGenerator;
import org.seqcode.gsebricks.verbs.*;
import org.seqcode.gseutils.*;
public class DiscriminativeKmers {
private Genome genome;
private double minfoldchange;
private int k, mask, maxmismatch, minclustersize, minclustercount;
private Binomial binomial;
private SequenceGenerator seqgen;
private int randombgcount = 1000, // number of random background regions to pic
randombgsize = 100, // size of random background regions
parsedregionexpand; // expand input regions by this much on either side
private Map<String,char[]> foreground, background; // foreground and background sequences
private boolean printKmers;
private List<WeightMatrix> pwms;
private String outbase;
private final static long intmask = 0xffffffffL;
private final static int maxshift = 3;
private final static char[] toChar = {'A','C','G','T'};
public static long charsToLong(char[] chars) {
long out = 0;
for (int i = 0; i < chars.length; i++) {
out <<= 2;
char newchar = chars[i];
switch (newchar) {
case 'A':
case 'a':
out += 0;
break;
case 'C':
case 'c':
out += 1;
break;
case 'G':
case 'g':
out += 2;
break;
case 'T':
case 't':
out += 3;
break;
default:
break;
}
}
return out;
}
/** adds another character to the end (right side) of the long representation
* of a kmer. mask is the bitmask of usable bits in the long rep.
*/
public static long addChar(long existing, long mask, char newchar) {
existing <<= 2;
switch (newchar) {
case 'A':
case 'a':
existing += 0;
break;
case 'C':
case 'c':
existing += 1;
break;
case 'G':
case 'g':
existing += 2;
break;
case 'T':
case 't':
existing += 3;
break;
default:
break;
}
return existing & mask;
}
/** converts a long representation of a kmer back to a string.
*/
public static String longToString(long l, int k) {
return new String(longToChars(l,k));
}
public static char[] longToChars(long l, int k) {
char[] chars = new char[k];
while (k-- > 0) {
int b = (int)(l & 3);
l >>= 2;
chars[k] = toChar[b];
}
return chars;
}
public static long reverseComplement(long kmer,
int k) {
long out = 0;
for (int i = 0; i < k; i++) {
byte b = (byte)((kmer ^ 3) & 3);
kmer >>= 2;
out = (out << 2) | b;
}
return out;
}
/** returns a map from kmers to their frequency
* in the string.
*/
public static Map<Long,Integer> count(char[] chars,
int k,
long mask,
Map<Long,Integer> map) {
if (map == null) {
map = new HashMap<Long,Integer>();
}
long l = 0;
for (int i = 0; i < k-1; i++) {
l = addChar(l,mask,chars[i]);
}
for (int i = k; i < chars.length; i++) {
l = addChar(l,mask,chars[i]);
if (!map.containsKey(l)) {
map.put(l,1);
} else {
map.put(l,map.get(l)+1);
}
}
return map;
}
/** return the number of bases that two kmers have in common.
*/
public static short countBasesSame(long a, long b, int k) {
long x = a ^ b;
short same = 0;
while (k-- > 0) {
if ((x & 3) == 0) {
same++;
}
x >>= 2;
}
return same;
}
/** returns a long describing the best match between a and b. The
* upper int is the shift- the offset at which a best matched b.
* Positive shift means to shift a to the right such that the last
* characters of a and first characters of b are overhanging.The
* lower int is the number of bases in common.
*/
public static int countBasesSameOneDir(long a, long b, int k, int maxShift) {
short bestSame = countBasesSame(a,b,k);
short bestPos = 0;
for (short l = 1; l <= maxShift; l++) {
long newa = a >> l*2;
short s = countBasesSame(newa,b,k-l);
if (s > bestSame) {
bestSame = s;
bestPos = l;
}
}
for (short r = 1; r <= maxShift; r++) {
long newb = b >> r*2;
short s = countBasesSame(a,newb,k-r);
if (s > bestSame) {
bestSame = s;
bestPos = (short)(-1 * r);
}
}
int ret = bestPos;
ret = (ret << 16) | bestSame;
return ret;
}
/** first short is zero
second short is one if reverse-complement was best
third short is offset
fourth short (LSB) is score
*/
public static long countBasesSame(long a, long b, int k, int maxShift) {
long forw = countBasesSameOneDir(a,b,k,maxShift);
long rc = countBasesSameOneDir(reverseComplement(a,k),b,k,maxShift);
if ((forw & 0xffff) > (rc & 0xffff)) {
return (forw << 1);
} else {
return (rc << 1) | 1;
}
}
public static short getRC(long l) {
return (short)(l & 1);
}
public static short getSameness(long l) {
return (short)((l >> 1) & 0xffff);
}
public static short getShift(long l) {
return (short)((l >> 17) & 0xffff);
}
public static void paintMotif(WeightMatrix wm, String fname) throws IOException {
File f = new File(fname);
BufferedImage im =
new BufferedImage(800, 200, BufferedImage.TYPE_INT_RGB);
Graphics2D g2 = im.createGraphics();
g2.setRenderingHints(new RenderingHints(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_ON));
WeightMatrixPainter wmp = new WeightMatrixPainter();
g2.setColor(Color.WHITE);
g2.fillRect(0,0,800,200);
wmp.paint(wm,g2,0,0,800,200);
ImageIO.write(im, "png", f);
}
public WeightMatrix toWeightMatrix(int sums[], int counts[][]) {
double[] bgmodel = new double[4];
for (String s : background.keySet()) {
char[] chars = background.get(s);
for (int i = 0; i < chars.length; i++) {
switch (chars[i]) {
case 'A':
case 'a':
bgmodel[0]++;
break;
case 'C':
case 'c':
bgmodel[1]++;
break;
case 'G':
case 'g':
bgmodel[2]++;
break;
case 'T':
case 't':
bgmodel[3]++;
break;
default:
break;
}
}
}
double bgsum = bgmodel[0] + bgmodel[1] + bgmodel[2] + bgmodel[3];
for (int i = 0; i < bgmodel.length; i++) {
bgmodel[i] = bgmodel[i] / bgsum;
}
WeightMatrix out = new WeightMatrix(sums.length);
for (int i = 0; i < sums.length; i++) {
out.matrix[i]['A'] = (float)Math.log( ((double)counts[i][0] / (double)sums[i]) / bgmodel[0] );
out.matrix[i]['C'] = (float)Math.log( ((double)counts[i][1] / (double)sums[i]) / bgmodel[1] );
out.matrix[i]['G'] = (float)Math.log( ((double)counts[i][2] / (double)sums[i]) / bgmodel[2] );
out.matrix[i]['T'] = (float)Math.log( ((double)counts[i][3] / (double)sums[i]) / bgmodel[3] );
}
return out;
}
public List<KmerCluster> cluster(List<KmerCount> kmers, int k, int maxmismatch) {
Collections.sort(kmers, new KmerCountComparator());
List<KmerCluster> clusters = new ArrayList<KmerCluster>();
clusters.add(new KmerCluster(kmers.remove(0)));
for (int i = 0; i < kmers.size(); i++) {
KmerCount kc = kmers.get(i);
int bestCluster = -1, bestSameness = -1;
for (int j = 0; j < clusters.size(); j++) {
KmerCluster cluster = clusters.get(j);
long cbs = countBasesSame(kc.kmer, cluster.centroid, k, maxshift);
short sameness = getSameness(cbs);
if (sameness > bestSameness) {
bestSameness = sameness;
bestCluster = j;
}
}
if (bestSameness > k - maxmismatch) {
clusters.get(bestCluster).members.add(kc);
} else {
clusters.add(new KmerCluster(kc));
}
}
Collections.sort(clusters, new KmerClusterComparator());
return clusters;
}
public void printClusters (List<KmerCluster> clusters) {
int clusterNumber = 0;
pwms = new ArrayList<WeightMatrix>();
for (KmerCluster cluster : clusters) {
if (cluster.totalCount() < minclustercount) {
continue;
}
if (cluster.members.size() < minclustersize) {
continue;
}
int sums[] = new int[k + 2*maxshift];
int counts[][] = new int[k + 2*maxshift][4];
for (int j = 0; j < sums.length; j++) {
sums[j] = 4;
for (int i = 0; i < 4; i++) {
counts[j][i] = 1;
}
}
for (KmerCount kc : cluster.members) {
long cbs = countBasesSame(kc.kmer, cluster.centroid, k, maxshift);
short sameness = getSameness(cbs);
short shift = getShift(cbs);
short rc = getRC(cbs);
long kmer = kc.kmer;
if (rc == 1) {
kmer = reverseComplement(kmer,k);
// shift = (short)(shift * -1);
}
String kmerString = longToString(kmer,k);
int firstbase = maxshift + shift;
for (int i = 1; i <= k; i++) {
int pos = firstbase + k - i;
sums[pos] += kc.count;
counts[pos][(int)(kmer & 3L)] += kc.count;
kmer >>= 2;
}
if (printKmers) {
StringBuilder sb = new StringBuilder(" ");
for (int i = 0; i < firstbase; i++) {
sb.append(" ");
}
sb.append(kmerString);
for (int i = 0; i < (2*maxshift-firstbase); i++) {
sb.append(" ");
}
sb.append("\t" + kc.count);
System.out.println(sb.toString());
}
}
WeightMatrix wm = toWeightMatrix(sums, counts);
wm.name = outbase + "_" + clusterNumber;
wm.version = String.format("mfc %.2f expand %d size %d count %d",
minfoldchange, parsedregionexpand, minclustersize, minclustercount);
wm.type = "DiscriminativeKmers";
CEResult scanresult = CompareEnrichment.doScan(wm,
foreground,
background,
null,null,
.1, .01, 2, .05, 1, null, null, false);
if (scanresult.freqone > .1) {
System.out.println("Cluster centroid " + longToString(cluster.centroid,k) + " and total count " + cluster.totalCount());
for (int i = 0; i < 4; i++) {
System.out.print(toChar[i]);
for (int j = 0; j < sums.length; j++) {
double s = sums[j];
double c = counts[j][i];
System.out.print(String.format("\t%.2f",c/s));
}
System.out.println();
}
System.out.println(scanresult.toString());
try {
paintMotif(wm, outbase + clusterNumber++ + ".png");
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
public DiscriminativeKmers () {
seqgen = new SequenceGenerator();
binomial = new Binomial(100, .01, RandomEngine.makeDefault());
}
public void setK(int k) {
this.k = k;
mask = 0;
for (int i = 0; i < k; i++) {
mask <<= 2;
mask = mask | 3;
}
}
public void parseArgs(String args[]) throws NotFoundException, IOException, FileNotFoundException {
int k = Args.parseInteger(args,"k",10);
setK(k);
printKmers = Args.parseFlags(args).contains("printkmers");
minfoldchange = Args.parseDouble(args,"minfoldchange",1);
parsedregionexpand = Args.parseInteger(args,"expand",30);
randombgcount = Args.parseInteger(args,"randombgcount",1000);
randombgsize = Args.parseInteger(args,"randombgsize",100);
maxmismatch = Args.parseInteger(args,"maxmismatch",3);
minclustersize = Args.parseInteger(args,"minclustersize",2);
minclustercount = Args.parseInteger(args,"minclustercount",30);
outbase = Args.parseString(args,"outbase","motif");
genome = Args.parseGenome(args).cdr();
String firstfname = null, secondfname = null;
firstfname = Args.parseString(args,"first",null);
secondfname = Args.parseString(args,"second",null);
if (firstfname == null) {
System.err.println("No --first specified. Reading from stdin");
foreground = CompareEnrichment.readRegions(genome,
new BufferedReader(new InputStreamReader(System.in)),
parsedregionexpand,
null);
} else {
if (firstfname.matches(".*\\.fasta") ||
firstfname.matches(".*\\.fa")) {
foreground = CompareEnrichment.readFasta(new BufferedReader(new FileReader(firstfname)));
} else {
foreground = CompareEnrichment.readRegions(genome,
new BufferedReader(new FileReader(firstfname)),
parsedregionexpand,
null);
}
}
if (secondfname == null) {
System.err.println("No background file given. Generating " + randombgcount + " regions of size " + randombgsize);
background = CompareEnrichment.randomRegions(genome,
randombgcount,
randombgsize);
} else {
if (secondfname.matches(".*\\.fasta") ||
secondfname.matches(".*\\.fa")){
background = CompareEnrichment.readFasta(new BufferedReader(new FileReader(secondfname)));
} else {
background = CompareEnrichment.readRegions(genome, new BufferedReader(new FileReader(secondfname)), parsedregionexpand,null);
}
}
}
public void run() {
Map<Long,Integer> fgcounts = new HashMap<Long,Integer>();
Map<Long,Integer> bgcounts = new HashMap<Long,Integer>();
for (char[] chars : foreground.values()) {
count(chars, k, mask, fgcounts);
}
for (char[] chars : background.values()) {
count(chars, k, mask, bgcounts);
}
int fgsum = 0;
int bgsum = 0;
for (long l : fgcounts.keySet()) {
fgsum += fgcounts.get(l);
}
for (long l : bgcounts.keySet()) {
bgsum += bgcounts.get(l);
}
System.err.println("Read " + fgsum + " kmers from the fg set and " + bgsum + " from the background set");
List<KmerCount> enriched = new ArrayList<KmerCount>();
for (long l : fgcounts.keySet()) {
int bgcount = (bgcounts.containsKey(l) ? bgcounts.get(l) : 0) + 1;
double bgprob = ((double)bgcount) / ((double)bgsum);
int fgcount = fgcounts.get(l);
double fgprob = ((double)fgcount) / ((double)fgsum);
binomial.setNandP(fgsum, bgprob);
double pval = Math.log(1 - binomial.cdf(fgcount));
if (fgprob > bgprob * minfoldchange) {
KmerCount kc = new KmerCount(l,fgcount);
// before clustering, change count from absolute counts to count above background
kc.count = (int)(kc.count - bgprob * fgsum);
if (kc.count > 0) {
enriched.add(kc);
}
}
}
List<KmerCluster> clusters = cluster(enriched, k, maxmismatch);
printClusters(clusters);
}
public static void main(String args[]) throws Exception {
DiscriminativeKmers kmers = new DiscriminativeKmers();
kmers.parseArgs(args);
kmers.run();
}
}
class KmerCount {
public long kmer;
public int count;
public KmerCount(long l, int c) {kmer = l; count = c;}
}
class KmerCluster {
public long centroid;
public List<KmerCount> members;
public KmerCluster(KmerCount kc) {
centroid = kc.kmer;
members = new ArrayList<KmerCount>();
members.add(kc);
}
public int totalCount() {
int c = 0;
for (KmerCount kc : members) {
c += kc.count;
}
return c;
}
}
class KmerCountComparator implements Comparator<KmerCount> {
public int compare(KmerCount a, KmerCount b) {
return b.count - a.count;
}
}
class KmerClusterComparator implements Comparator<KmerCluster> {
public int compare(KmerCluster a, KmerCluster b) {
return b.totalCount() - a.totalCount();
}
}