package edu.fudan.ml.classifier.hier; import java.io.BufferedWriter; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.io.UnsupportedEncodingException; import edu.fudan.ml.types.alphabet.AlphabetFactory; import edu.fudan.ml.types.alphabet.IFeatureAlphabet; import edu.fudan.ml.types.alphabet.LabelAlphabet; import edu.fudan.ml.types.alphabet.StringFeatureAlphabet; import edu.fudan.ml.types.sv.HashSparseVector; import edu.fudan.util.MyHashSparseArrays; import gnu.trove.iterator.TIntFloatIterator; import gnu.trove.iterator.TObjectIntIterator; import gnu.trove.map.hash.TIntObjectHashMap; /** * 优化模型文件,去掉无用的特征 * 权重向量为HashSparseVector[] * @since FudanNLP 1.0 * @author xpqiu * */ public class ModelAnalysis { private Linear cl; public AlphabetFactory factory; private float thresh = 0; HashSparseVector[] weights; private IFeatureAlphabet feature; private LabelAlphabet label; public ModelAnalysis(Linear cl) { this.cl = cl; this.factory = cl.factory; feature = factory.DefaultFeatureAlphabet(); label = factory.DefaultLabelAlphabet(); this.weights = cl.weights; } /** * 统计信息,计算删除非0特征后,权重的长度 * * @throws IOException */ public void removeZero() { boolean freeze = false; if (feature.isStopIncrement()) { feature.setStopIncrement(false); freeze = true; } TIntObjectHashMap<String> index = (TIntObjectHashMap<String>) feature.toInverseIndexMap(); System.out.println("原字典大小"+index.size()); System.out.println("原字典大小"+feature.size()); StringFeatureAlphabet newfeat = new StringFeatureAlphabet(); cl.factory.setDefaultFeatureAlphabet(newfeat); for(int i=0;i<weights.length;i++){ TIntFloatIterator itt = weights[i].data.iterator(); HashSparseVector ww = new HashSparseVector(); while(itt.hasNext()){ itt.advance(); float v = itt.value(); if(Math.abs(v)<1e-3f) continue; String fea = index.get(itt.key()); int newidx = newfeat.lookupIndex(fea); ww.put(newidx, v); } weights[i] = ww; } newfeat.setStopIncrement(freeze); System.out.println("新字典大小"+newfeat.size()); System.out.println("新字典大小"+feature.size()); index.clear(); } /** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { String file = "./tmp/model/tree_model.gz"; Linear cl = Linear.loadFrom(file); ModelAnalysis ma = new ModelAnalysis(cl); ma.getSalientFeatures("./tmp/model/tree_model",100); // ma.removeZero(); // cl.saveTo(file+1); System.out.print("Done"); } private void getSalientFeatures(String string, int topn) throws IOException { PrintWriter pw = new PrintWriter(new OutputStreamWriter( new FileOutputStream(string), "UTF-8")); TIntObjectHashMap<String> index = (TIntObjectHashMap<String>) feature.toInverseIndexMap(); for(int i=0;i<weights.length;i++){ int[] idx = MyHashSparseArrays.sort(weights[i].data); pw.println(label.lookupString(i)); for(int j=0;j<topn;j++){ pw.print(index.get(idx[j])); pw.print("\t"); pw.println(weights[i].get(idx[j])); } pw.println(); } pw.close(); } }