package water.rapids.ast.prims.mungers;
import org.apache.commons.lang.ArrayUtils;
import water.*;
import water.fvec.*;
import water.rapids.Val;
import water.rapids.ast.AstBuiltin;
import water.rapids.vals.ValFrame;
import water.util.VecUtils;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import java.util.Arrays;
public class AstPivot extends AstBuiltin<AstPivot> {
@Override
public String[] args() {
return new String[]{"ary", "index", "column", "value"}; //the array and name of columns
}
@Override
public int nargs() {
return 1 + 4;
} // (pivot ary index column value)
@Override
public String str() {
return "pivot";
}
@Override
public ValFrame exec(Val[] args) {
// Distributed parallelized mrtask pivot
// Limitations: a single index value cant have more than chunk size * chunk size number of rows
// or if all rows of a single index value cant fit on a single node (due to the sort call)
Frame fr = args[1].getFrame();
String index = args[2].getStr();
String column = args[3].getStr();
String value = args[4].getStr();
int indexIdx = fr.find(index);
int colIdx = fr.find(column);
if(fr.vec(column).isConst())
throw new IllegalArgumentException("Column: '" + column + "'is constant. Perhaps use transpose?" );
if(fr.vec(index).naCnt() > 0)
throw new IllegalArgumentException("Index column '" + index + "' has > 0 NAs");
// This is the sort then MRTask method.
// Create the target Frame
// Now sort on the index key, result is that unique keys will be localized
Frame fr2 = fr.sort(new int[]{indexIdx});
final long[] classes = new VecUtils.CollectDomain().doAll(fr.vec(colIdx)).domain();
final int nClass = (fr.vec(colIdx).isNumeric() || fr.vec(colIdx).isTime()) ? classes.length : fr.vec(colIdx).domain().length;
String[] header = null;
if (fr.vec(colIdx).isNumeric()) {
header = (String[]) ArrayUtils.addAll(new String[]{index}, Arrays.toString(classes).split("[\\[\\]]")[1].split(", "));
} else if (fr.vec(colIdx).isTime()) {
header = new String[nClass];
for (int i=0;i<nClass;i++) header[i] = (new DateTime(classes[i], DateTimeZone.UTC)).toString();
} else {
header = (String[]) ArrayUtils.addAll(new String[]{index}, fr.vec(colIdx).domain());
}
Frame initialPass = new pivotTask(fr2.find(index),fr2.find(column),fr2.find(value),classes)
.doAll(nClass+1, Vec.T_NUM, fr2)
.outputFrame(null, header, null);
fr2.delete();
Frame result = new Frame(initialPass.vec(0).makeCopy(fr.vec(indexIdx).domain(),fr.vec(indexIdx).get_type()));
result._key = Key.<Frame>make();
result.setNames(new String[]{index});
initialPass.remove(0);
result.add(initialPass);
return new ValFrame(result);
}
private class pivotTask extends MRTask<AstPivot.pivotTask>{
int _indexColIdx;
int _colColIdx;
int _valColIdx;
long[] _classes;
pivotTask(int indexColIdx, int colColIdx, int valColIdx, long[] classes) {
_indexColIdx = indexColIdx; _colColIdx = colColIdx; _valColIdx = valColIdx; _classes=classes;
}
@Override
public void map(Chunk[] cs, NewChunk[] nc) {
// skip past the first rows of the first index if we know that the previous chunk will run in here
long firstIdx = cs[_indexColIdx].at8(0);
long globalIdx = cs[_indexColIdx].start();
int start = 0;
if (globalIdx > 0 && cs[_indexColIdx].vec().at8(globalIdx-1)==firstIdx){
while(start < cs[_indexColIdx].len() && firstIdx == cs[_indexColIdx].at8(start)) start++;
}
for (int i=start; i<cs[_indexColIdx]._len; i++) {
long currentIdx = cs[_indexColIdx].at8(i);
// start with a copy of the current row
double[] newRow = new double[nc.length-1];
Arrays.fill(newRow,Double.NaN);
if (((i == cs[_indexColIdx]._len -1) &&
(cs[_indexColIdx].nextChunk() == null || cs[_indexColIdx].nextChunk() != null && currentIdx != cs[_indexColIdx].nextChunk().at8(0)))
|| (i < cs[_indexColIdx]._len -1 && currentIdx != cs[_indexColIdx].at8(i+1))) {
newRow[ArrayUtils.indexOf(_classes,cs[_colColIdx].at8(i))] = cs[_valColIdx].atd(i);
nc[0].addNum(cs[_indexColIdx].at8(i));
for (int j = 1; j < nc.length; j++) nc[j].addNum(newRow[j - 1]);
// were done here since we know the next row has a different index
continue;
}
// here we know we have to search ahead
int count = 1;
newRow[ArrayUtils.indexOf(_classes,cs[_colColIdx].at8(i))] = cs[_valColIdx].atd(i);
while ( count + i < cs[_indexColIdx]._len && currentIdx == cs[_indexColIdx].at8(i + count) ) {
// merge the forward row, the newRow and the existing row
// here would be a good place to apply aggregating function
// for now we are aggregating by "first"
if (Double.isNaN(newRow[ArrayUtils.indexOf(_classes,cs[_colColIdx].at8(i + count))])) {
newRow[ArrayUtils.indexOf(_classes,cs[_colColIdx].at8(i + count))] = cs[_valColIdx].atd(i + count);
}
count++;
}
// need to look if we need to go to next chunk
if (i + count == cs[_indexColIdx]._len && cs[_indexColIdx].nextChunk() != null) {
Chunk indexNC = cs[_indexColIdx].nextChunk(); // for the index
Chunk colNC = cs[_colColIdx].nextChunk(); // for the rest of the columns
Chunk valNC = cs[_valColIdx].nextChunk(); // for the rest of the columns
int countNC = 0;
// If we reach the end of the chunk, we'll update nextChunk and nextChunkArr
while (indexNC != null && countNC < indexNC._len && currentIdx == indexNC.at8(countNC)) {
if (Double.isNaN(newRow[ArrayUtils.indexOf(_classes, colNC.at8(countNC))])) {
newRow[(int) colNC.atd(countNC)] = valNC.atd(countNC);
}
}
countNC++;
if (countNC == indexNC._len) { // go to the next chunk again
indexNC = indexNC.nextChunk();
colNC = colNC.nextChunk();
valNC = valNC.nextChunk();
countNC = 0;
}
}
nc[0].addNum(currentIdx);
for (int j = 1; j < nc.length; j++) {
nc[j].addNum(newRow[j - 1]);
}
i += (count - 1);
}
}
}
}