/** * Copyright 2014, Emory University * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package edu.emory.clir.clearnlp.bin.helper; import java.io.InputStream; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import org.kohsuke.args4j.Option; import edu.emory.clir.clearnlp.classification.configuration.AbstractTrainerConfiguration; import edu.emory.clir.clearnlp.classification.instance.AbstractInstance; import edu.emory.clir.clearnlp.classification.instance.AbstractInstanceReader; import edu.emory.clir.clearnlp.classification.instance.SparseInstanceReader; import edu.emory.clir.clearnlp.classification.instance.StringInstanceReader; import edu.emory.clir.clearnlp.classification.model.AbstractModel; import edu.emory.clir.clearnlp.classification.model.SparseModel; import edu.emory.clir.clearnlp.classification.model.StringModel; import edu.emory.clir.clearnlp.classification.prediction.StringPrediction; import edu.emory.clir.clearnlp.classification.trainer.AbstractTrainer; import edu.emory.clir.clearnlp.classification.vector.AbstractFeatureVector; import edu.emory.clir.clearnlp.experiment.AbstractArgsReader; import edu.emory.clir.clearnlp.util.BinUtils; import edu.emory.clir.clearnlp.util.IOUtils; import edu.emory.clir.clearnlp.util.MathUtils; import edu.emory.clir.clearnlp.util.adapter.Adapter1; /** * @since 3.0.0 * @author Jinho D. Choi ({@code jinho.choi@emory.edu}) */ abstract public class AbstractClassify { static public final byte TYPE_SPARSE = 0; static public final byte TYPE_STRING = 1; @Option(name="-trainFile", usage="the training file (optional)", required=false, metaVar="<filename>") protected String s_trainFile; @Option(name="-modelFile", usage="the model file (optional)", required=false, metaVar="<filename>") protected String s_modelFile; @Option(name="-testFile", usage="the test filename (optional)", required=false, metaVar="<filename>") protected String s_testFile = null; @Option(name="-lcutoff", usage="label frequency cutoff (default: 0)\n"+"exclusive, string vector space only", required=false, metaVar="<integer>") protected int i_labelCutoff = 0; @Option(name="-fcutoff", usage="feature frequency cutoff (default: 0)\n"+"exclusive, string vector space only", required=false, metaVar="<integer>") protected int i_featureCutoff = 0; @Option(name="-threads", usage="the number of threads to be used (default: 1)", required=false, metaVar="<integer>") protected int i_numberOfThreads = 1; @Option(name="-binary", usage="if set, train a binary model (default: false)", required=false, metaVar="<boolean>") protected boolean b_binary = false; @Option(name="-type", usage="the type of vector space (default: "+AbstractClassify.TYPE_STRING+")\n"+ AbstractClassify.TYPE_SPARSE+": sparse vector space\n"+ AbstractClassify.TYPE_STRING+": string vector space\n", required=false, metaVar="<byte>") protected byte i_vectorType = AbstractClassify.TYPE_STRING; /** Called by {@link #AbstractClassify(String[])}. */ abstract protected AbstractTrainerConfiguration createTrainConfiguration(); /** Called by {@link #train(AbstractTrainerConfiguration, String)}. */ abstract protected AbstractTrainer getTrainer(AbstractTrainerConfiguration trainConfiguration, AbstractModel<?,?> model); @SuppressWarnings("unchecked") /** Called by {@link #trainModel(AbstractTrainConfiguration, String)}. */ public <I extends AbstractInstance<F>, F extends AbstractFeatureVector>AbstractModel<I,F> createModel(byte vectorType, boolean binary) { return (AbstractModel<I,F>)(vectorType == AbstractClassify.TYPE_SPARSE ? new SparseModel(binary) : new StringModel(binary)); } public AbstractModel<?,?> loadModel(String modelFile, byte vectorType) { try { ObjectInputStream in = IOUtils.createObjectXZBufferedInputStream(modelFile); AbstractModel<?,?> model = vectorType == AbstractClassify.TYPE_SPARSE ? new SparseModel(in) : new StringModel(in); in.close(); return model; } catch (Exception e) {e.printStackTrace();} return null; } public void saveModel(AbstractModel<?,?> model, String modelFile) { try { ObjectOutputStream out = IOUtils.createObjectXZBufferedOutputStream(modelFile); model.save(out); out.close(); } catch (Exception e) {e.printStackTrace();} } public <I extends AbstractInstance<F>, F extends AbstractFeatureVector>double evaluate(AbstractModel<I,F> model, String testFile) { BinUtils.LOG.info("Evaluating: "+testFile+"\n"); EvaluateAdapter<I,F> adapter = new EvaluateAdapter<>(model); process(adapter, testFile, isSparseModel(model)); double acc = adapter.getAccuracy(); BinUtils.LOG.info(String.format("- Accuracy: %7.4f (%d/%d)\n", acc, adapter.getCorrect(), adapter.getTotal())); return acc; } /** Called by {@link #train(AbstractTrainerConfiguration, String)}. */ protected <I extends AbstractInstance<F>, F extends AbstractFeatureVector>void readInstances(AbstractModel<I,F> model, String trainFile) { BinUtils.LOG.info("Reading: "+trainFile+"\n"); InstanceAdapter<I,F> adapter = new InstanceAdapter<>(model); process(adapter, trainFile, isSparseModel(model)); BinUtils.LOG.info("- "+adapter.getTotal()+" instances\n"); } protected <I extends AbstractInstance<F>, F extends AbstractFeatureVector>void process(AbstractAdapter<I,F> adapter, String filename, boolean sparse) { AbstractInstanceReader<I,F> reader = getInstanceReader(filename, sparse); reader.applyAll(adapter); reader.close(); } @SuppressWarnings({ "unchecked" }) /** Called by {@link #process(AbstractAdapter, String, boolean)}. */ private <I extends AbstractInstance<F>, F extends AbstractFeatureVector>AbstractInstanceReader<I,F> getInstanceReader(String filename, boolean sparse) { InputStream in = IOUtils.createFileInputStream(filename); return (AbstractInstanceReader<I,F>)(sparse ? new SparseInstanceReader(in) : new StringInstanceReader(in)); } protected boolean isSparseModel(AbstractModel<?,?> model) { return model instanceof SparseModel; } protected class ArgsReader extends AbstractArgsReader { public ArgsReader(String[] args, Object obj) { super(args, obj); } @Override protected String getErrorMessage() { if (s_trainFile == null && s_modelFile == null) return "Either a \"training filename\" or a \"model filename\" must be specified."; return null; } } abstract private class AbstractAdapter<I extends AbstractInstance<F>, F extends AbstractFeatureVector> implements Adapter1<I> { protected AbstractModel<I,F> a_model; protected int n_total; public AbstractAdapter(AbstractModel<I,F> model) { a_model = model; n_total = 0; } public int getTotal() { return n_total; } } private class InstanceAdapter<I extends AbstractInstance<F>, F extends AbstractFeatureVector> extends AbstractAdapter<I,F> { public InstanceAdapter(AbstractModel<I,F> model) { super(model); } @Override public void apply(I instance) { a_model.addInstance(instance); n_total++; } } private class EvaluateAdapter<I extends AbstractInstance<F>, F extends AbstractFeatureVector> extends AbstractAdapter<I,F> { private int n_correct; public EvaluateAdapter(AbstractModel<I,F> model) { super(model); n_correct = 0; } @Override public void apply(I instance) { StringPrediction p = a_model.predictBest(instance.getFeatureVector()); if (instance.isLabel(p.getLabel())) n_correct++; n_total++; } public int getCorrect() { return n_correct; } public double getAccuracy() { return MathUtils.accuracy(n_correct, n_total); } } }