package water.fvec;
import water.*;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.IcedLong;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
/**
*
* This class represents an interaction between two Vec instances.
*
* Another virtual Vec (akin to TransformWrappedVec) used to represent a
* never-materialized interaction between two columns.
*
* There are 3 types of interactions to consider: Num-Num, Num-Enum, and Enum-Enums
* Creation of these Vec instances is cheap except for the Enum-Enum case (since it is not
* known apriori the co-occurrence of enums between any two categorical columns). So in
* this specific case, an MRTask is done to collect the domain.
*
*
* @author spencer
*/
public class InteractionWrappedVec extends WrappedVec {
private final Key<Vec> _masterVecKey1;
private final Key<Vec> _masterVecKey2;
private transient Vec _masterVec1;
private transient Vec _masterVec2;
private String[] _v1Domain;
private String[] _v2Domain;
public boolean _useAllFactorLevels;
public boolean _skipMissing;
public boolean _standardize;
private long[] _bins;
private String[] _missingDomains;
public transient GetMeanTask t;
private String[] _v1Enums; // only interact these enums from vec 1
private String[] _v2Enums; // only interact these enums from vec 2
public InteractionWrappedVec(Key<Vec> key, int rowLayout, String[] vec1DomainLimit, String[] vec2DomainLimit, boolean useAllFactorLevels, boolean skipMissing, boolean standardize, Key<Vec> masterVecKey1, Key<Vec> masterVecKey2) {
super(key, rowLayout, null);
_masterVecKey1=masterVecKey1;
_masterVecKey2=masterVecKey2;
_v1Enums=vec1DomainLimit;
_v2Enums=vec2DomainLimit;
_masterVec1=_masterVecKey1.get();
_masterVec2=_masterVecKey2.get();
_useAllFactorLevels=useAllFactorLevels;
_skipMissing=skipMissing;
setupDomain(_standardize=standardize); // performs MRTask if both vecs are categorical!!
DKV.put(this);
if (t != null) t.doAll(this);
}
public String[] v1Domain() { return _v1Enums==null?_v1Domain:_v1Enums; }
public String[] v2Domain() { return _v2Enums==null?_v2Domain:_v2Enums; }
@Override public String[] domain() { // always returns the "correct" domains, so accidental mixup of domain vs domains is ok
String[] res1 = v1Domain();
String[] res2 = v2Domain();
return res1 == null? res2 : res2 == null? res1 : super.domain();
}
public Vec v1() { return _masterVec1==null?(_masterVec1=_masterVecKey1.get()):_masterVec1; }
public Vec v2() { return _masterVec2==null?(_masterVec2=_masterVecKey2.get()):_masterVec2; }
/**
* Obtain the length of the expanded (i.e. one-hot expanded) interaction column.
*/
public int expandedLength() {
if( _v1Domain==null && _v2Domain==null ) return 1; // 2 numeric columns -> 1 column
else if( isCategorical() ) return domain().length; // 2 cat -> domains (limited) length
else if( _v1Domain!=null ) return _v1Enums==null?_v1Domain.length - (_useAllFactorLevels?0:1):_v1Enums.length-(_useAllFactorLevels?0:1);
else return _v2Enums==null?_v2Domain.length - (_useAllFactorLevels?0:1):_v2Enums.length - (_useAllFactorLevels?0:1);
}
public double[] getMeans() {
if( null!=_v1Domain && null!=_v2Domain ) {
double[] res = new double[domain().length];
Arrays.fill(res,Double.NaN);
return res;
} else if( null==_v1Domain && null==_v2Domain ) return new double[]{super.mean()};
return new GetMeanTask(v1Domain()==null?v2Domain().length:v1Domain().length).doAll(this)._d;
}
public double getSub(int i) {
if (t == null) return mean();
return t._d[i];
}
public double getMul(int i) {
double sigma = (t == null)? sigma() : t._sigma[i];
return sigma == 0? 1.0 : 1.0/sigma;
}
private static class GetMeanTask extends MRTask<GetMeanTask> {
private double[] _d; // means, NA skipped
private double[] _sigma; // sds, NA skipped
private long _rows;
private final int _len;
GetMeanTask(int len) { _len=len; }
@Override public void map(Chunk c) {
_d = new double[_len];
_sigma = new double[_len];
InteractionWrappedChunk cc = (InteractionWrappedChunk)c;
Chunk lC = cc._c[0]; Chunk rC = cc._c[1]; // get the "left" chk and the "rite" chk
if( cc._c2IsCat ) { lC=rC; rC=cc._c[0]; } // left is always cat
long rows=0;
for(int rid=0;rid<c._len;++rid) {
if( lC.isNA(rid) || rC.isNA(rid) ) continue; // skipmissing
int idx = (int)lC.at8(rid);
rows++;
for(int i=0;i<_d.length;++i) {
double x = i==idx?rC.atd(rid):0;
double delta = x - _d[i];
_d[i] += delta / rows;
_sigma[i] += delta * (x - _d[i]);
}
}
_rows=rows;
}
@Override public void reduce(GetMeanTask t) {
if (_rows == 0) { _d = t._d; _sigma = t._sigma; }
else if(t._rows != 0){
for(int i=0;i<_d.length;++i) {
double delta = _d[i] - t._d[i];
_d[i] = (_d[i]* _rows + t._d[i] * t._rows) / (_rows + t._rows);
_sigma[i] += t._sigma[i] + delta * delta * _rows * t._rows / (_rows + t._rows);
}
}
_rows += t._rows;
}
@Override public void postGlobal() {
for(int i=0;i<_sigma.length;++i )
_sigma[i] = Math.sqrt(_sigma[i]/(_rows-1));
}
}
@Override public double mean() {
if( null==t && null==v1Domain() && null==v2Domain() )
return super.mean();
return 0;
}
@Override public double sigma() {
if( null==t && null==v1Domain() && null==v2Domain() )
return super.sigma();
return 1;
}
@Override public int mode() {
if( !isCategorical() ) throw H2O.unimpl();
return ArrayUtils.maxIndex(_bins);
}
public long[] getBins() { return _bins; }
public String[] missingDomains() { return _missingDomains; }
private void setupDomain(boolean standardize) {
if( _masterVec1.isCategorical() || _masterVec2.isCategorical() ) {
_v1Domain = _masterVec1.domain();
_v2Domain = _masterVec2.domain();
if( _v1Domain!=null && _v2Domain!=null ) {
CombineDomainTask t =new CombineDomainTask(_v1Domain, _v2Domain,_v1Enums,_v2Enums, _useAllFactorLevels,_skipMissing).doAll(_masterVec1, _masterVec2);
setDomain(t._dom);
_bins=t._bins;
_type = Vec.T_CAT; // vec is T_NUM up to this point
_missingDomains=t._missingDom;
} else
t = standardize?new GetMeanTask(v1Domain()==null?v2Domain().length:v1Domain().length):null;
}
if( null==_v1Domain && null==_v2Domain ) _useAllFactorLevels=true; // just makes life easier to have this when the vec is categorical
}
private static class CombineDomainTask extends MRTask<CombineDomainTask> {
private String[] _dom; // out, sorted (uses Arrays.sort)
private long[] _bins; // out, sorted according to _dom
private String[] _missingDom; // out, the missing levels due to !_useAllLvls
private final String _left[]; // in
private final String _rite[]; // in
private final String _leftLimit[]; // in
private final String _riteLimit[]; // in
private final boolean _useAllLvls; // in
private final boolean _skipMissing; // in
private IcedHashMap<String, IcedLong> _perChkMap;
private IcedHashMap<String, String> _perChkMapMissing; // skipped cats
CombineDomainTask(String[] left, String[] rite, String[] leftLimit, String[] riteLimit, boolean useAllLvls, boolean skipMissing) {
_left = left;
_rite = rite;
_leftLimit = leftLimit;
_riteLimit = riteLimit;
_useAllLvls = useAllLvls;
_skipMissing = skipMissing;
}
@Override public void map(Chunk[] c) {
_perChkMap = new IcedHashMap<>();
if( !_useAllLvls ) _perChkMapMissing = new IcedHashMap<>();
Chunk left = c[0];
Chunk rite = c[1];
String k;
HashSet<String> A = _leftLimit == null ? null : new HashSet<String>();
HashSet<String> B = _riteLimit == null ? null : new HashSet<String>();
if (A != null) Collections.addAll(A, _leftLimit);
if (B != null) Collections.addAll(B, _riteLimit);
int lval,rval;
String l,r;
boolean leftIsNA, riteIsNA;
for (int i = 0; i < left._len; ++i)
if( (!((leftIsNA=left.isNA(i)) | (riteIsNA=rite.isNA(i)))) ) {
lval = (int)left.at8(i);
rval = (int)rite.at8(i);
if( !_useAllLvls && ( 0==lval || 0==rval )) {
_perChkMapMissing.putIfAbsent(_left[lval] + "_" + _rite[rval],"");
continue;
}
l = _left[lval];
r = _rite[rval];
if (A != null && !A.contains(l)) continue;
if (B != null && !B.contains(r)) continue;
if( null!=_perChkMap.putIfAbsent((k = l + "_" + r), new IcedLong(1)) )
_perChkMap.get(k)._val++;
} else if( !_skipMissing ) {
if( !(leftIsNA && riteIsNA) ) { // not both missing
if( leftIsNA ) {
r = _rite[rval=(int)rite.at8(i)];
if( !_useAllLvls && 0==rval ) {
_perChkMapMissing.putIfAbsent("NA_" + _rite[rval],"");
continue;
}
if( B!=null && !B.contains(r) ) continue;
if( null!=_perChkMap.putIfAbsent((k="NA_"+r), new IcedLong(1)) )
_perChkMap.get(k)._val++;
} else {
l = _left[lval=(int)left.at8(i)];
if( !_useAllLvls && 0==lval ) {
_perChkMapMissing.putIfAbsent(_left[lval] + "_NA","");
continue;
}
if( null!=A && !A.contains(l) ) continue;
if( null!=_perChkMap.putIfAbsent((k=l+"_NA"), new IcedLong(1)) )
_perChkMap.get(k)._val++;
}
}
}
}
@Override public void reduce(CombineDomainTask t) {
for (Map.Entry<String, IcedLong> e : t._perChkMap.entrySet()) {
IcedLong i = _perChkMap.get(e.getKey());
if (i != null) i._val += e.getValue()._val;
else _perChkMap.put(e.getKey(), e.getValue());
}
t._perChkMap = null;
if(_perChkMapMissing==null && t._perChkMapMissing!=null ) {
_perChkMapMissing=new IcedHashMap<>();
_perChkMapMissing.putAll(t._perChkMapMissing);
}
else if( _perChkMapMissing!=null && t._perChkMapMissing!=null ) {
for (String s: t._perChkMapMissing.keySet())
_perChkMapMissing.putIfAbsent(s,"");
}
t._perChkMapMissing=null;
}
@Override public void postGlobal() {
Arrays.sort(_dom = _perChkMap.keySet().toArray(new String[_perChkMap.size()]));
int idx = 0;
_bins = new long[_perChkMap.size()];
for(String s:_dom)
_bins[idx++] = _perChkMap.get(s)._val;
if( _missingDom!=null )
Arrays.sort(_missingDom = _perChkMapMissing.keySet().toArray(new String[_perChkMapMissing.size()]));
}
}
@Override public Chunk chunkForChunkIdx(int cidx) {
Chunk[] cs = new Chunk[2];
cs[0] = (_masterVec1!=null?_masterVec1: (_masterVec1=_masterVecKey1.get())).chunkForChunkIdx(cidx);
cs[1] = (_masterVec2!=null?_masterVec2: (_masterVec2=_masterVecKey2.get())).chunkForChunkIdx(cidx);
return new InteractionWrappedChunk(this, cs);
}
@Override public Vec doCopy() {
InteractionWrappedVec v = new InteractionWrappedVec(group().addVec(), _rowLayout,_v1Enums,_v2Enums, _useAllFactorLevels, _skipMissing, _standardize, _masterVecKey1, _masterVecKey2);
if( null!=domain() ) v.setDomain(domain());
if( null!=_v1Domain ) v._v1Domain=_v1Domain.clone();
if( null!=_v2Domain ) v._v2Domain=_v2Domain.clone();
return v;
}
@Override protected AutoBuffer writeAll_impl(AutoBuffer ab) {
ab.putAStr(_v1Domain);
ab.putAStr(_v2Domain);
ab.putZ(_useAllFactorLevels);
ab.putZ(_skipMissing);
ab.putZ(_standardize);
ab.putAStr(_missingDomains);
ab.putAStr(_v1Enums);
ab.putAStr(_v2Enums);
return super.writeAll_impl(ab);
}
@Override protected Keyed readAll_impl(AutoBuffer ab, Futures fs) {
_v1Domain=ab.getAStr();
_v2Domain=ab.getAStr();
_useAllFactorLevels=ab.getZ();
_skipMissing=ab.getZ();
_standardize=ab.getZ();
_missingDomains=ab.getAStr();
_v1Enums=ab.getAStr();
_v2Enums=ab.getAStr();
return super.readAll_impl(ab,fs);
}
public static class InteractionWrappedChunk extends Chunk {
public final transient Chunk[] _c;
public final boolean _c1IsCat; // left chunk is categorical
public final boolean _c2IsCat; // rite chunk is categorical
public final boolean _isCat; // this vec is categorical
InteractionWrappedChunk(InteractionWrappedVec transformWrappedVec, Chunk[] c) {
// set all the chunk fields
_c = c; set_len(_c[0]._len);
_start = _c[0]._start; _vec = transformWrappedVec; _cidx = _c[0]._cidx;
_c1IsCat=_c[0]._vec.isCategorical();
_c2IsCat=_c[1]._vec.isCategorical();
_isCat = _vec.isCategorical();
}
@Override public double atd_impl(int idx) {
if( _isCat )
if( isNA_impl(idx) ) return Double.NaN;
return _isCat ? Arrays.binarySearch(_vec.domain(), getKey(idx)) : ( _c1IsCat?1: (_c[0].atd(idx))) * ( _c2IsCat?1: (_c[1].atd(idx)) );
}
@Override public long at8_impl(int idx) { return _isCat ? Arrays.binarySearch(_vec.domain(), getKey(idx)) : ( _c1IsCat?1:_c[0].at8(idx) ) * ( _c2IsCat?1:_c[1].at8(idx) ); }
private String getKey(int idx) { return _c[0]._vec.domain()[(int)_c[0].at8(idx)] + "_" + _c[1]._vec.domain()[(int)_c[1].at8(idx)]; }
@Override public boolean isNA_impl(int idx) { return _c[0].isNA(idx) || _c[1].isNA(idx); }
// Returns true if the masterVec is missing, false otherwise
@Override public boolean set_impl(int idx, long l) { return false; }
@Override public boolean set_impl(int idx, double d) { return false; }
@Override public boolean set_impl(int idx, float f) { return false; }
@Override public boolean setNA_impl(int idx) { return false; }
@Override public ChunkVisitor processRows(ChunkVisitor nc, int from, int to){
for(int i = from; i < to; i++)
nc.addValue(atd(i));
return nc;
}
@Override public ChunkVisitor processRows(ChunkVisitor nc, int... rows){
for(int i:rows)
nc.addValue(atd(i));
return nc;
}
@Override protected final void initFromBytes () { throw water.H2O.fail(); }
}
}