/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.alg; import ml.shifu.shifu.container.obj.ModelConfig; import ml.shifu.shifu.core.AbstractTrainer; import org.apache.commons.io.FileUtils; import org.encog.ml.svm.KernelType; import org.encog.ml.svm.SVM; import org.encog.ml.svm.SVMType; import org.encog.ml.svm.training.SVMTrain; import org.encog.persist.EncogDirectoryPersistence; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.io.IOException; import java.util.HashMap; import java.util.Map; /** * Implementation of AbstractTrainer for support vector machine classification */ public class SVMTrainer extends AbstractTrainer { public static final String SVM_KERNEL = "Kernel"; public static final String SVM_GAMMA = "Gamma"; public static final String SVM_CONST = "Const"; private SVM svm; private static Map<String, KernelType> kernel = new HashMap<String, KernelType>(); private static Map<String, SVMType> type = new HashMap<String, SVMType>(); protected Logger log = LoggerFactory.getLogger(SVMTrainer.class); static { kernel.put("Leaner kernel".toLowerCase(), KernelType.Linear); kernel.put("linear".toLowerCase(), KernelType.Linear); kernel.put("Poly kernel".toLowerCase(), KernelType.Poly); kernel.put("poly", KernelType.Poly); kernel.put("Sigmoid kernel".toLowerCase(), KernelType.Sigmoid); kernel.put("Sigmoid".toLowerCase(), KernelType.Sigmoid); kernel.put("RadialBasisFunction".toLowerCase(), KernelType.RadialBasisFunction); kernel.put("RBF".toLowerCase(), KernelType.RadialBasisFunction); type.put("classification", SVMType.SupportVectorClassification); type.put("regresssion", SVMType.EpsilonSupportVectorRegression); // kernel.put(KernelType.Precomputed, "") } /** * SVMTrainer Constructor * * @param modelConfig * modelConfig * @param trainerID * trainerID * @param dryRun * dryRun */ public SVMTrainer(ModelConfig modelConfig, int trainerID, Boolean dryRun) { super(modelConfig, trainerID, dryRun); } @Override public double train() throws IOException { if(this.trainerID == 0) { log.info("Trainer #" + (this.trainerID + 1) + " Using SVM algorithm..."); } encogTrain(); return 0.0d; } /** * Setup SVM */ private void buildSVM() { svm = new SVM(this.trainSet.getInputSize(), SVMType.SupportVectorClassification, kernel.get(modelConfig .getParams().get("Kernel"))); } /** * using Encog's SVM trainer */ private void encogTrain() { buildSVM(); SVMTrain trainer = new SVMTrain(svm, trainSet); trainer.setC((Double) modelConfig.getParams().get("Const")); trainer.setGamma((Double) modelConfig.getParams().get("Gamma")); if(this.trainerID == 0) { log.info("Using kenerl function " + svm.getKernelType()); } SVMRunner runner = new SVMRunner(trainer); Thread thread = new Thread(runner); thread.start(); long second = 1000; while(!runner.isFinish()) { try { Thread.sleep(second); log.info("Trainer #" + this.trainerID + " is running"); } catch (InterruptedException e) { throw new RuntimeException("Within system interrupted"); } } log.info("Trainer #" + this.trainerID + " finish training"); trainer = runner.trainer; log.info("Train #" + this.trainerID + " Error: " + df.format(trainer.getError()) + " Validation Error:" + df.format(getValidSetError())); saveModel(); } private void saveModel() { File folder = new File("./models"); if(!folder.exists()) { try { FileUtils.forceMkdir(folder); } catch (IOException e) { log.error("Failed to create directory: {}", folder.getAbsolutePath()); e.printStackTrace(); } } EncogDirectoryPersistence.saveObject(new File("./models/model" + this.trainerID + ".svm"), svm); } public double getValidSetError() { return svm.calculateError(this.validSet); } public SVM getSVM() { return svm; } /** * SVMtrainer worker */ private static class SVMRunner implements Runnable { private SVMTrain trainer; private boolean isFinish; public SVMRunner(SVMTrain trainer) { this.trainer = trainer; this.isFinish = false; } @Override public void run() { trainer.setFold(1); trainer.iteration(); isFinish = true; } public boolean isFinish() { return isFinish; } } }