package edu.fudan.example.ml; import java.io.File; import edu.fudan.data.reader.SimpleFileReader; import edu.fudan.ml.classifier.linear.Linear; import edu.fudan.ml.classifier.linear.OnlineTrainer; import edu.fudan.ml.classifier.linear.inf.Inferencer; import edu.fudan.ml.classifier.linear.inf.LinearMax; import edu.fudan.ml.classifier.linear.update.LinearMaxPAUpdate; import edu.fudan.ml.feature.Generator; import edu.fudan.ml.feature.SFGenerator; import edu.fudan.ml.loss.ZeroOneLoss; import edu.fudan.ml.types.InstanceSet; import edu.fudan.ml.types.alphabet.AlphabetFactory; import edu.fudan.ml.types.alphabet.IFeatureAlphabet; import edu.fudan.ml.types.alphabet.LabelAlphabet; import edu.fudan.nlp.pipe.StringArray2IndexArray; import edu.fudan.nlp.pipe.Pipe; import edu.fudan.nlp.pipe.SeriesPipes; import edu.fudan.nlp.pipe.Target2Label; /** * 线性分类器使用示例 * * @author xpqiu * */ public class SimpleClassifier2 { static InstanceSet train; static InstanceSet test; static AlphabetFactory factory = AlphabetFactory.buildFactory(); static LabelAlphabet al = factory.DefaultLabelAlphabet(); static IFeatureAlphabet af = factory.DefaultFeatureAlphabet(); static String path = null; public static void main(String[] args) throws Exception { long start = System.currentTimeMillis(); path = "./example-data/data-classification.txt"; Pipe lpipe = new Target2Label(al); Pipe fpipe = new StringArray2IndexArray(factory, true); //构造转换器组 Pipe pipe = new SeriesPipes(new Pipe[]{lpipe,fpipe}); //构建训练集 train = new InstanceSet(pipe, factory); SimpleFileReader reader = new SimpleFileReader (path,true); train.loadThruStagePipes(reader); al.setStopIncrement(true); //构建测试集 test = new InstanceSet(pipe, factory); reader = new SimpleFileReader (path,true); test.loadThruStagePipes(reader); System.out.println("Train Number: " + train.size()); System.out.println("Test Number: " + test.size()); System.out.println("Class Number: " + al.size()); float c = 1.0f; int round = 20; Generator featureGen = new SFGenerator(); ZeroOneLoss loss = new ZeroOneLoss(); LinearMaxPAUpdate update = new LinearMaxPAUpdate(loss); Inferencer msolver = new LinearMax(featureGen, al.size() ); OnlineTrainer trainer = new OnlineTrainer(msolver, update, loss, af.size(), round, c); Linear classify = trainer.train(train, test); String modelFile = path+".m.gz"; classify.saveTo(modelFile); long end = System.currentTimeMillis(); System.out.println("Total Time: " + (end - start)); System.out.println("End!"); (new File(modelFile)).deleteOnExit(); System.exit(0); } }