/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.udf.lib;
import java.io.IOException;
import java.util.Iterator;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.CacheException;
import org.apache.sysml.runtime.matrix.data.IJV;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.udf.FunctionParameter;
import org.apache.sysml.udf.Matrix;
import org.apache.sysml.udf.PackageFunction;
import org.apache.sysml.udf.Scalar;
import org.apache.sysml.udf.Matrix.ValueType;
/**
* Variant of cumsum:<br>
* Computes following two functions:<br>
* <pre>
* <code>
* cumsum_prod = function (Matrix[double] X, Matrix[double] C, double start) return (Matrix[double] Y)
* # Computes the following recurrence in log-number of steps:
* # Y [1, ] = X [1, ] + C [1, ] * start;
* # Y [i+1, ] = X [i+1, ] + C [i+1, ] * Y [i, ]
* {
* Y = X; P = C; m = nrow(X); k = 1;
* Y [1, ] = Y [1, ] + C [1, ] * start;
* while (k < m) {
* Y [k+1 : m, ] = Y [k+1 : m, ] + Y [1 : m-k, ] * P [k+1 : m, ];
* P [k+1 : m, ] = P [1 : m-k, ] * P [k+1 : m, ];
* k = 2 * k;
* }
* }
*
* cumsum_prod_reverse = function (Matrix[double] X, Matrix[double] C, double start) return (Matrix[double] Y)
* # Computes the reverse recurrence in log-number of steps:
* # Y [m, ] = X [m, ] + C [m, ] * start;
* # Y [i-1, ] = X [i-1, ] + C [i-1, ] * Y [i, ]
* {
* Y = X; P = C; m = nrow(X); k = 1;
* Y [m, ] = Y [m, ] + C [m, ] * start;
* while (k < m) {
* Y [1 : m-k, ] = Y [1 : m-k, ] + Y [k+1 : m, ] * P [1 : m-k, ];
* P [1 : m-k, ] = P [k+1 : m, ] * P [1 : m-k, ];
* k = 2 * k;
* }
* }
* </code>
* </pre>
*
* The API of this external built-in function is as follows:<br>
* <pre>
* <code>
* func = externalFunction(matrix[double] X, matrix[double] C, double start, boolean isReverse) return (matrix[double] Y)
* implemented in (classname="org.apache.sysml.udf.lib.CumSumProd",exectype="mem");
* </code>
* </pre>
*/
public class CumSumProd extends PackageFunction {
private static final long serialVersionUID = -7883258699548686065L;
private Matrix ret;
private MatrixBlock retMB, X, C;
private double start;
private boolean isReverse;
@Override
public int getNumFunctionOutputs() {
return 1;
}
@Override
public FunctionParameter getFunctionOutput(int pos) {
if(pos == 0)
return ret;
else
throw new RuntimeException("CumSumProd produces only one output");
}
@Override
public void execute() {
try {
X = ((Matrix) getFunctionInput(0)).getMatrixObject().acquireRead();
C = ((Matrix) getFunctionInput(1)).getMatrixObject().acquireRead();
if(X.getNumRows() != C.getNumRows())
throw new RuntimeException("Number of rows of X and C should match");
if( X.getNumColumns() != C.getNumColumns() && C.getNumColumns() != 1 )
throw new RuntimeException("Incorrect Number of columns of X and C (Expected C to be of same dimension or a vector)");
start = Double.parseDouble(((Scalar)getFunctionInput(2)).getValue());
isReverse = Boolean.parseBoolean(((Scalar)getFunctionInput(3)).getValue());
numRetRows = X.getNumRows();
numRetCols = X.getNumColumns();
allocateOutput();
// Copy X to Y
denseBlock = retMB.getDenseBlock();
if(X.isInSparseFormat()) {
Iterator<IJV> iter = X.getSparseBlockIterator();
while(iter.hasNext()) {
IJV ijv = iter.next();
denseBlock[ijv.getI()*numRetCols + ijv.getJ()] = ijv.getV();
}
}
else {
if(X.getDenseBlock() != null)
System.arraycopy(X.getDenseBlock(), 0, denseBlock, 0, denseBlock.length);
}
if(!isReverse) {
// Y [1, ] = X [1, ] + C [1, ] * start;
// Y [i+1, ] = X [i+1, ] + C [i+1, ] * Y [i, ]
addCNConstant(0, start);
for(int i = 1; i < numRetRows; i++) {
addC(i, true);
}
}
else {
// Y [m, ] = X [m, ] + C [m, ] * start;
// Y [i-1, ] = X [i-1, ] + C [i-1, ] * Y [i, ]
addCNConstant(numRetRows-1, start);
for(int i = numRetRows - 2; i >= 0; i--) {
addC(i, false);
}
}
((Matrix) getFunctionInput(1)).getMatrixObject().release();
((Matrix) getFunctionInput(0)).getMatrixObject().release();
} catch (CacheException e) {
throw new RuntimeException("Error while executing CumSumProd", e);
}
retMB.recomputeNonZeros();
try {
retMB.examSparsity();
ret.setMatrixDoubleArray(retMB, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
} catch (DMLRuntimeException e) {
throw new RuntimeException("Error while executing CumSumProd", e);
} catch (IOException e) {
throw new RuntimeException("Error while executing CumSumProd", e);
}
}
int numRetRows; int numRetCols;
double [] denseBlock;
private void addCNConstant(int i, double constant) {
boolean isCVector = C.getNumColumns() != ret.getNumCols();
if(C.isInSparseFormat()) {
Iterator<IJV> iter = C.getSparseBlockIterator(i, i+1);
while(iter.hasNext()) {
IJV ijv = iter.next();
if(!isCVector)
denseBlock[ijv.getI()*numRetCols + ijv.getJ()] += ijv.getV() * constant;
else {
double val = ijv.getV();
for(int j = ijv.getI()*numRetCols; j < (ijv.getI()+1)*numRetCols; j++) {
denseBlock[j] += val*constant;
}
}
}
}
else {
double [] CBlk = C.getDenseBlock();
if(CBlk != null) {
if(!isCVector) {
for(int j = i*numRetCols; j < (i+1)*numRetCols; j++) {
denseBlock[j] += CBlk[j]*constant;
}
}
else {
for(int j = i*numRetCols; j < (i+1)*numRetCols; j++) {
denseBlock[j] += CBlk[i]*constant;
}
}
}
}
}
private void addC(int i, boolean addPrevRow) {
boolean isCVector = C.getNumColumns() != ret.getNumCols();
if(C.isInSparseFormat()) {
Iterator<IJV> iter = C.getSparseBlockIterator(i, i+1);
while(iter.hasNext()) {
IJV ijv = iter.next();
if(!isCVector) {
if(addPrevRow)
denseBlock[ijv.getI()*numRetCols + ijv.getJ()] += ijv.getV() * denseBlock[(ijv.getI()-1)*numRetCols + ijv.getJ()];
else
denseBlock[ijv.getI()*numRetCols + ijv.getJ()] += ijv.getV() * denseBlock[(ijv.getI()+1)*numRetCols + ijv.getJ()];
}
else {
double val = ijv.getV();
for(int j = ijv.getI()*numRetCols; j < (ijv.getI()+1)*numRetCols; j++) {
double val1 = addPrevRow ? denseBlock[(ijv.getI()-1)*numRetCols + ijv.getJ()] : denseBlock[(ijv.getI()+1)*numRetCols + ijv.getJ()];
denseBlock[j] += val*val1;
}
}
}
}
else {
double [] CBlk = C.getDenseBlock();
if(CBlk != null) {
if(!isCVector) {
for(int j = i*numRetCols; j < (i+1)*numRetCols; j++) {
double val1 = addPrevRow ? denseBlock[j-numRetCols] : denseBlock[j+numRetCols];
denseBlock[j] += CBlk[j]*val1;
}
}
else {
for(int j = i*numRetCols; j < (i+1)*numRetCols; j++) {
double val1 = addPrevRow ? denseBlock[j-numRetCols] : denseBlock[j+numRetCols];
denseBlock[j] += CBlk[i]*val1;
}
}
}
}
}
private void allocateOutput() {
String dir = createOutputFilePathAndName( "TMP" );
ret = new Matrix( dir, numRetRows, numRetCols, ValueType.Double );
retMB = new MatrixBlock((int) numRetRows, (int) numRetCols, false);
retMB.allocateDenseBlock();
}
}