package org.seqcode.motifs; import java.util.*; import java.io.*; import org.seqcode.genome.Genome; import org.seqcode.genome.GenomeConfig; import org.seqcode.genome.location.Point; import org.seqcode.genome.location.Region; import org.seqcode.genome.sequence.SequenceGenerator; import org.seqcode.genome.sequence.SequenceUtils; import org.seqcode.gsebricks.verbs.location.PointParser; import org.seqcode.gsebricks.verbs.location.RegionParser; import org.seqcode.gseutils.*; /** Use: * java org.seqcode.motifs.CountKmers --species "$SC;sacCer3" --mink 4 --maxk 6 [--outputcounts] [--includerc] [--topn 100] * * outputcounts: output counts instead of frequencies * includerc: include counts from reverse complement strand too * topn: only output the top N kmers instead of all of them. * table: output frequencies per region * */ public class CountKmers { private int mink, maxk; private Genome genome; private GenomeConfig gconfig; private List<Region> regions; private int win; private List<int[]> counts; private SequenceGenerator seqgen; private boolean outputCounts, includeRC; private int topN; private boolean fullTable=false; /* use this if you're going to feed in sequences, but not Regions */ public void init(int mink, int maxk) { this.mink = mink; this.maxk = maxk; counts = new ArrayList<int[]>(); for (int i = 0; i <= maxk; i++) { int[] l = new int[(4 << ((i-1) * 2))]; for (int j = 0; j < l.length; j++) { l[j] = 0; } counts.add(l); } } public void parseArgs(String args[]) throws NotFoundException { gconfig = new GenomeConfig(args); Genome genome = gconfig.getGenome(); win = Args.parseInteger(args,"win",-1); String regFile = Args.parseString(args, "regions", null); regions = loadRegionsFromFile(regFile, genome, win); outputCounts = Args.parseFlags(args).contains("outputcounts"); includeRC = Args.parseFlags(args).contains("includerc"); fullTable = Args.parseFlags(args).contains("table"); topN = Args.parseInteger(args,"topn",-1); seqgen = gconfig.getSequenceGenerator(); for(Region r : regions){ System.out.println(r.getLocationString()+"\t"+r.getWidth()); } init(Args.parseInteger(args,"mink",1), Args.parseInteger(args,"maxk",4)); } /** * Loads a set of regions from the third or first column of a file * (Suitable for GPS & StatisticalPeakFinder files * @param filename String * @param win integer width of region to impose (-1 leaves region width alone) * @return */ public static List<Region> loadRegionsFromFile(String filename, Genome gen, int win){ List<Region> regs = new ArrayList<Region>(); try{ File pFile = new File(filename); if(!pFile.isFile()){System.err.println("Invalid file name: "+filename);System.exit(1);} BufferedReader reader = new BufferedReader(new FileReader(pFile)); String line; while ((line = reader.readLine()) != null) { line = line.trim(); String[] words = line.split("\\s+"); if(words.length>0 && !words[0].contains("#") && !words[0].equals("Region") && !words[0].equals("Position")){ if(words.length>=3 && words[2].contains(":")){ PointParser pparser = new PointParser(gen); Point p = pparser.execute(words[2]); if(win==-1 && words[0].contains(":") && words[0].contains("-")){ RegionParser rparser = new RegionParser(gen); Region q = rparser.execute(words[0]); regs.add(q); }else{ regs.add(p.expand(win/2)); } }else if(words.length>=1 && words[0].contains(":")){ String[] coords = words[0].split(":"); if(coords[1].contains("-")){ RegionParser rparser = new RegionParser(gen); Region q = rparser.execute(words[0]); if(win==-1){ if(q!=null){regs.add(q);} }else regs.add(q.getMidpoint().expand(win/2)); }else{ PointParser pparser = new PointParser(gen); Point p = pparser.execute(words[0]); regs.add(p.expand(win/2)); } } } }reader.close(); } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } return(regs); } public void addToCounts(Region r) { String s = seqgen.execute(r); addToCounts(s); if (includeRC) { addToCounts(SequenceUtils.reverseComplement(s)); } } public void addToCounts(String s) { char[] chars = s.toCharArray(); for (int i = 0; i < chars.length; i++) { if (chars[i] == 'A' || chars[i] == 'a') { chars[i] = 0; } else if (chars[i] == 'C' || chars[i] == 'c') { chars[i] = 1; } else if (chars[i] == 'G' || chars[i] == 'g') { chars[i] = 2; } else { chars[i] = 3; } } for (int k = mink; k <= maxk; k++) { int[] l = counts.get(k); for (int i = 0; i < chars.length - k; i++) { int index = 0; for (int j = 0; j < k; j++) { char c = chars[i + j]; if (c > 3) { i += k; break; } index = (index << 2) + c; } l[index]++; } } } public String indexToString(int index, int k) { char[] out = new char[k]; int pos = k - 1; while (pos >= 0) { out[pos--] = (char)(index & 3); index >>= 2; } for (int i = 0; i < out.length; i++) { if (out[i] == 0) { out[i] = 'A'; } else if (out[i] == 1) { out[i] = 'C'; } else if (out[i] == 2) { out[i] = 'G'; } else { out[i] = 'T'; } } return new String(out); } public Set<String> getKeySet(int k) { HashSet<String> s = new HashSet<String>(); int[] l = counts.get(k); for (int i = 0; i < l.length; i++) { if (l[i] > 0) { s.add(indexToString(i,k)); } } return s; } public Map<String,Integer> getCounts(int k) { int[] l = counts.get(k); Map<String,Integer> output = new HashMap<String,Integer>(); for (int i = 0; i < l.length; i++) { if (l[i] > 0) { output.put(indexToString(i,k), l[i]); } } return output; } public int getCount(String key, int k) throws NumberFormatException { char[] chars = key.toCharArray(); for (int i = 0; i < chars.length; i++) { if (chars[i] == 'A' || chars[i] == 'a') { chars[i] = 0; } else if (chars[i] == 'C' || chars[i] == 'c') { chars[i] = 1; } else if (chars[i] == 'G' || chars[i] == 'g') { chars[i] = 2; } else { chars[i] = 3; } } int index = 0; for (int j = 0; j < k; j++) { char c = chars[j]; if (c > 3) { throw new NumberFormatException("Invalid character in " + key); } index = (index << 2) + c; } return counts.get(k)[index]; } public int getMinCount(int k) { if (topN < 0) { return 0; } ArrayList<Integer> list = new ArrayList<Integer>(); int[] array = counts.get(k); for (int i = 0; i < array.length; i++) { list.add(array[i]); } Collections.sort(list); if (topN == 0) { return list.get(list.size() - 1) + 1; } return list.get(list.size() - topN); } public void print(PrintWriter pw) { System.err.println(String.format("Printing from %d to %d", mink, maxk)); for (int k = mink; k <= maxk; k++) { int[] l = counts.get(k); int minCount = getMinCount(k); System.err.println(String.format("Length at %d is %d",k,l.length)); long total = 0; for (int i = 0; i < l.length; i++) { total += l[i]; } for (int i = 0; i < l.length; i++) { String key = indexToString(i,k); if (l[i] < minCount) { continue; } if (outputCounts) { pw.println(key + "\t" + l[i]); } else { pw.println(key + "\t" + (((double)l[i]) / ((double)total))); } } } } public static void main(String args[]) throws Exception { CountKmers counter = new CountKmers(); counter.parseArgs(args); for (Region r : counter.regions) { counter.addToCounts(r); } PrintWriter pw = new PrintWriter(System.out); counter.print(pw); pw.close(); } }