package water.rapids.ast.prims.advmath; import water.H2O; import water.MRTask; import water.fvec.*; import water.rapids.Env; import water.rapids.Val; import water.rapids.ast.AstBuiltin; import water.rapids.ast.AstPrimitive; import water.rapids.ast.AstRoot; import water.rapids.vals.ValFrame; import water.util.ArrayUtils; import water.util.Log; import java.util.Arrays; /** * Calculate Distance Metric between pairs of rows */ public class AstDistance extends AstBuiltin<AstDistance> { @Override public String[] args() { return new String[]{"ary", "x", "y", "measure"}; } @Override public int nargs() { return 1 + 3; /* (distance X Y measure) */ } @Override public String str() { return "distance"; } @Override public String description() { return "Compute a pairwise distance measure between all rows of two numeric H2OFrames.\n" + "For a given (usually larger) reference frame (N rows x p cols),\n" + "and a (usually smaller) query frame (M rows x p cols), we return a numeric Frame of size (N rows x M cols),\n" + "where the ij-th element is the distance measure between the i-th reference row and the j-th query row.\n" + "Note1: The output frame is symmetric.\n" + "Note2: Since N x M can be very large, it may be more efficient (memory-wise) to make multiple calls with smaller query Frames."; } @Override public Val apply(Env env, Env.StackHelp stk, AstRoot asts[]) { Frame frx = stk.track(asts[1].exec(env)).getFrame(); Frame fry = stk.track(asts[2].exec(env)).getFrame(); String measure = stk.track(asts[3].exec(env)).getStr(); return computeCosineDistances(frx, fry, measure); } public Val computeCosineDistances(Frame references, Frame queries, String distanceMetric) { Log.info("Number of references: " + references.numRows()); Log.info("Number of queries : " + queries.numRows()); String[] options = new String[]{"cosine","cosine_sq","l1","l2"}; if (!ArrayUtils.contains(options, distanceMetric.toLowerCase())) throw new IllegalArgumentException("Invalid distance measure provided: " + distanceMetric + ". Mustbe one of " + Arrays.toString(options)); if (references.numRows() * queries.numRows() * 8 > H2O.CLOUD.free_mem() ) throw new IllegalArgumentException("Not enough free memory to allocate the distance matrix (" + references.numRows() + " rows and " + queries.numRows() + " cols. Try specifying a smaller query frame."); if (references.numCols() != queries.numCols()) throw new IllegalArgumentException("Frames must have the same number of cols, found " + references.numCols() + " and " + queries.numCols()); if (queries.numRows() > Integer.MAX_VALUE) throw new IllegalArgumentException("Queries can't be larger than 2 billion rows."); if (queries.numCols() != references.numCols()) throw new IllegalArgumentException("Queries and References must have the same dimensionality"); for (int i=0;i<queries.numCols();++i) { if (!references.vec(i).isNumeric()) throw new IllegalArgumentException("References column " + references.name(i) + " is not numeric."); if (!queries.vec(i).isNumeric()) throw new IllegalArgumentException("Queries column " + references.name(i) + " is not numeric."); if (references.vec(i).naCnt()>0) throw new IllegalArgumentException("References column " + references.name(i) + " contains missing values."); if (queries.vec(i).naCnt()>0) throw new IllegalArgumentException("Queries column " + references.name(i) + " contains missing values."); } return new ValFrame(new DistanceComputer(queries, distanceMetric).doAll((int)queries.numRows(), Vec.T_NUM, references).outputFrame()); } static public class DistanceComputer extends MRTask<DistanceComputer> { Frame _queries; String _measure; DistanceComputer(Frame queries, String measure) { _queries = queries; _measure = measure; } @Override public void map(Chunk[] cs, NewChunk[] ncs) { int p = cs.length; //dimensionality int Q = (int) _queries.numRows(); int R = cs[0]._len; Vec.Reader[] Qs = new Vec.Reader[p]; for (int i = 0; i < p; ++i) { Qs[i] = _queries.vec(i).new Reader(); } double[] denomR = null; double[] denomQ = null; final boolean cosine = _measure.toLowerCase().equals("cosine"); final boolean cosine_sq = _measure.toLowerCase().equals("cosine_sq"); final boolean l1 = _measure.toLowerCase().equals("l1"); final boolean l2 = _measure.toLowerCase().equals("l2"); if (cosine || cosine_sq) { denomR = new double[R]; denomQ = new double[Q]; for (int r = 0; r < R; ++r) { // Reference row (chunk-local) for (int c = 0; c < p; ++c) { //cols denomR[r] += Math.pow(cs[c].atd(r), 2); } } for (int q = 0; q < Q; ++q) { // Query row (global) for (int c = 0; c < p; ++c) { //cols denomQ[q] += Math.pow(Qs[c].at(q), 2); } } } for (int r = 0; r < cs[0]._len; ++r) { // Reference row (chunk-local) for (int q = 0; q < Q; ++q) { // Query row (global) double distRQ = 0; if (l1) { for (int c = 0; c < p; ++c) { //cols distRQ += Math.abs(cs[c].atd(r) - Qs[c].at(q)); } } else if (l2) { for (int c = 0; c < p; ++c) { //cols distRQ += Math.pow(cs[c].atd(r) - Qs[c].at(q), 2); } } else if (cosine || cosine_sq) { for (int c = 0; c < p; ++c) { //cols distRQ += cs[c].atd(r) * Qs[c].at(q); } if (cosine_sq) { distRQ *= distRQ; distRQ /= denomR[r] * denomQ[q]; } else { distRQ /= Math.sqrt(denomR[r] * denomQ[q]); } } ncs[q].addNum(distRQ); // one Q distance per Reference } } } } }