/**
* Copyright 2014, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.classification.model;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.List;
import org.tukaani.xz.LZMA2Options;
import org.tukaani.xz.XZInputStream;
import org.tukaani.xz.XZOutputStream;
import edu.emory.clir.clearnlp.classification.instance.AbstractInstance;
import edu.emory.clir.clearnlp.classification.instance.AbstractInstanceCollector;
import edu.emory.clir.clearnlp.classification.instance.IntInstance;
import edu.emory.clir.clearnlp.classification.map.LabelMap;
import edu.emory.clir.clearnlp.classification.prediction.StringPrediction;
import edu.emory.clir.clearnlp.classification.vector.AbstractFeatureVector;
import edu.emory.clir.clearnlp.classification.vector.AbstractWeightVector;
import edu.emory.clir.clearnlp.classification.vector.BinaryWeightVector;
import edu.emory.clir.clearnlp.classification.vector.MultiWeightVector;
import edu.emory.clir.clearnlp.collection.pair.DoubleIntPair;
import edu.emory.clir.clearnlp.collection.pair.Pair;
import edu.emory.clir.clearnlp.util.BinUtils;
import edu.emory.clir.clearnlp.util.DSUtils;
/**
* @since 3.0.0
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
abstract public class AbstractModel<I extends AbstractInstance<F>, F extends AbstractFeatureVector> implements Serializable
{
private static final long serialVersionUID = 6096015874433178106L;
protected AbstractInstanceCollector<I,F> i_collector;
protected AbstractWeightVector w_vector;
protected LabelMap m_labels;
/** Initializes this model for training. */
public AbstractModel(boolean binary)
{
w_vector = binary ? new BinaryWeightVector() : new MultiWeightVector();
m_labels = new LabelMap();
}
public AbstractModel(ObjectInputStream in)
{
try
{
load(in);
}
catch (ClassNotFoundException | IOException e) {e.printStackTrace();}
}
// =============================== Serialization ===============================
private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException
{
load(in);
}
private void writeObject(ObjectOutputStream out) throws IOException
{
save(out);
}
abstract public void load(ObjectInputStream in ) throws IOException, ClassNotFoundException;
abstract public void save(ObjectOutputStream out) throws IOException;
// =============================== Training ===============================
abstract public void addInstance(I instance);
public void addInstances(Collection<I> instances)
{
for (I instance : instances)
addInstance(instance);
}
// =============================== Labels/Features/Weights ===============================
public int getLabelIndex(String label)
{
return m_labels.getLabelIndex(label);
}
public int getLabelSize()
{
return w_vector.getLabelSize();
}
public int getFeatureSize()
{
return w_vector.getFeatureSize();
}
public String[] getLabels()
{
return m_labels.getLabels();
}
public AbstractWeightVector getWeightVector()
{
return w_vector;
}
public void setWeightVector(AbstractWeightVector vector)
{
w_vector = vector;
}
public boolean isBinaryLabel()
{
return w_vector.isBinaryLabel();
}
public void loadWeightVectorFromByteArray(byte[] array) throws Exception
{
ObjectInputStream ois = new ObjectInputStream(new XZInputStream(new BufferedInputStream(new ByteArrayInputStream(array))));
setWeightVector((AbstractWeightVector)ois.readObject());
ois.close();
}
public byte[] saveWeightVectorToByteArray() throws Exception
{
ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(new XZOutputStream(new BufferedOutputStream(bos), new LZMA2Options()));
oos.writeObject(w_vector);
oos.close();
return bos.toByteArray();
}
// =============================== Conversion ===============================
abstract public IntInstance toIntInstance(I instance);
public List<IntInstance> toIntInstanceList(Deque<I> sInstances)
{
BinUtils.LOG.info("Vectorizing: "+sInstances.size()+"\n");
ArrayList<IntInstance> iInstances = new ArrayList<>();
final int PRINT = 100000;
IntInstance iInstance;
for (int i=1; !sInstances.isEmpty(); i++)
{
iInstance = toIntInstance(sInstances.poll());
if (iInstance != null) iInstances.add(iInstance);
if (i%PRINT == 0) BinUtils.LOG.info(".");
}
if (iInstances.size() > PRINT) BinUtils.LOG.info("\n\n");
else BinUtils.LOG.info("\n");
iInstances.trimToSize();
return iInstances;
}
// =============================== Predictions ===============================
abstract public double[] getScores(F x);
abstract public double[] getScores(F x, int[] include);
public StringPrediction getPrediction(int labelIndex, double score)
{
return new StringPrediction(m_labels.getLabel(labelIndex), score);
}
/** @return the best prediction given the specific feature vector. */
public StringPrediction predictBest(F x)
{
return isBinaryLabel() ? predictBestBinary(x) : predictBestMulti(x);
}
private StringPrediction predictBestBinary(F x)
{
double[] scores = getScores(x);
return (scores[0] > 0) ? getPrediction(0, scores[0]) : getPrediction(1, scores[1]);
}
private StringPrediction predictBestMulti(F x)
{
double[] scores = getScores(x);
int i, size = scores.length, maxIndex = 0;
double maxValue = scores[maxIndex];
for (i=1; i<size; i++)
{
if (maxValue < scores[i])
{
maxIndex = i;
maxValue = scores[maxIndex];
}
}
return getPrediction(maxIndex, maxValue);
}
/** @return the top 2 predictions given the specific feature vector. */
public StringPrediction[] predictTop2(F x)
{
return isBinaryLabel() ? predictTop2Binary(x) : predictTop2Multi(x);
}
private StringPrediction[] predictTop2Binary(F x)
{
double[] scores = getScores(x);
StringPrediction fst = getPrediction(0, scores[0]);
StringPrediction snd = getPrediction(1, scores[1]);
return (scores[0] > 0) ? new StringPrediction[]{fst,snd} : new StringPrediction[]{snd,fst};
}
private StringPrediction[] predictTop2Multi(F x)
{
double[] scores = getScores(x);
Pair<DoubleIntPair,DoubleIntPair> top2 = DSUtils.top2(scores);
DoubleIntPair p1 = top2.o1;
DoubleIntPair p2 = top2.o2;
return new StringPrediction[]{getPrediction(p1.i,p1.d), getPrediction(p2.i,p2.d)};
}
/** @return the list of predictions given the specific feature vector sorted in descending order. */
public StringPrediction[] predictAll(F x)
{
return isBinaryLabel() ? predictTop2Binary(x) : predictAllMulti(x);
}
private StringPrediction[] predictAllMulti(F x)
{
double[] scores = getScores(x);
int i, lsize = getLabelSize();
StringPrediction[] array = new StringPrediction[lsize];
for (i=0; i<lsize; i++)
array[i] = getPrediction(i, scores[i]);
DSUtils.sortReverseOrder(array);
return array;
}
public StringPrediction predictBest(F x, int[] indices)
{
return isBinaryLabel() ? predictBestBinary(x) : predictBestMulti(x, indices);
}
private StringPrediction predictBestMulti(F x, int[] indices)
{
double[] scores = getScores(x, indices);
int i, size = indices.length, maxIndex = indices[0];
double maxValue = scores[maxIndex];
for (i=1; i<size; i++)
{
if (maxValue < scores[indices[i]])
{
maxIndex = indices[i];
maxValue = scores[maxIndex];
}
}
return getPrediction(maxIndex, maxValue);
}
/** @return the top 2 predictions given the specific feature vector. */
public StringPrediction[] predictTop2(F x, int[] indices)
{
return isBinaryLabel() ? predictTop2Binary(x) : predictTop2Multi(x, indices);
}
private StringPrediction[] predictTop2Multi(F x, int[] indices)
{
double[] scores = getScores(x, indices);
Pair<DoubleIntPair,DoubleIntPair> top2 = DSUtils.top2(scores, indices);
DoubleIntPair p1 = top2.o1;
DoubleIntPair p2 = top2.o2;
return new StringPrediction[]{getPrediction(p1.i,p1.d), getPrediction(p2.i,p2.d)};
}
/** @return the list of predictions given the specific feature vector sorted in descending order. */
public StringPrediction[] predictAll(F x, int[] indices)
{
return isBinaryLabel() ? predictTop2Binary(x) : predictAllMulti(x, indices);
}
private StringPrediction[] predictAllMulti(F x, int[] indices)
{
double[] scores = getScores(x, indices);
int i, j, lsize = indices.length;
StringPrediction[] array = new StringPrediction[lsize];
for (j=0; j<lsize; j++)
{
i = indices[j];
array[j] = getPrediction(i, scores[i]);
}
DSUtils.sortReverseOrder(array);
return array;
}
}