/**
* (C) Copyright IBM Corp. 2010, 2015
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.ibm.bi.dml.udf.lib;
import java.util.Arrays;
import java.util.Comparator;
import com.ibm.bi.dml.runtime.matrix.data.InputInfo;
import com.ibm.bi.dml.runtime.matrix.data.MatrixBlock;
import com.ibm.bi.dml.runtime.matrix.data.OutputInfo;
import com.ibm.bi.dml.udf.FunctionParameter;
import com.ibm.bi.dml.udf.Matrix;
import com.ibm.bi.dml.udf.PackageFunction;
import com.ibm.bi.dml.udf.PackageRuntimeException;
import com.ibm.bi.dml.udf.Scalar;
import com.ibm.bi.dml.udf.Matrix.ValueType;
/**
* Wrapper class for Sorting and Creating of a Permutation Matrix
*
* Sort single-column matrix and produce a permutation matrix. Pre-multiplying
* the input matrix with the permutation matrix produces a sorted matrix. A
* permutation matrix is a matrix where each row and each column as exactly one
* 1: To From 1
*
* Input: (n x 1)-matrix, and true/false for sorting in descending order Output:
* (n x n)- matrix
*
* permutation_matrix= externalFunction(Matrix[Double] A, Boolean desc) return
* (Matrix[Double] P) implemented in
* (classname="com.ibm.bi.dml.packagesupport.PermutationMatrixWrapper"
* ,exectype="mem"); A = read( "Data/A.mtx"); P = permutation_matrix( A[,2],
* false); B = P %*% A
*
*/
public class PermutationMatrixWrapper extends PackageFunction
{
private static final long serialVersionUID = 1L;
private static final String OUTPUT_FILE = "TMP";
// return matrix
private Matrix _ret;
@Override
public int getNumFunctionOutputs() {
return 1;
}
@Override
public FunctionParameter getFunctionOutput(int pos) {
if (pos == 0)
return _ret;
throw new PackageRuntimeException(
"Invalid function output being requested");
}
@Override
public void execute() {
try {
Matrix inM = (Matrix) getFunctionInput(0);
double[][] inData = inM.getMatrixAsDoubleArray();
boolean desc = Boolean.parseBoolean(((Scalar) getFunctionInput(1))
.getValue());
// add index column as first column
double[][] idxData = new double[(int) inM.getNumRows()][2];
for (int i = 0; i < idxData.length; i++) {
idxData[i][0] = i;
idxData[i][1] = inData[i][0];
}
// sort input matrix (in-place)
if (!desc) // asc
Arrays.sort(idxData, new AscRowComparator(1));
else
// desc
Arrays.sort(idxData, new DescRowComparator(1));
// create and populate sparse matrixblock for result
MatrixBlock mb = new MatrixBlock(idxData.length, idxData.length,
true, idxData.length);
for (int i = 0; i < idxData.length; i++) {
mb.quickSetValue(i, (int) idxData[i][0], 1.0);
}
mb.examSparsity();
// set result
String dir = createOutputFilePathAndName(OUTPUT_FILE);
_ret = new Matrix(dir, mb.getNumRows(), mb.getNumColumns(),
ValueType.Double);
_ret.setMatrixDoubleArray(mb, OutputInfo.BinaryBlockOutputInfo,
InputInfo.BinaryBlockInputInfo);
} catch (Exception e) {
throw new PackageRuntimeException(
"Error executing external permutation_matrix function", e);
}
}
/**
*
*
*/
private static class AscRowComparator implements Comparator<double[]> {
private int _col = -1;
public AscRowComparator(int col) {
_col = col;
}
@Override
public int compare(double[] arg0, double[] arg1) {
return (arg0[_col] < arg1[_col] ? -1
: (arg0[_col] == arg1[_col] ? 0 : 1));
}
}
/**
*
*
*/
private static class DescRowComparator implements Comparator<double[]> {
private int _col = -1;
public DescRowComparator(int col) {
_col = col;
}
@Override
public int compare(double[] arg0, double[] arg1) {
return (arg0[_col] > arg1[_col] ? -1
: (arg0[_col] == arg1[_col] ? 0 : 1));
}
}
}