/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.sysml.runtime.functionobjects; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.instructions.cp.CM_COV_Object; import org.apache.sysml.runtime.instructions.cp.Data; import org.apache.sysml.runtime.instructions.cp.KahanObject; import org.apache.sysml.runtime.matrix.operators.CMOperator.AggregateOperationTypes; /** * GENERAL NOTE: * * 05/28/2014: We decided to do handle weights consistently to SPSS in an operation-specific manner, * i.e., we (1) round instead of casting where required (e.g. count), and (2) consistently use * fractional weight values elsewhere. In case a count-base interpretation of weights is needed, just * ensure rounding before calling CM/COV/KahanPlus. * */ public class CM extends ValueFunction { private static final long serialVersionUID = 9177194651533064123L; private AggregateOperationTypes _type = null; //helper function objects for specific types private KahanPlus _plus = null; private KahanObject _buff2 = null; private KahanObject _buff3 = null; private CM( AggregateOperationTypes type ) { _type = type; switch( _type ) //helper obj on demand { case COUNT: break; case CM4: case CM3: _buff3 = new KahanObject(0, 0); case CM2: _buff2 = new KahanObject(0, 0); case VARIANCE: case MEAN: _plus = KahanPlus.getKahanPlusFnObject(); break; default: //do nothing } } public static CM getCMFnObject( AggregateOperationTypes type ) { //return new obj, required for correctness in multi-threaded //execution due to state in cm object (buff2, buff3) return new CM( type ); } public AggregateOperationTypes getAggOpType() { return _type; } /** * Special case for weights w2==1 */ @Override public Data execute(Data in1, double in2) throws DMLRuntimeException { CM_COV_Object cm1=(CM_COV_Object) in1; if(cm1.isCMAllZeros()) { cm1.w=1; cm1.mean.set(in2, 0); cm1.m2.set(0,0); cm1.m3.set(0,0); cm1.m4.set(0,0); return cm1; } switch( _type ) { case COUNT: { cm1.w = cm1.w + 1; break; } case MEAN: { double w= cm1.w + 1; double d=in2-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w); cm1.w=w; break; } case CM2: { double w= cm1.w + 1; double d=in2-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w); double t1=cm1.w/w*d; double lt1=t1*d; _buff2.set(cm1.m2); _buff2=(KahanObject) _plus.execute(_buff2, lt1); cm1.m2.set(_buff2); cm1.w=w; break; } case CM3: { double w = cm1.w + 1; double d=in2-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w); double t1=cm1.w/w*d; double t2=-1/cm1.w; double lt1=t1*d; double lt2=Math.pow(t1, 3)*(1.0-Math.pow(t2, 2)); double f2=1.0/w; _buff2.set(cm1.m2); _buff2=(KahanObject) _plus.execute(_buff2, lt1); _buff3.set(cm1.m3); _buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d); cm1.m2.set(_buff2); cm1.m3.set(_buff3); cm1.w=w; break; } case CM4: { double w=cm1.w+1; double d=in2-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w); double t1=cm1.w/w*d; double t2=-1/cm1.w; double lt1=t1*d; double lt2=Math.pow(t1, 3)*(1.0-Math.pow(t2, 2)); double lt3=Math.pow(t1, 4)*(1.0-Math.pow(t2, 3)); double f2=1.0/w; _buff2.set(cm1.m2); _buff2=(KahanObject) _plus.execute(_buff2, lt1); _buff3.set(cm1.m3); _buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d); cm1.m4=(KahanObject) _plus.execute(cm1.m4, 6*cm1.m2._sum*Math.pow(-f2*d, 2) + lt3-4*cm1.m3._sum*f2*d); cm1.m2.set(_buff2); cm1.m3.set(_buff3); cm1.w=w; break; } case VARIANCE: { double w=cm1.w+1; double d=in2-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w); double t1=cm1.w/w*d; double lt1=t1*d; cm1.m2=(KahanObject) _plus.execute(cm1.m2, lt1); cm1.w=w; break; } default: throw new DMLRuntimeException("Unsupported operation type: "+_type); } return cm1; } /** * General case for arbitrary weights w2 */ @Override public Data execute(Data in1, double in2, double w2) throws DMLRuntimeException { CM_COV_Object cm1=(CM_COV_Object) in1; if(cm1.isCMAllZeros()) { cm1.w=w2; cm1.mean.set(in2, 0); cm1.m2.set(0,0); cm1.m3.set(0,0); cm1.m4.set(0,0); return cm1; } switch( _type ) { case COUNT: { cm1.w = Math.round(cm1.w + w2); break; } case MEAN: { double w = cm1.w + w2; double d=in2-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w); cm1.w=w; break; } case CM2: { double w = cm1.w + w2; double d=in2-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w); double t1=cm1.w*w2/w*d; double lt1=t1*d; _buff2.set(cm1.m2); _buff2=(KahanObject) _plus.execute(_buff2, lt1); cm1.m2.set(_buff2); cm1.w=w; break; } case CM3: { double w = cm1.w + w2; double d=in2-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w); double t1=cm1.w*w2/w*d; double t2=-1/cm1.w; double lt1=t1*d; double lt2=Math.pow(t1, 3)*(1/Math.pow(w2, 2)-Math.pow(t2, 2)); double f2=w2/w; _buff2.set(cm1.m2); _buff2=(KahanObject) _plus.execute(_buff2, lt1); _buff3.set(cm1.m3); _buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d); cm1.m2.set(_buff2); cm1.m3.set(_buff3); cm1.w=w; break; } case CM4: { double w = cm1.w + w2; double d=in2-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w); double t1=cm1.w*w2/w*d; double t2=-1/cm1.w; double lt1=t1*d; double lt2=Math.pow(t1, 3)*(1/Math.pow(w2, 2)-Math.pow(t2, 2)); double lt3=Math.pow(t1, 4)*(1/Math.pow(w2, 3)-Math.pow(t2, 3)); double f2=w2/w; _buff2.set(cm1.m2); _buff2=(KahanObject) _plus.execute(_buff2, lt1); _buff3.set(cm1.m3); _buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d); cm1.m4=(KahanObject) _plus.execute(cm1.m4, 6*cm1.m2._sum*Math.pow(-f2*d, 2) + lt3-4*cm1.m3._sum*f2*d); cm1.m2.set(_buff2); cm1.m3.set(_buff3); cm1.w=w; break; } case VARIANCE: { double w = cm1.w + w2; double d=in2-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w); double t1=cm1.w*w2/w*d; double lt1=t1*d; cm1.m2=(KahanObject) _plus.execute(cm1.m2, lt1); cm1.w=w; break; } default: throw new DMLRuntimeException("Unsupported operation type: "+_type); } return cm1; } /** * Combining stats from two partitions of the data. */ @Override public Data execute(Data in1, Data in2) throws DMLRuntimeException { CM_COV_Object cm1=(CM_COV_Object) in1; CM_COV_Object cm2=(CM_COV_Object) in2; if(cm1.isCMAllZeros()) { cm1.w=cm2.w; cm1.mean.set(cm2.mean); cm1.m2.set(cm2.m2); cm1.m3.set(cm2.m3); cm1.m4.set(cm2.m4); return cm1; } if(cm2.isCMAllZeros()) return cm1; switch( _type ) { case COUNT: { cm1.w = Math.round(cm1.w + cm2.w); break; } case MEAN: { double w = cm1.w + cm2.w; double d=cm2.mean._sum-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w); cm1.w=w; break; } case CM2: { double w = cm1.w + cm2.w; double d=cm2.mean._sum-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w); double t1=cm1.w*cm2.w/w*d; double lt1=t1*d; _buff2.set(cm1.m2); _buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction); _buff2=(KahanObject) _plus.execute(_buff2, lt1); cm1.m2.set(_buff2); cm1.w=w; break; } case CM3: { double w = cm1.w + cm2.w; double d=cm2.mean._sum-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w); double t1=cm1.w*cm2.w/w*d; double t2=-1/cm1.w; double lt1=t1*d; double lt2=Math.pow(t1, 3)*(1/Math.pow(cm2.w, 2)-Math.pow(t2, 2)); double f1=cm1.w/w; double f2=cm2.w/w; _buff2.set(cm1.m2); _buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction); _buff2=(KahanObject) _plus.execute(_buff2, lt1); _buff3.set(cm1.m3); _buff3=(KahanObject) _plus.execute(_buff3, cm2.m3._sum, cm2.m3._correction); _buff3=(KahanObject) _plus.execute(_buff3, 3*(-f2*cm1.m2._sum+f1*cm2.m2._sum)*d + lt2); cm1.m2.set(_buff2); cm1.m3.set(_buff3); cm1.w=w; break; } case CM4: { double w = cm1.w + cm2.w; double d=cm2.mean._sum-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w); double t1=cm1.w*cm2.w/w*d; double t2=-1/cm1.w; double lt1=t1*d; double lt2=Math.pow(t1, 3)*(1/Math.pow(cm2.w, 2)-Math.pow(t2, 2)); double lt3=Math.pow(t1, 4)*(1/Math.pow(cm2.w, 3)-Math.pow(t2, 3)); double f1=cm1.w/w; double f2=cm2.w/w; _buff2.set(cm1.m2); _buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction); _buff2=(KahanObject) _plus.execute(_buff2, lt1); _buff3.set(cm1.m3); _buff3=(KahanObject) _plus.execute(_buff3, cm2.m3._sum, cm2.m3._correction); _buff3=(KahanObject) _plus.execute(_buff3, 3*(-f2*cm1.m2._sum+f1*cm2.m2._sum)*d + lt2); cm1.m4=(KahanObject) _plus.execute(cm1.m4, cm2.m4._sum, cm2.m4._correction); cm1.m4=(KahanObject) _plus.execute(cm1.m4, 4*(-f2*cm1.m3._sum+f1*cm2.m3._sum)*d + 6*(Math.pow(-f2, 2)*cm1.m2._sum+Math.pow(f1, 2)*cm2.m2._sum)*Math.pow(d, 2) + lt3); cm1.m2.set(_buff2); cm1.m3.set(_buff3); cm1.w=w; break; } case VARIANCE: { double w = cm1.w + cm2.w; double d=cm2.mean._sum-cm1.mean._sum; cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w); double t1=cm1.w*cm2.w/w*d; double lt1=t1*d; cm1.m2=(KahanObject) _plus.execute(cm1.m2, cm2.m2._sum, cm2.m2._correction); cm1.m2=(KahanObject) _plus.execute(cm1.m2, lt1); cm1.w=w; break; } default: throw new DMLRuntimeException("Unsupported operation type: "+_type); } return cm1; } }