/** * Copyright 2014, Emory University * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package edu.emory.clir.clearnlp.classification.trainer; import java.util.Random; import edu.emory.clir.clearnlp.classification.model.SparseModel; import edu.emory.clir.clearnlp.classification.model.StringModel; import edu.emory.clir.clearnlp.classification.vector.SparseFeatureVector; import edu.emory.clir.clearnlp.util.BinUtils; import edu.emory.clir.clearnlp.util.DSUtils; /** * @since 3.0.0 * @author Jinho D. Choi ({@code jinho.choi@emory.edu}) */ public class LiblinearL2SVM extends AbstractLiblinear { /** * @param cost the cost. * @param eps the tolerance of termination criterion. * @param bias the bias. */ public LiblinearL2SVM(SparseModel model, int numThreads, double cost, double eps, double bias) { super(model, numThreads, cost, eps, bias); } /** * @param cost the cost. * @param eps the tolerance of termination criterion. * @param bias the bias. */ public LiblinearL2SVM(StringModel model, int labelCutoff, int featureCutoff, int numThreads, double cost, double eps, double bias) { super(model, labelCutoff, featureCutoff, numThreads, cost, eps, bias); } @Override public void update(int currLabel) { final Random rand = new Random(RANDOM_SEED); final int N = getInstanceSize(); float[] weight = w_vector.getWeights(currLabel); double[] alpha = new double[N]; double G, d, alpha_old; // Projected gradient, for shrinking and stopping double Gmax_old = Double.POSITIVE_INFINITY; double Gmin_old = Double.NEGATIVE_INFINITY; double violation, Gmax_new, Gmin_new; double upper_bound = d_cost; int i, s, iter, active_size = N; SparseFeatureVector xi; byte yi; int [] index = DSUtils.range(N); byte[] aY = getBinaryLabels(currLabel); double[] QD = getSumOfSquares(0, d_bias); for (iter=0; iter<MAX_ITER; iter++) { Gmax_new = Double.NEGATIVE_INFINITY; Gmin_new = Double.POSITIVE_INFINITY; DSUtils.shuffle(index, rand, active_size); for (s=0; s<active_size; s++) { i = index[s]; yi = aY[i]; xi = getInstance(i).getFeatureVector(); G = getScore(weight, xi, d_bias) * yi - 1; if (alpha[i] == 0) { if (G > Gmax_old) { active_size--; DSUtils.swap(index, s, active_size); s--; continue; } violation = Math.min(G, 0); } else if (alpha[i] == upper_bound) { if (G < Gmin_old) { active_size--; DSUtils.swap(index, s, active_size); s--; continue; } violation = Math.max(G, 0); } else { violation = G; } Gmax_new = Math.max(Gmax_new, violation); Gmin_new = Math.min(Gmin_new, violation); if (Math.abs(violation) > 1.0e-12) { alpha_old = alpha[i]; alpha[i] = Math.min(Math.max(alpha[i] - G / QD[i], 0d), upper_bound); d = (alpha[i] - alpha_old) * yi; if (d != 0) update(weight, xi, d_bias, d); } } if (Gmax_new - Gmin_new <= d_eps) { if (active_size == N) break; else { active_size = N; Gmax_old = Double.POSITIVE_INFINITY; Gmin_old = Double.NEGATIVE_INFINITY; continue; } } Gmax_old = Gmax_new; Gmin_old = Gmin_new; if (Gmax_old <= 0) Gmax_old = Double.POSITIVE_INFINITY; if (Gmin_old >= 0) Gmin_old = Double.NEGATIVE_INFINITY; } weight[0] *= d_bias; w_vector.setWeights(currLabel, weight); // int nSV = 0; // for (i=0; i<N; i++) if (alpha[i] > 0) ++nSV; StringBuilder build = new StringBuilder(); build.append("- label = "); build.append(currLabel); build.append(": iter = "); build.append(iter); // build.append(", nSV = "); build.append(nSV); build.append("\n"); BinUtils.LOG.info(build.toString()); } @Override public String trainerInfo() { return trainerInfo("SVM"); } }