/**
* Copyright 2014, Emory University
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.emory.clir.clearnlp.classification.trainer;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import edu.emory.clir.clearnlp.classification.model.SparseModel;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.vector.BinaryWeightVector;
import edu.emory.clir.clearnlp.classification.vector.SparseFeatureVector;
import edu.emory.clir.clearnlp.util.BinUtils;
import edu.emory.clir.clearnlp.util.MathUtils;
/**
* @since 3.0.0
* @author Jinho D. Choi ({@code jinho.choi@emory.edu})
*/
abstract public class AbstractOneVsAllTrainer extends AbstractTrainer
{
protected int n_threads;
/** @param numThreads the number of threads. */
public AbstractOneVsAllTrainer(SparseModel model, int numThreads)
{
super(TrainerType.ONE_VS_ALL, model);
setNumberOfThreads(numThreads);
}
/** @param numThreads the number of threads. */
public AbstractOneVsAllTrainer(StringModel model, int labelCutoff, int featureCutoff, int numThreads)
{
super(TrainerType.ONE_VS_ALL, model, labelCutoff, featureCutoff);
setNumberOfThreads(numThreads);
}
public void setNumberOfThreads(int numThreads)
{
n_threads = numThreads;
}
public void train()
{
if (w_vector.isBinaryLabel()) trainBinary();
else trainMulti();
}
private void trainBinary()
{
update(BinaryWeightVector.POSITIVE);
}
private void trainMulti()
{
ExecutorService executor = Executors.newFixedThreadPool(n_threads);
int currLabel, size = w_vector.getLabelSize();
BinUtils.LOG.info("One vs. All\n");
for (currLabel=0; currLabel<size; currLabel++)
executor.execute(new TrainTask(currLabel));
executor.shutdown();
try
{
executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
}
catch (InterruptedException e) {e.printStackTrace();}
}
class TrainTask implements Runnable
{
int curr_label;
/** @param currLabel the current label to train. */
public TrainTask(int currLabel)
{
curr_label = currLabel;
}
public void run()
{
update(curr_label);
}
}
abstract protected void update(int currLabel);
/** @return an array of 1 or -1. */
protected byte[] getBinaryLabels(int currLabel)
{
int i, size = getInstanceSize();
byte[] aY = new byte[size];
for (i=0; i<size; i++)
aY[i] = getInstance(i).isLabel(currLabel) ? (byte)1 : (byte)-1;
return aY;
}
protected double getScore(float[] weight, SparseFeatureVector x, double bias)
{
double score = weight[0] * bias;
int i, len = x.size();
for (i=0; i<len; i++)
score += weight[x.getIndex(i)] * x.getWeight(i);
return score;
}
protected void update(float[] weight, SparseFeatureVector x, double bias, double cost)
{
weight[0] += cost * bias;
int i, len = x.size();
for (i=0; i<len; i++)
weight[x.getIndex(i)] += cost * x.getWeight(i);
}
protected double[] getSumOfSquares(double init, double bias)
{
int i, size = getInstanceSize();
double[] qd = new double[size];
init += MathUtils.sq(bias);
SparseFeatureVector x;
for (i=0; i<size; i++)
{
x = getInstance(i).getFeatureVector();
qd[i] = init + x.sumOfSquares();
}
return qd;
}
}