/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.udf.lib;
import java.io.IOException;
import java.util.Iterator;
import java.util.Random;
import org.apache.sysml.runtime.controlprogram.caching.CacheException;
import org.apache.sysml.runtime.matrix.data.IJV;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.udf.FunctionParameter;
import org.apache.sysml.udf.Matrix;
import org.apache.sysml.udf.PackageFunction;
import org.apache.sysml.udf.Scalar;
import org.apache.sysml.udf.Matrix.ValueType;
/**
* Use this class to perform an SGD update with Nesterov momentum in CP.
* Assumption: the input batch fits in CP (which is also the assumption of most deep learning systems).
*
* Usage:
* update_nesterov = externalFunction(matrix[double] X, matrix[double] dX, double lr, double mu, matrix[double] v, double lambda) return (matrix[double] X, matrix[double] v) implemented in (classname="org.apache.sysml.udf.lib.SGDNesterovUpdate",exectype="mem");
* [X, v] = update_nesterov(X, dX, lr, mu, v);
*
*
* This class eliminates the unnecessary instruction overhead as well as memory pressure.
*
*/
public class SGDNesterovUpdate extends PackageFunction {
private static final long serialVersionUID = -3905212831582648882L;
private Matrix updatedX;
private Matrix updatedV;
private Random rand = new Random();
@Override
public int getNumFunctionOutputs() {
return 2;
}
@Override
public FunctionParameter getFunctionOutput(int pos) {
if(pos == 0)
return updatedX;
else if(pos == 1)
return updatedV;
throw new RuntimeException("Invalid function output being requested");
}
boolean isDense(MatrixBlock X) {
return !X.isInSparseFormat() && X.getDenseBlock() != null;
}
@Override
public void execute() {
try {
MatrixBlock X = ((Matrix) getFunctionInput(0)).getMatrixObject().acquireRead();
MatrixBlock dX = ((Matrix) getFunctionInput(1)).getMatrixObject().acquireRead();
double lr = Double.parseDouble(((Scalar)getFunctionInput(2)).getValue());
double mu = Double.parseDouble(((Scalar)getFunctionInput(3)).getValue());
MatrixBlock v = ((Matrix) getFunctionInput(4)).getMatrixObject().acquireRead();
double lambda = Double.parseDouble(((Scalar)getFunctionInput(5)).getValue());
// v = mu * v - lr * dX - lr*lambda*X
updatedV = new Matrix( "tmp_" + rand.nextLong(), v.getNumRows(), v.getNumColumns(), ValueType.Double );
MatrixBlock updatedVMB = allocateDenseMatrixBlock(updatedV);
double [] updatedVData = updatedVMB.getDenseBlock();
if(isDense(v) && isDense(dX) && isDense(X)) {
double [] vArr = v.getDenseBlock();
double [] dXArr = dX.getDenseBlock();
double [] XArr = X.getDenseBlock();
int nnz = 0;
for(int i = 0; i < updatedVData.length; i++) {
updatedVData[i] = mu*vArr[i] - lr*dXArr[i] - lr*lambda*XArr[i];
nnz += (updatedVData[i]!=0) ? 1 : 0;
}
updatedVMB.setNonZeros(nnz);
}
else {
multiplyByConstant(v, mu, updatedVData);
multiplyByConstant(dX, -lr, updatedVData);
multiplyByConstant(X, -lr*lambda, updatedVData);
updatedVMB.recomputeNonZeros();
}
updatedV.setMatrixDoubleArray(updatedVMB, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
// X = X - mu * v_prev + (1 + mu) * v
updatedX = new Matrix( "tmp_" + rand.nextLong(), X.getNumRows(), X.getNumColumns(), ValueType.Double );
MatrixBlock updatedXMB = allocateDenseMatrixBlock(updatedX);
double [] updatedXData = updatedXMB.getDenseBlock();
if(isDense(X) && isDense(v)) {
double [] XArr = X.getDenseBlock();
double [] vPrevArr = v.getDenseBlock();
int nnz = 0; double muPlus1 = mu+1;
for(int i = 0; i < updatedXData.length; i++) {
updatedXData[i] = XArr[i] - mu*vPrevArr[i] + muPlus1*updatedVData[i];
nnz += (updatedXData[i]!=0) ? 1 : 0;
}
updatedXMB.setNonZeros(nnz);
}
else if(isDense(v)) {
copy(X, updatedXData);
double [] vPrevArr = v.getDenseBlock();
int nnz = 0; double muPlus1 = mu+1;
for(int i = 0; i < updatedXData.length; i++) {
updatedXData[i] += - mu*vPrevArr[i] + muPlus1*updatedVData[i];
nnz += (updatedXData[i]!=0) ? 1 : 0;
}
updatedXMB.setNonZeros(nnz);
}
else {
copy(X, updatedXData);
multiplyByConstant(v, -mu, updatedXData);
multiplyByConstant(updatedVData, 1+mu, updatedXData);
updatedXMB.recomputeNonZeros();
}
updatedX.setMatrixDoubleArray(updatedXMB, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
((Matrix) getFunctionInput(0)).getMatrixObject().release();
((Matrix) getFunctionInput(1)).getMatrixObject().release();
((Matrix) getFunctionInput(4)).getMatrixObject().release();
} catch (CacheException e) {
throw new RuntimeException("Exception while executing SGDNesterovUpdate", e);
} catch (IOException e) {
throw new RuntimeException("Exception while executing SGDNesterovUpdate", e);
}
}
private MatrixBlock allocateDenseMatrixBlock(Matrix mat) {
int rows = (int) mat.getNumRows();
int cols = (int) mat.getNumCols();
MatrixBlock mb = new MatrixBlock(rows, cols, false);
mb.allocateDenseBlock();
return mb;
}
// out += constant*in
private void multiplyByConstant(double [] in, double constant, double [] out) {
for(int i = 0; i < out.length; i++) {
out[i] += in[i]*constant;
}
}
// out += constant*in
private void multiplyByConstant(MatrixBlock in, double constant, double [] out) {
if(in.isInSparseFormat()) {
Iterator<IJV> iter = in.getSparseBlockIterator();
while(iter.hasNext()) {
IJV ijv = iter.next();
out[ijv.getI()*ijv.getJ()] += ijv.getV() * constant;
}
}
else {
double [] denseBlock = in.getDenseBlock();
if(denseBlock != null) {
// If not empty block
for(int i = 0; i < out.length; i++) {
out[i] += denseBlock[i]*constant;
}
}
}
}
// Assumption dest is zero-ed out.
private void copy(MatrixBlock src, double [] dest) {
if(src.isInSparseFormat()) {
Iterator<IJV> iter = src.getSparseBlockIterator();
while(iter.hasNext()) {
IJV ijv = iter.next();
dest[ijv.getI()*ijv.getJ()] = ijv.getV();
}
}
else {
double [] denseBlock = src.getDenseBlock();
if(denseBlock != null) {
// If not empty block
System.arraycopy(denseBlock, 0, dest, 0, dest.length);
}
}
}
}