/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysml.udf.lib;
import java.io.IOException;
import java.util.Iterator;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.caching.CacheException;
import org.apache.sysml.runtime.matrix.data.IJV;
import org.apache.sysml.runtime.matrix.data.InputInfo;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.udf.FunctionParameter;
import org.apache.sysml.udf.Matrix;
import org.apache.sysml.udf.PackageFunction;
import org.apache.sysml.udf.Scalar;
import org.apache.sysml.udf.Matrix.ValueType;
/**
* This external built-in function addresses following two common scenarios:
* 1. cbind (cbind (cbind ( X1, X2 ), X3 ), X4)
* 2. With spagetization: cbind (cbind (cbind ( matrix(X1, rows=length(X1), cols=1), matrix(X2, rows=length(X2), cols=1) ), matrix(X3, rows=length(X3), cols=1) ), matrix(X4, rows=length(X4), cols=1))
*
* The API of this external built-in function is as follows:
*
* func = externalFunction(int numInputs, boolean spagetize, matrix[double] X1, matrix[double] X2, matrix[double] X3, matrix[double] X4) return (matrix[double] out)
* implemented in (classname="org.apache.sysml.udf.lib.MultiInputCbind",exectype="mem");
*
*/
public class MultiInputCbind extends PackageFunction {
private static final long serialVersionUID = -4266180315672563097L;
private Matrix ret;
private MatrixBlock retMB;
long numRetRows; long numRetCols;
boolean spagetize;
@Override
public int getNumFunctionOutputs() {
return 1;
}
@Override
public FunctionParameter getFunctionOutput(int pos) {
if(pos == 0)
return ret;
else
throw new RuntimeException("MultiInputCbind produces only one output");
}
@Override
public void execute() {
int numInputs = Integer.parseInt(((Scalar)getFunctionInput(0)).getValue());
spagetize = Boolean.parseBoolean(((Scalar)getFunctionInput(1)).getValue());
// Compute output dimensions
try {
numRetCols = 0;
if(spagetize) {
// Assumption the inputs are of same shape
MatrixBlock in = ((Matrix) getFunctionInput(2)).getMatrixObject().acquireRead();
numRetRows = in.getNumRows()*in.getNumColumns();
numRetCols = numInputs;
((Matrix) getFunctionInput(2)).getMatrixObject().release();
}
else {
for(int inputID = 2; inputID < numInputs + 2; inputID++) {
MatrixBlock in = ((Matrix) getFunctionInput(inputID)).getMatrixObject().acquireRead();
numRetRows = in.getNumRows();
numRetCols += in.getNumColumns();
((Matrix) getFunctionInput(inputID)).getMatrixObject().release();
}
}
} catch (CacheException e) {
throw new RuntimeException("Error while executing MultiInputCbind", e);
}
allocateOutput();
// Performs cbind (cbind (cbind ( X1, X2 ), X3 ), X4)
double [] retData = retMB.getDenseBlock();
try {
int startColumn = 0;
for(int inputID = 2; inputID < numInputs + 2; inputID++) {
MatrixBlock in = ((Matrix) getFunctionInput(inputID)).getMatrixObject().acquireRead();
if(spagetize && in.getNumRows()*in.getNumColumns() != numRetRows) {
throw new RuntimeException("Expected the inputs to be of same size when spagetization is turned on.");
}
int inputNumCols = in.getNumColumns();
if(in.isInSparseFormat()) {
Iterator<IJV> iter = in.getSparseBlockIterator();
while(iter.hasNext()) {
IJV ijv = iter.next();
if(spagetize) {
// Perform matrix(X1, rows=length(X1), cols=1) operation before cbind
// Output Column ID = inputID-2 for all elements of inputs
int outputRowIndex = ijv.getI()*inputNumCols + ijv.getJ();
int outputColIndex = inputID-2;
retData[(int) (outputRowIndex*retMB.getNumColumns() + outputColIndex)] = ijv.getV();
}
else {
// Traditional cbind
// Row ID remains the same as that of input
int outputRowIndex = ijv.getI();
int outputColIndex = ijv.getJ() + startColumn;
retData[(int) (outputRowIndex*retMB.getNumColumns() + outputColIndex)] = ijv.getV();
}
}
}
else {
double [] denseBlock = in.getDenseBlock();
if(denseBlock != null) {
if(spagetize) {
// Perform matrix(X1, rows=length(X1), cols=1) operation before cbind
// Output Column ID = inputID-2 for all elements of inputs
int j = inputID-2;
for(int i = 0; i < numRetRows; i++) {
retData[(int) (i*numRetCols + j)] = denseBlock[i];
}
}
else {
// Traditional cbind
// Row ID remains the same as that of input
for(int i = 0; i < retMB.getNumRows(); i++) {
for(int j = 0; j < inputNumCols; j++) {
int outputColIndex = j + startColumn;
retData[(int) (i*numRetCols + outputColIndex)] = denseBlock[i*inputNumCols + j];
}
}
}
}
}
((Matrix) getFunctionInput(inputID)).getMatrixObject().release();
startColumn += inputNumCols;
}
} catch (CacheException e) {
throw new RuntimeException("Error while executing MultiInputCbind", e);
}
retMB.recomputeNonZeros();
try {
retMB.examSparsity();
ret.setMatrixDoubleArray(retMB, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
} catch (DMLRuntimeException e) {
throw new RuntimeException("Error while executing MultiInputCbind", e);
} catch (IOException e) {
throw new RuntimeException("Error while executing MultiInputCbind", e);
}
}
private void allocateOutput() {
String dir = createOutputFilePathAndName( "TMP" );
ret = new Matrix( dir, numRetRows, numRetCols, ValueType.Double );
retMB = new MatrixBlock((int) numRetRows, (int) numRetCols, false);
retMB.allocateDenseBlock();
}
}