/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.udf.lib; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.Iterator; import java.util.Map.Entry; import java.util.TreeMap; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.controlprogram.caching.CacheException; import org.apache.sysml.runtime.matrix.data.IJV; import org.apache.sysml.runtime.matrix.data.InputInfo; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.OutputInfo; import org.apache.sysml.udf.FunctionParameter; import org.apache.sysml.udf.Matrix; import org.apache.sysml.udf.PackageFunction; import org.apache.sysml.udf.Matrix.ValueType; /** * Performs following operation: * # Computes the intersection ("meet") of equivalence classes for * # each row of A and B, excluding 0-valued cells. * # INPUT: * # A, B = matrices whose rows contain that row's class labels; * # for each i, rows A [i, ] and B [i, ] define two * # equivalence relations on some of the columns, which * # we want to intersect * # A [i, j] == A [i, k] != 0 if and only if (j ~ k) as defined * # by row A [i, ]; * # A [i, j] == 0 means that j is excluded by A [i, ] * # B [i, j] is analogous * # NOTE 1: Either nrow(A) == nrow(B), or exactly one of A or B * # has one row that "applies" to each row of the other matrix. * # NOTE 2: If ncol(A) != ncol(B), we pad extra 0-columns up to * # max (ncol(A), ncol(B)). * # OUTPUT: * # Both C and N have the same size as (the max of) A and B. * # C = matrix whose rows contain class labels that represent * # the intersection (coarsest common refinement) of the * # corresponding rows of A and B. * # C [i, j] == C [i, k] != 0 if and only if (j ~ k) as defined * # by both A [i, ] and B [j, ] * # C [i, j] == 0 if and only if A [i, j] == 0 or B [i, j] == 0 * # Additionally, we guarantee that non-0 labels in C [i, ] * # will be integers from 1 to max (C [i, ]) without gaps. * # For A and B the labels can be arbitrary. * # N = matrix with class-size information for C-cells * # N [i, j] = count of {C [i, k] | C [i, j] == C [i, k] != 0} * */ public class RowClassMeet extends PackageFunction { private static final long serialVersionUID = 1L; private Matrix CMat, NMat; private MatrixBlock A, B, C, N; private int nr, nc; @Override public int getNumFunctionOutputs() { return 2; } @Override public FunctionParameter getFunctionOutput(int pos) { if(pos == 0) return CMat; else if(pos == 1) return NMat; else throw new RuntimeException("RowClassMeet produces only one output"); } public class ClassLabels { public double aVal; public double bVal; public ClassLabels(double aVal, double bVal) { this.aVal = aVal; this.bVal = bVal; } } public class ClassLabelComparator implements Comparator<ClassLabels> { Integer tmp1, tmp2; @Override public int compare(ClassLabels o1, ClassLabels o2) { if(o1.aVal != o2.aVal) { tmp1 = (int) o1.aVal; tmp2 = (int) o2.aVal; } else { tmp1 = (int) o1.bVal; tmp2 = (int) o2.bVal; } return tmp1.compareTo(tmp2); } } double [] getRow(MatrixBlock B, double [] bRow, int i) { if(B.getNumRows() == 1) i = 0; Arrays.fill(bRow, 0); if(B.isInSparseFormat()) { Iterator<IJV> iter = B.getSparseBlockIterator(i, i+1); while(iter.hasNext()) { IJV ijv = iter.next(); bRow[ijv.getJ()] = ijv.getV(); } } else { double [] denseBlk = B.getDenseBlock(); if(denseBlk != null) System.arraycopy(denseBlk, i*B.getNumColumns(), bRow, 0, B.getNumColumns()); } return bRow; } @Override public void execute() { try { A = ((Matrix) getFunctionInput(0)).getMatrixObject().acquireRead(); B = ((Matrix) getFunctionInput(1)).getMatrixObject().acquireRead(); nr = Math.max(A.getNumRows(), B.getNumRows()); nc = Math.max(A.getNumColumns(), B.getNumColumns()); double [] bRow = new double[B.getNumColumns()]; CMat = new Matrix( createOutputFilePathAndName( "TMP" ), nr, nc, ValueType.Double ); C = new MatrixBlock(nr, nc, false); C.allocateDenseBlock(); NMat = new Matrix( createOutputFilePathAndName( "TMP" ), nr, nc, ValueType.Double ); N = new MatrixBlock(nr, nc, false); N.allocateDenseBlock(); double [] cBlk = C.getDenseBlock(); double [] nBlk = N.getDenseBlock(); if(B.getNumRows() == 1) getRow(B, bRow, 0); for(int i = 0; i < A.getNumRows(); i++) { if(B.getNumRows() != 1) getRow(B, bRow, i); // Create class labels TreeMap<ClassLabels, ArrayList<Integer>> classLabelMapping = new TreeMap<ClassLabels, ArrayList<Integer>>(new ClassLabelComparator()); if(A.isInSparseFormat()) { Iterator<IJV> iter = A.getSparseBlockIterator(i, i+1); while(iter.hasNext()) { IJV ijv = iter.next(); int j = ijv.getJ(); double aVal = ijv.getV(); if(aVal != 0 && bRow[j] != 0) { ClassLabels key = new ClassLabels(aVal, bRow[j]); if(!classLabelMapping.containsKey(key)) classLabelMapping.put(key, new ArrayList<Integer>()); classLabelMapping.get(key).add(j); } } } else { double [] denseBlk = A.getDenseBlock(); if(denseBlk != null) { int offset = i*A.getNumColumns(); for(int j = 0; j < A.getNumColumns(); j++) { double aVal = denseBlk[offset + j]; if(aVal != 0 && bRow[j] != 0) { ClassLabels key = new ClassLabels(aVal, bRow[j]); if(!classLabelMapping.containsKey(key)) classLabelMapping.put(key, new ArrayList<Integer>()); classLabelMapping.get(key).add(j); } } } } int labelID = 1; for(Entry<ClassLabels, ArrayList<Integer>> entry : classLabelMapping.entrySet()) { double nVal = entry.getValue().size(); for(Integer j : entry.getValue()) { nBlk[i*nc + j] = nVal; cBlk[i*nc + j] = labelID; } labelID++; } } ((Matrix) getFunctionInput(0)).getMatrixObject().release(); ((Matrix) getFunctionInput(1)).getMatrixObject().release(); } catch (CacheException e) { throw new RuntimeException("Error while executing RowClassMeet", e); } try { C.recomputeNonZeros(); C.examSparsity(); CMat.setMatrixDoubleArray(C, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo); N.recomputeNonZeros(); N.examSparsity(); NMat.setMatrixDoubleArray(N, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo); } catch (DMLRuntimeException e) { throw new RuntimeException("Error while executing RowClassMeet", e); } catch (IOException e) { throw new RuntimeException("Error while executing RowClassMeet", e); } } }