package gov.nih.ncgc.bard.tools; import gov.nih.ncgc.search.SearchCallback; import gov.nih.ncgc.search.SearchParams; import gov.nih.ncgc.search.SearchService2; import java.io.PrintWriter; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; import java.util.concurrent.PriorityBlockingQueue; import java.util.logging.Logger; import chemaxon.struc.Molecule; public class OrderedSearchResultHandler implements SearchCallback<SearchService2.MolEntry> { static final Logger logger = Logger.getLogger (OrderedSearchResultHandler.class.getName()); List<Long> cids; PrintWriter pw; SearchParams params; public boolean debug = false; //Stores a canonical ordered list of MolEntries for buffer PriorityBlockingQueue<SearchService2.MolEntry> MolBuffer = new PriorityBlockingQueue<SearchService2.MolEntry>(); volatile double[] minRank = null; //minimum returned rank per partition //(each partition returns in descending order) volatile Boolean[] wLock = new Boolean[]{true}; volatile Boolean[] tLock = new Boolean[]{true}; double lastRank = Double.POSITIVE_INFINITY; Integer returnCap = Integer.MAX_VALUE; Integer returnStart = 0; int returned = 0; boolean theEnd = false; boolean finished = false; //Consumer thread private Future t; public OrderedSearchResultHandler (SearchParams params, PrintWriter pw, Integer start, Integer resultCap) { cids = new ArrayList<Long>(); this.params = params; this.pw = pw; if (start != null) returnStart = start; if (resultCap != null) this.returnCap = resultCap; } public void start (ExecutorService threadPool) { t = threadPool.submit(new Consumer ()); } private void initialize(int l) { minRank = new double[l]; for (int i = 0; i < l; i++) { minRank[i] = Double.POSITIVE_INFINITY; } } /** * SearchCallback interface. Note that this method is called from * multiple thread, so it should be thread safe! */ public boolean matched(SearchService2.MolEntry entry) { //logger.info("thread:" + entry.getPartitionSig() + " of " + entry.getPartitionCount()); if (theEnd) return false; MolBuffer.add(entry); //if(true)return true; if (minRank == null) { synchronized (wLock) { if (minRank == null) { initialize(entry.getPartitionCount()); } } } minRank[entry.getPartitionSig()] = entry.getRank(); return true; } public boolean consumeMol(SearchService2.MolEntry entry) { int[][] hits = entry.getAtomMappings(); Molecule mol = entry.getMol().cloneMolecule(); switch (params.getType()) { case Substructure: { for (int[] h : hits) { for (int i = 0; i < h.length; ++i) { if (h[i] >= 0) { mol.getAtom(h[i]).setAtomMap(i + 1); } } } mol.setProperty("HIGHLIGHT", mol.toFormat("smiles:q")); } break; case Superstructure: { for (int[] h : hits) { for (int i = 0; i < h.length; ++i) { if (h[i] >= 0) { mol.getAtom(i).setAtomMap(h[i] + 1); } } } mol.setProperty("HIGHLIGHT", mol.toFormat("smiles:q")); } break; } mol.setProperty("SIMILARITY", String.format("%1$.3f", entry.getSimilarity())); mol.setProperty("RANKING", String.format("%1$.3f", entry.getRank())); writeOutput(mol); return true; } double calculateSafeRank() { if (finished) return Double.NEGATIVE_INFINITY; if (minRank == null) return Double.POSITIVE_INFINITY; double max = Double.NEGATIVE_INFINITY; for (int i = 0; i < minRank.length; i++) { if (minRank[i] > max) max = minRank[i]; } return max; } synchronized void writeOutput(Molecule mol) { String highlight = mol.getProperty("HIGHLIGHT"); pw.println(mol.getName()+(highlight != null ? ("\t"+highlight) : "")); cids.add(Long.valueOf(mol.getName())); } // TODO this throws a concurrent modification exception (sometimes) synchronized public List<Long> getCids() { return cids; } public void complete() { if (minRank != null) { for (int i = 0; i < minRank.length; i++) { minRank[i] = Double.NEGATIVE_INFINITY; } } finished = true; try { // wait t.get(); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } } class Consumer implements Runnable { public void consume() { while (!theEnd() && !MolBuffer.isEmpty()) { try { SearchService2.MolEntry m = MolBuffer.take(); if (m.getRank() > lastRank) { if (returned >= returnStart) consumeMol(m); returned++; } else { MolBuffer.add(m); return; } } catch (Exception e) { e.printStackTrace(); return; } } } private boolean moreToConsume() { double safeRank = calculateSafeRank(); if (safeRank < lastRank) { lastRank = safeRank; return true; } return false; } private boolean theEnd() { theEnd = ((returned - returnStart) >= returnCap); return theEnd || (MolBuffer.isEmpty() && finished); } public void run() { logger.info(">>> "+Thread.currentThread().getName()); long start = System.currentTimeMillis(); while (!theEnd()) { synchronized (tLock) { while (!moreToConsume()) { try { tLock.wait(50); } catch (InterruptedException e) { Thread.interrupted(); } } consume(); } } logger.info("<<< "+Thread.currentThread().getName() +" "+cids.size()+" " +(System.currentTimeMillis()-start)+"ms"); } } }