package edu.fudan.nlp.similarity.train; import gnu.trove.iterator.TIntIterator; import java.util.Date; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.cli.BasicParser; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; /** * Brown 词聚类算法,多线程版 * @author xpqiu * @since FudanNLP 1.5 */ public class WordClusterM extends WordCluster{ private static final long serialVersionUID = 58160232476872689L; transient int numThread =4; transient private ExecutorService pool; transient float maxL; transient int maxc1; transient int maxc2; transient AtomicInteger count = new AtomicInteger(); public WordClusterM(int threads) { this.numThread = threads; pool = Executors.newFixedThreadPool(numThread); } public synchronized void getmax(float f, int i, int j){ if (f > maxL) { maxL = f; maxc1 = i; maxc2 = j; } } class Multiplesolve implements Runnable { int c1,c2; public Multiplesolve(int c1, int c2) { this.c1 = c1; this.c2 = c2; } @Override public void run() { float l= calcL(c1, c2); getmax(l,c1,c2); count.decrementAndGet(); } } /** * merge clusters */ public void mergeCluster() { maxc1 = -1; maxc2 = -1; maxL = Float.NEGATIVE_INFINITY; TIntIterator it1 = slots.iterator(); while(it1.hasNext()){ int i = it1.next(); TIntIterator it2 = slots.iterator(); // System.out.print(i+": "); while(it2.hasNext()){ int j= it2.next(); if(i>=j) continue; // System.out.print(j+" "); Multiplesolve c = new Multiplesolve(i,j); count.incrementAndGet(); pool.execute(c); } // System.out.println(); } while(count.get()!=0){//等待所有子线程执行完 try { Thread.sleep(slotsize*slotsize/1000); } catch (InterruptedException e) { e.printStackTrace(); } } merge(maxc1,maxc2); } /** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { /** * 分析命令参数 */ Options opt = new Options(); opt.addOption("path", true, "保存路径"); opt.addOption("res", true, "评测结果保存路径"); opt.addOption("slot", true, "槽大小"); opt.addOption("thd", true, "线程个数"); BasicParser parser = new BasicParser(); CommandLine cl; try { cl = parser.parse(opt, args); } catch (Exception e) { System.err.println("Parameters format error"); return; } int threads = Integer.parseInt(cl.getOptionValue("thd", "3")); System.out.println("线程数量:"+threads); int slotsize = Integer.parseInt(cl.getOptionValue("slot", "20")); System.out.println("槽大小:"+slotsize); String file = cl.getOptionValue("path", "./tmp/SogouCA.mini.txt"); System.out.println("数据路径:"+file); String resfile = cl.getOptionValue("res", "./tmp/cluster.txt"); System.out.println("测试结果:"+resfile); long starttime = System.currentTimeMillis(); SougouCA sca = new SougouCA(file); WordClusterM wc = new WordClusterM(threads); wc.slotsize = slotsize; wc.read(sca); wc.startClustering(); wc.saveModel(resfile+".m"); wc.saveTxt(resfile); wc = (WordClusterM) WordCluster.loadFrom(resfile+".m"); wc.saveTxt(resfile+"1"); long endtime = System.currentTimeMillis(); System.out.println("Total Time:"+(endtime-starttime)/60000); System.out.println("Done"); System.exit(0); } }