/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.experiment.math.optimization; import rapaio.math.linear.RV; import rapaio.util.Pair; import java.io.Serializable; /** * Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> on 11/24/15. */ @Deprecated public interface Updater extends Serializable { Pair<RV, Double> compute(RV weightsOld, RV gradient, double stepSize, int iter, double regParam); } /** * A simple updater for gradient descent *without* any regularization. * Uses a step-size decreasing with the square root of the number of iterations. */ @Deprecated class SimpleUpdater implements Updater { private static final long serialVersionUID = -2067278844383126771L; public Pair<RV, Double> compute(RV weightsOld, RV gradient, double stepSize, int iter, double regParam) { double thisIterStepSize = stepSize / Math.sqrt(iter); RV brzWeights = weightsOld.solidCopy(); brzWeights.plus(gradient.solidCopy().dot(-thisIterStepSize)); return Pair.from(brzWeights, 0.0); } } /** * Updater for L1 regularized problems. * R(w) = ||w||_1 * Uses a step-size decreasing with the square root of the number of iterations. * Instead of subgradient of the regularizer, the proximal operator for the * L1 regularization is applied after the gradient step. This is known to * result in better sparsity of the intermediate solution. * <p> * The corresponding proximal operator for the L1 norm is the soft-thresholding * function. That is, each weight component is shrunk towards 0 by shrinkageVal. * <p> * If w > shrinkageVal, set weight component to w-shrinkageVal. * If w < -shrinkageVal, set weight component to w+shrinkageVal. * If -shrinkageVal < w < shrinkageVal, set weight component to 0. * <p> * Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal) */ @Deprecated class L1Updater implements Updater { public Pair<RV, Double> compute(RV weightsOld, RV gradient, double stepSize, int iter, double regParam) { double thisIterStepSize = stepSize / Math.sqrt(iter); // Take gradient step RV brzWeights = weightsOld.solidCopy(); brzWeights.plus(gradient.solidCopy().dot(-thisIterStepSize)); // Apply proximal operator (soft thresholding) double shrinkageVal = regParam * thisIterStepSize; int i = 0; int len = brzWeights.count(); while (i < len) { double wi = brzWeights.get(i); brzWeights.set(i, Math.signum(wi) * Math.max(0.0, Math.abs(wi) - shrinkageVal)); i += 1; } return Pair.from(brzWeights, brzWeights.norm(1) * regParam); } } /** * Updater for L2 regularized problems. * R(w) = 1/2 ||w||^2 * Uses a step-size decreasing with the square root of the number of iterations. */ @Deprecated class SquaredL2Updater implements Updater { public Pair<RV, Double> compute(RV weightsOld, RV gradient, double stepSize, int iter, double regParam) { // add up both updates from the gradient of the loss (= step) as well as // the gradient of the regularizer (= regParam * weightsOld) // w' = w - thisIterStepSize * (gradient + regParam * w) // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient double thisIterStepSize = stepSize / Math.sqrt(iter); RV brzWeights = weightsOld.solidCopy(); brzWeights.minus(brzWeights.solidCopy().dot(thisIterStepSize * regParam)); brzWeights.plus(gradient.solidCopy().dot(-thisIterStepSize)); double norm = brzWeights.norm(2.0); return Pair.from(brzWeights, 0.5 * regParam * norm * norm); } }