/**
* Copyright (c) 2009, Regents of the University of Colorado All rights
* reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer. Redistributions in binary
* form must reproduce the above copyright notice, this list of conditions and
* the following disclaimer in the documentation and/or other materials provided
* with the distribution. Neither the name of the University of Colorado at
* Boulder nor the names of its contributors may be used to endorse or promote
* products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
package clear.engine;
import clear.model.AbstractModel;
import clear.train.AbstractTrainer;
import clear.train.BinaryTrainer;
import clear.train.OneVsAllTrainer;
import clear.train.algorithm.IAlgorithm;
import clear.train.algorithm.LibLinearL2;
import clear.train.algorithm.RRM;
import clear.train.kernel.AbstractKernel;
import clear.train.kernel.NoneKernel;
import clear.util.tuple.JIntObjectTuple;
import clear.util.tuple.JObjectObjectTuple;
import com.carrotsearch.hppc.IntArrayList;
import java.io.PrintStream;
import java.util.ArrayList;
import org.apache.commons.compress.archivers.jar.JarArchiveEntry;
import org.apache.commons.compress.archivers.jar.JarArchiveOutputStream;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;
/**
* Trains dependency parser. <b>Last update:</b> 11/16/2010
*
* @author Jinho D. Choi
*/
abstract public class AbstractTrain extends AbstractCommon {
@Option(name = "-m", usage = "model file", required = false, metaVar = "OPTIONAL")
protected String s_modelFile = null;
protected final String TAG_CLASSIFY = "classify";
protected final String TAG_CLASSIFY_ALGORITHM = "algorithm";
protected byte kernel_type = AbstractKernel.KERNEL_NONE;
protected byte trainer_type = AbstractTrainer.ST_ONE_VS_ALL;
protected ArrayList<JObjectObjectTuple<IntArrayList, ArrayList<int[]>>> a_yx;
protected AbstractModel trainModel(int index, JarArchiveOutputStream zout) throws Exception {
JIntObjectTuple<IAlgorithm> tup = getAlgorithm();
if (tup.object == null) {
System.err.println("Learning algorithm is not specified in the feature template");
return null;
}
IAlgorithm algorithm = tup.object;
int numThreads = tup.index;
PrintStream fout = null;
if (zout != null) {
if (a_yx.size() == 1) {
zout.putArchiveEntry(new JarArchiveEntry(ENTRY_MODEL));
} else {
zout.putArchiveEntry(new JarArchiveEntry(ENTRY_MODEL + "." + index));
}
fout = new PrintStream(zout);
}
long st = System.currentTimeMillis();
NoneKernel kernel = new NoneKernel();
kernel.out = out;
kernel.add(a_yx.get(index));
AbstractTrainer.out = out;
AbstractTrainer trainer = (trainer_type == AbstractTrainer.ST_BINARY) ? new BinaryTrainer(fout, algorithm, kernel, numThreads) : new OneVsAllTrainer(fout, algorithm, kernel, numThreads);
long time = System.currentTimeMillis() - st;
out.printf("- duration: %d h, %d m\n", time / (1000 * 3600), time / (1000 * 60));
if (zout != null) {
zout.closeArchiveEntry();
}
if (fout != null) {
if (index == a_yx.size() - 1) {
fout.close();
} else {
fout.flush();
}
}
return trainer.getModel();
}
private JIntObjectTuple<IAlgorithm> getAlgorithm() {
Element eTrain = getElement(e_config, TAG_CLASSIFY);
Element element = getElement(eTrain, TAG_CLASSIFY_ALGORITHM);
String name = element.getAttribute("name").trim();
StringBuilder options = new StringBuilder();
IAlgorithm algorithm = null;
String tmp;
switch (name) {
case IAlgorithm.LIBLINEAR_L2:
{
byte lossType = 1;
double c = 0.1, eps = 0.1, bias = -1;
if ((tmp = element.getAttribute("l").trim()).length() > 0) {
lossType = Byte.parseByte(tmp);
}
if ((tmp = element.getAttribute("c").trim()).length() > 0) {
c = Double.parseDouble(tmp);
}
if ((tmp = element.getAttribute("e").trim()).length() > 0) {
eps = Double.parseDouble(tmp);
}
if ((tmp = element.getAttribute("b").trim()).length() > 0) {
bias = Double.parseDouble(tmp);
}
algorithm = new LibLinearL2(lossType, c, eps, bias);
options.append("loss_type = ");
options.append(lossType);
options.append(", c = ");
options.append(c);
options.append(", eps = ");
options.append(eps);
options.append(", bias = ");
options.append(bias);
break;
}
case IAlgorithm.RRM:
{
int k = 40;
double mu = 1.0, eta = 0.001, c = 0.1;
if ((tmp = element.getAttribute("k").trim()).length() > 0) {
k = Integer.parseInt(tmp);
}
if ((tmp = element.getAttribute("m").trim()).length() > 0) {
mu = Double.parseDouble(tmp);
}
if ((tmp = element.getAttribute("e").trim()).length() > 0) {
eta = Double.parseDouble(tmp);
}
if ((tmp = element.getAttribute("c").trim()).length() > 0) {
c = Double.parseDouble(tmp);
}
algorithm = new RRM(k, mu, eta, c);
options.append("K = ");
options.append(k);
options.append(", mu = ");
options.append(mu);
options.append(", eta = ");
options.append(eta);
options.append(", c = ");
options.append(c);
break;
}
}
int numThreads = 1;
element = getElement(eTrain, "threads");
if (element != null) {
numThreads = Integer.parseInt(element.getTextContent().trim());
}
out.println("\n* Train model");
out.println("- algorithm: " + name);
out.println("- options : " + options.toString());
out.println("- threads : " + numThreads);
out.println();
return new JIntObjectTuple<>(numThreads, algorithm);
}
}