/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.compress; import java.util.Arrays; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.functionobjects.Builtin; import org.apache.sysml.runtime.functionobjects.KahanFunction; import org.apache.sysml.runtime.functionobjects.KahanPlus; import org.apache.sysml.runtime.functionobjects.KahanPlusSq; import org.apache.sysml.runtime.functionobjects.ReduceAll; import org.apache.sysml.runtime.functionobjects.ReduceCol; import org.apache.sysml.runtime.functionobjects.ReduceRow; import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysml.runtime.instructions.cp.KahanObject; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.operators.AggregateUnaryOperator; /** * Class to encapsulate information about a column group that is encoded with * dense dictionary encoding (DDC). * * NOTE: zero values are included at position 0 in the value dictionary, which * simplifies various operations such as counting the number of non-zeros. */ public abstract class ColGroupDDC extends ColGroupValue { private static final long serialVersionUID = -3204391646123465004L; public ColGroupDDC() { super(); } public ColGroupDDC(int[] colIndices, int numRows, UncompressedBitmap ubm) { super(colIndices, numRows, ubm); } protected ColGroupDDC(int[] colIndices, int numRows, double[] values) { super(colIndices, numRows, values); } @Override public void decompressToBlock(MatrixBlock target, int rl, int ru) { for( int i = rl; i < ru; i++ ) { for( int colIx = 0; colIx < _colIndexes.length; colIx++ ) { int col = _colIndexes[colIx]; double cellVal = getData(i, colIx); target.quickSetValue(i, col, cellVal); } } } @Override public void decompressToBlock(MatrixBlock target, int[] colIndexTargets) { int nrow = getNumRows(); int ncol = getNumCols(); for( int i = 0; i < nrow; i++ ) { for( int colIx = 0; colIx < ncol; colIx++ ) { int origMatrixColIx = getColIndex(colIx); int col = colIndexTargets[origMatrixColIx]; double cellVal = getData(i, colIx); target.quickSetValue(i, col, cellVal); } } } @Override public void decompressToBlock(MatrixBlock target, int colpos) { int nrow = getNumRows(); for( int i = 0; i < nrow; i++ ) { double cellVal = getData(i, colpos); target.quickSetValue(i, 0, cellVal); } } @Override public double get(int r, int c) { //find local column index int ix = Arrays.binarySearch(_colIndexes, c); if( ix < 0 ) throw new RuntimeException("Column index "+c+" not in DDC group."); //get value return getData(r, ix); } @Override protected void countNonZerosPerRow(int[] rnnz, int rl, int ru) { int ncol = getNumCols(); for( int i = rl; i < ru; i++ ) { int lnnz = 0; for( int colIx=0; colIx < ncol; colIx++ ) lnnz += (getData(i, colIx) != 0) ? 1 : 0; rnnz[i-rl] += lnnz; } } @Override public void unaryAggregateOperations(AggregateUnaryOperator op, MatrixBlock result, int rl, int ru) throws DMLRuntimeException { //sum and sumsq (reduceall/reducerow over tuples and counts) if( op.aggOp.increOp.fn instanceof KahanPlus || op.aggOp.increOp.fn instanceof KahanPlusSq ) { KahanFunction kplus = (op.aggOp.increOp.fn instanceof KahanPlus) ? KahanPlus.getKahanPlusFnObject() : KahanPlusSq.getKahanPlusSqFnObject(); if( op.indexFn instanceof ReduceAll ) computeSum(result, kplus); else if( op.indexFn instanceof ReduceCol ) computeRowSums(result, kplus, rl, ru); else if( op.indexFn instanceof ReduceRow ) computeColSums(result, kplus); } //min and max (reduceall/reducerow over tuples only) else if(op.aggOp.increOp.fn instanceof Builtin && (((Builtin)op.aggOp.increOp.fn).getBuiltinCode()==BuiltinCode.MAX || ((Builtin)op.aggOp.increOp.fn).getBuiltinCode()==BuiltinCode.MIN)) { Builtin builtin = (Builtin) op.aggOp.increOp.fn; if( op.indexFn instanceof ReduceAll ) computeMxx(result, builtin, false); else if( op.indexFn instanceof ReduceCol ) computeRowMxx(result, builtin, rl, ru); else if( op.indexFn instanceof ReduceRow ) computeColMxx(result, builtin, false); } } protected void computeSum(MatrixBlock result, KahanFunction kplus) { int nrow = getNumRows(); int ncol = getNumCols(); KahanObject kbuff = new KahanObject(result.quickGetValue(0, 0), result.quickGetValue(0, 1)); for( int i=0; i<nrow; i++ ) for( int j=0; j<ncol; j++ ) kplus.execute2(kbuff, getData(i, j)); result.quickSetValue(0, 0, kbuff._sum); result.quickSetValue(0, 1, kbuff._correction); } protected void computeColSums(MatrixBlock result, KahanFunction kplus) { int nrow = getNumRows(); int ncol = getNumCols(); KahanObject[] kbuff = new KahanObject[getNumCols()]; for( int j=0; j<ncol; j++ ) kbuff[j] = new KahanObject(result.quickGetValue(0, _colIndexes[j]), result.quickGetValue(1, _colIndexes[j])); for( int i=0; i<nrow; i++ ) for( int j=0; j<ncol; j++ ) kplus.execute2(kbuff[j], getData(i, j)); for( int j=0; j<ncol; j++ ) { result.quickSetValue(0, _colIndexes[j], kbuff[j]._sum); result.quickSetValue(1, _colIndexes[j], kbuff[j]._correction); } } protected void computeRowSums(MatrixBlock result, KahanFunction kplus, int rl, int ru) { int ncol = getNumCols(); KahanObject kbuff = new KahanObject(0, 0); for( int i=rl; i<ru; i++ ) { kbuff.set(result.quickGetValue(i, 0), result.quickGetValue(i, 1)); for( int j=0; j<ncol; j++ ) kplus.execute2(kbuff, getData(i, j)); result.quickSetValue(i, 0, kbuff._sum); result.quickSetValue(i, 1, kbuff._correction); } } protected void computeRowMxx(MatrixBlock result, Builtin builtin, int rl, int ru) { double[] c = result.getDenseBlock(); int ncol = getNumCols(); for( int i=rl; i<ru; i++ ) for( int j=0; j<ncol; j++ ) c[i] = builtin.execute2(c[i], getData(i, j)); } protected final void postScaling(double[] vals, double[] c) { final int ncol = getNumCols(); final int numVals = getNumValues(); for( int k=0, valOff=0; k<numVals; k++, valOff+=ncol ) { double aval = vals[k]; for( int j=0; j<ncol; j++ ) { int colIx = _colIndexes[j]; c[colIx] += aval * _values[valOff+j]; } } } /** * Generic get value for byte-length-agnostic access. * * @param r global row index * @param colIx local column index * @return value */ protected abstract double getData(int r, int colIx); /** * Generic set value for byte-length-agnostic write * of encoded value. * * @param r global row index * @param code encoded value */ protected abstract void setData(int r, int code); @Override public long estimateInMemorySize() { return super.estimateInMemorySize(); } }