/**
* (C) Copyright IBM Corp. 2010, 2015
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package com.ibm.bi.dml.runtime.functionobjects;
import com.ibm.bi.dml.runtime.DMLRuntimeException;
import com.ibm.bi.dml.runtime.instructions.cp.CM_COV_Object;
import com.ibm.bi.dml.runtime.instructions.cp.Data;
import com.ibm.bi.dml.runtime.instructions.cp.KahanObject;
import com.ibm.bi.dml.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
/**
* GENERAL NOTE:
* * 05/28/2014: We decided to do handle weights consistently to SPSS in an operation-specific manner,
* i.e., we (1) round instead of casting where required (e.g. count), and (2) consistently use
* fractional weight values elsewhere. In case a count-base interpretation of weights is needed, just
* ensure rounding before calling CM/COV/KahanPlus.
*
*/
public class CM extends ValueFunction
{
private static final long serialVersionUID = 9177194651533064123L;
private AggregateOperationTypes _type = null;
//helper function objects for specific types
private KahanPlus _plus = null;
private KahanObject _buff2 = null;
private KahanObject _buff3 = null;
private CM( AggregateOperationTypes type )
{
_type = type;
switch( _type ) //helper obj on demand
{
case COUNT:
break;
case CM4:
case CM3:
_buff3 = new KahanObject(0, 0);
case CM2:
case VARIANCE:
_buff2 = new KahanObject(0, 0);
case MEAN:
_plus = KahanPlus.getKahanPlusFnObject();
break;
default:
//do nothing
}
}
public static CM getCMFnObject( AggregateOperationTypes type ) {
//return new obj, required for correctness in multi-threaded
//execution due to state in cm object (buff2, buff3)
return new CM( type );
}
public Object clone() throws CloneNotSupportedException {
// cloning is not supported for singleton classes
throw new CloneNotSupportedException();
}
/**
* Special case for weights w2==1
*/
@Override
public Data execute(Data in1, double in2)
throws DMLRuntimeException
{
CM_COV_Object cm1=(CM_COV_Object) in1;
if(cm1.isCMAllZeros())
{
cm1.w=1;
cm1.mean.set(in2, 0);
cm1.m2.set(0,0);
cm1.m3.set(0,0);
cm1.m4.set(0,0);
return cm1;
}
switch( _type )
{
case COUNT:
{
cm1.w = cm1.w + 1;
break;
}
case MEAN:
{
double w= cm1.w + 1;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
cm1.w=w;
break;
}
case CM2:
{
double w= cm1.w + 1;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
double t1=cm1.w/w*d;
double lt1=t1*d;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
cm1.m2.set(_buff2);
cm1.w=w;
break;
}
case CM3:
{
double w = cm1.w + 1;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
double t1=cm1.w/w*d;
double t2=-1/cm1.w;
double lt1=t1*d;
double lt2=Math.pow(t1, 3)*(1.0-Math.pow(t2, 2));
double f2=1.0/w;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
_buff3.set(cm1.m3);
_buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d);
cm1.m2.set(_buff2);
cm1.m3.set(_buff3);
cm1.w=w;
break;
}
case CM4:
{
double w=cm1.w+1;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
double t1=cm1.w/w*d;
double t2=-1/cm1.w;
double lt1=t1*d;
double lt2=Math.pow(t1, 3)*(1.0-Math.pow(t2, 2));
double lt3=Math.pow(t1, 4)*(1.0-Math.pow(t2, 3));
double f2=1.0/w;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
_buff3.set(cm1.m3);
_buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d);
cm1.m4=(KahanObject) _plus.execute(cm1.m4, 6*cm1.m2._sum*Math.pow(-f2*d, 2) + lt3-4*cm1.m3._sum*f2*d);
cm1.m2.set(_buff2);
cm1.m3.set(_buff3);
cm1.w=w;
break;
}
case VARIANCE:
{
double w=cm1.w+1;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
double t1=cm1.w/w*d;
double lt1=t1*d;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
cm1.m2.set(_buff2);
cm1.w=w;
break;
}
default:
throw new DMLRuntimeException("Unsupported operation type: "+_type);
}
return cm1;
}
/**
* General case for arbitrary weights w2
*/
@Override
public Data execute(Data in1, double in2, double w2)
throws DMLRuntimeException
{
CM_COV_Object cm1=(CM_COV_Object) in1;
if(cm1.isCMAllZeros())
{
cm1.w=w2;
cm1.mean.set(in2, 0);
cm1.m2.set(0,0);
cm1.m3.set(0,0);
cm1.m4.set(0,0);
return cm1;
}
switch( _type )
{
case COUNT:
{
cm1.w = Math.round(cm1.w + w2);
break;
}
case MEAN:
{
double w = cm1.w + w2;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
cm1.w=w;
break;
}
case CM2:
{
double w = cm1.w + w2;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
double t1=cm1.w*w2/w*d;
double lt1=t1*d;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
cm1.m2.set(_buff2);
cm1.w=w;
break;
}
case CM3:
{
double w = cm1.w + w2;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
double t1=cm1.w*w2/w*d;
double t2=-1/cm1.w;
double lt1=t1*d;
double lt2=Math.pow(t1, 3)*(1/Math.pow(w2, 2)-Math.pow(t2, 2));
double f2=w2/w;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
_buff3.set(cm1.m3);
_buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d);
cm1.m2.set(_buff2);
cm1.m3.set(_buff3);
cm1.w=w;
break;
}
case CM4:
{
double w = cm1.w + w2;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
double t1=cm1.w*w2/w*d;
double t2=-1/cm1.w;
double lt1=t1*d;
double lt2=Math.pow(t1, 3)*(1/Math.pow(w2, 2)-Math.pow(t2, 2));
double lt3=Math.pow(t1, 4)*(1/Math.pow(w2, 3)-Math.pow(t2, 3));
double f2=w2/w;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
_buff3.set(cm1.m3);
_buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d);
cm1.m4=(KahanObject) _plus.execute(cm1.m4, 6*cm1.m2._sum*Math.pow(-f2*d, 2) + lt3-4*cm1.m3._sum*f2*d);
cm1.m2.set(_buff2);
cm1.m3.set(_buff3);
cm1.w=w;
break;
}
case VARIANCE:
{
double w = cm1.w + w2;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
double t1=cm1.w*w2/w*d;
double lt1=t1*d;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
cm1.m2.set(_buff2);
cm1.w=w;
break;
}
default:
throw new DMLRuntimeException("Unsupported operation type: "+_type);
}
return cm1;
}
/*
//following the SPSS definition.
public Data execute(Data in1, double in2, double w2) throws DMLRuntimeException {
CMObject cm=(CMObject) in1;
double oldweight=cm._weight;
cm._weight+=w2;
double v=w2/cm._weight*(in2-cm._mean);
cm._mean+=v;
double oldm2=cm._m2;
double oldm3=cm._m3;
double oldm4=cm._m4;
double weightProduct=cm._weight*oldweight;
double vsquare=Math.pow(v, 2);
cm._m2=oldm2+weightProduct/w2*vsquare;
cm._m3=oldm3-3*v*oldm2+weightProduct/Math.pow(w2,2)*(cm._weight-2*w2)*Math.pow(v, 3);
cm._m4=oldm4-4*v*oldm3+6*vsquare*oldm2
+((Math.pow(cm._weight, 2)-3*w2*oldweight)/Math.pow(w2,3))*Math.pow(v, 4)*weightProduct;
return cm;
}*/
@Override
public Data execute(Data in1, Data in2) throws DMLRuntimeException
{
CM_COV_Object cm1=(CM_COV_Object) in1;
CM_COV_Object cm2=(CM_COV_Object) in2;
if(cm1.isCMAllZeros())
{
cm1.w=cm2.w;
cm1.mean.set(cm2.mean);
cm1.m2.set(cm2.m2);
cm1.m3.set(cm2.m3);
cm1.m4.set(cm2.m4);
return cm1;
}
if(cm2.isCMAllZeros())
return cm1;
switch( _type )
{
case COUNT:
{
cm1.w = Math.round(cm1.w + cm2.w);
break;
}
case MEAN:
{
double w = cm1.w + cm2.w;
double d=cm2.mean._sum-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
cm1.w=w;
break;
}
case CM2:
{
double w = cm1.w + cm2.w;
double d=cm2.mean._sum-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
double t1=cm1.w*cm2.w/w*d;
double lt1=t1*d;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
cm1.m2.set(_buff2);
cm1.w=w;
break;
}
case CM3:
{
double w = cm1.w + cm2.w;
double d=cm2.mean._sum-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
double t1=cm1.w*cm2.w/w*d;
double t2=-1/cm1.w;
double lt1=t1*d;
double lt2=Math.pow(t1, 3)*(1/Math.pow(cm2.w, 2)-Math.pow(t2, 2));
double f1=cm1.w/w;
double f2=cm2.w/w;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
_buff3.set(cm1.m3);
_buff3=(KahanObject) _plus.execute(_buff3, cm2.m3._sum, cm2.m3._correction);
_buff3=(KahanObject) _plus.execute(_buff3, 3*(-f2*cm1.m2._sum+f1*cm2.m2._sum)*d + lt2);
cm1.m2.set(_buff2);
cm1.m3.set(_buff3);
cm1.w=w;
break;
}
case CM4:
{
double w = cm1.w + cm2.w;
double d=cm2.mean._sum-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
double t1=cm1.w*cm2.w/w*d;
double t2=-1/cm1.w;
double lt1=t1*d;
double lt2=Math.pow(t1, 3)*(1/Math.pow(cm2.w, 2)-Math.pow(t2, 2));
double lt3=Math.pow(t1, 4)*(1/Math.pow(cm2.w, 3)-Math.pow(t2, 3));
double f1=cm1.w/w;
double f2=cm2.w/w;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
_buff3.set(cm1.m3);
_buff3=(KahanObject) _plus.execute(_buff3, cm2.m3._sum, cm2.m3._correction);
_buff3=(KahanObject) _plus.execute(_buff3, 3*(-f2*cm1.m2._sum+f1*cm2.m2._sum)*d + lt2);
cm1.m4=(KahanObject) _plus.execute(cm1.m4, cm2.m4._sum, cm2.m4._correction);
cm1.m4=(KahanObject) _plus.execute(cm1.m4, 4*(-f2*cm1.m3._sum+f1*cm2.m3._sum)*d
+ 6*(Math.pow(-f2, 2)*cm1.m2._sum+Math.pow(f1, 2)*cm2.m2._sum)*Math.pow(d, 2) + lt3);
cm1.m2.set(_buff2);
cm1.m3.set(_buff3);
cm1.w=w;
break;
}
case VARIANCE:
{
double w = cm1.w + cm2.w;
double d=cm2.mean._sum-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
double t1=cm1.w*cm2.w/w*d;
double lt1=t1*d;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
cm1.m2.set(_buff2);
cm1.w=w;
break;
}
default:
throw new DMLRuntimeException("Unsupported operation type: "+_type);
}
return cm1;
}
/*
private double Q(CMObject cm1, CMObject cm2, int power)
{
return cm1._weight*Math.pow(cm1._mean,power)+cm2._weight*Math.pow(cm2._mean,power);
}
//following the SPSS definition, it is wrong
public Data execute(Data in1, Data in2) throws DMLRuntimeException
{
CMObject cm1=(CMObject) in1;
CMObject cm2=(CMObject) in2;
double w=cm1._weight+cm2._weight;
double q1=cm1._mean*cm1._weight+cm2._mean*cm2._weight;
double mean=q1/w;
double p1=mean-cm1._mean;
double p2=mean-cm2._mean;
double q2=Q(cm1, cm2, 2);
double q3=Q(cm1, cm2, 3);
double q4=Q(cm1, cm2, 4);
double mean2=Math.pow(mean, 2);
double mean3=Math.pow(mean, 3);
double mean4=Math.pow(mean, 4);
double m2 = cm1._m2+cm2._m2 + q2 - 2*mean*q1 + w*mean2;
double m3 = cm1._m3+cm2._m3 - 3*(p1*cm1._m2+p2*cm2._m2)
- 3*mean*(Math.pow(cm1._mean, 2)+Math.pow(cm2._mean, 2)) + 4*q3 - w*mean3;
double m4 = cm1._m4+cm2._m4 - 4*(p1*cm1._m3+p2*cm2._m3) + 6*(Math.pow(p1, 2)*cm1._m2+Math.pow(p2, 2)*cm2._m2)-4*q4-4*mean*q3+6*mean2*q2-4*mean3*q1+2*w*mean4;
cm1._m2=m2;
cm1._m3=m3;
cm1._m4=m4;
cm1._mean=mean;
cm1._weight=w;
return cm1;
}*/
}