package edu.fudan.ml.classifier.linear; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; import edu.fudan.ml.classifier.AbstractClassifier; import edu.fudan.ml.classifier.LabelParser; import edu.fudan.ml.classifier.LabelParser.Type; import edu.fudan.ml.classifier.Predict; import edu.fudan.ml.classifier.linear.inf.Inferencer; import edu.fudan.ml.types.Instance; import edu.fudan.ml.types.alphabet.AlphabetFactory; import edu.fudan.nlp.pipe.Pipe; import edu.fudan.util.exception.LoadModelException; /** * 线性分类器 * * @author xpqiu * */ public class Linear extends AbstractClassifier implements Serializable { private static final long serialVersionUID = -2626247109469506636L; protected Inferencer inferencer; protected AlphabetFactory factory; protected Pipe pipe; public Linear(Inferencer inferencer, AlphabetFactory factory) { this.inferencer = inferencer; this.factory = factory; } public Linear() { } public Predict classify(Instance instance, int n) { return (Predict) inferencer.getBest(instance, n); } @Override public Predict classify(Instance instance, Type t, int n) { Predict res = (Predict) inferencer.getBest(instance, n); return LabelParser.parse(res,factory.DefaultLabelAlphabet(),t); } /** * 得到类标签 * @param idx 类标签对应的索引 * @return */ public String getLabel(int idx) { return factory.DefaultLabelAlphabet().lookupString(idx); } /** * 将分类器保存到文件 * @param file * @throws IOException */ public void saveTo(String file) throws IOException { File f = new File(file); File path = f.getParentFile(); if(!path.exists()){ path.mkdirs(); } ObjectOutputStream out = new ObjectOutputStream(new GZIPOutputStream( new BufferedOutputStream(new FileOutputStream(file)))); out.writeObject(this); out.close(); } /** * 从文件读入分类器 * @param file * @return * @throws LoadModelException */ public static Linear loadFrom(String file) throws LoadModelException{ Linear cl = null; try { ObjectInputStream in = new ObjectInputStream(new GZIPInputStream( new BufferedInputStream(new FileInputStream(file)))); cl = (Linear) in.readObject(); in.close(); } catch (Exception e) { throw new LoadModelException(e,file); } return cl; } public Inferencer getInferencer() { return inferencer; } public void setInferencer(Inferencer inferencer) { this.inferencer = inferencer; } public AlphabetFactory getAlphabetFactory() { return factory; } public void setWeights(float[] weights) { inferencer.setWeights(weights); } public float[] getWeights() { return inferencer.getWeights(); } public void setPipe(Pipe pipe) { this.pipe = pipe; } public Pipe getPipe() { return pipe; } }