package hex.deepwater;
import hex.DataInfo;
import water.*;
import water.fvec.C4FChunk;
import water.fvec.C8DChunk;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.util.UnsafeUtils;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
class DeepWaterDatasetIterator extends DeepWaterIterator {
DeepWaterDatasetIterator(ArrayList<Integer> rows, ArrayList<Float> labels, DataInfo dinfo, int batch_size, boolean cache) throws IOException {
super(batch_size, dinfo.fullN(), cache);
_rows_lst = rows;
_label_lst = labels;
_dinfo = dinfo;
}
// Row-based storage, dense - direct mapping to float[]
static class IcedRow extends Iced<IcedRow> {
int size() { return _data.length; }
float getVal(int i) { return _data[i]; }
void insertValuesIntoArray(float[] vals, int offset) {
for (int i=0; i<_data.length; ++i) {
vals[offset+i] = _data[i];
}
}
public IcedRow() {}
IcedRow(float[] fs) {
_data = fs;
}
private float[] _data;
}
static class FrameDataConverter extends H2O.H2OCountedCompleter<FrameDataConverter> {
int _index;
int _globalIndex;
DataInfo _dinfo;
float _label;
float[] _destData;
float[] _destLabel;
boolean _cache;
FrameDataConverter(int index, int globalIndex, DataInfo dinfo, float label, float[] destData, float[] destLabel, boolean cache) {
_index=index;
_globalIndex=globalIndex;
_dinfo = dinfo;
_label = label;
_destData=destData;
_destLabel=destLabel;
_cache = cache;
}
@Override
public void compute2() {
_destLabel[_index] = _label;
final int start=_index*_dinfo.fullN();
Key rowKey = Key.make(_dinfo._adaptedFrame._key + "_" + _dinfo.fullN() + "_row_" + Integer.toString(_globalIndex) + "_" + DeepWaterModel.CACHE_MARKER);
boolean status = false;
if (_cache) {
IcedRow icedRow = DKV.getGet(rowKey);
if (icedRow != null) {
icedRow.insertValuesIntoArray(_destData, start);
status = true;
}
}
if (!status) { //only do this the first time
DataInfo.Row row = _dinfo.newDenseRow();
Chunk[] chks = new Chunk[_dinfo._adaptedFrame.numCols()];
for (int i=0;i<chks.length;++i)
chks[i] = _dinfo._adaptedFrame.vec(i).chunkForRow(_globalIndex);
_dinfo.extractDenseRow(chks, _globalIndex-(int)chks[0].start(), row);
for (int i = 0; i< _dinfo.fullN(); ++i)
_destData[start+i] = (float)row.get(i);
// System.err.println("Row: " + _dinfo._adaptedFrame.vec(0).domain()[(int)_dinfo._adaptedFrame.vec(0).at8(_globalIndex)] + " -> " + Arrays.toString(_destData));
// System.err.println(Arrays.toString(Arrays.copyOfRange(_destData, start, start + _dinfo.fullN())));
if (_cache) {
Value v = new Value(rowKey, new IcedRow(Arrays.copyOfRange(_destData, start, start + _dinfo.fullN())));
DKV.put(rowKey, v);
v.freeMem();
}
}
tryComplete();
}
}
public boolean Next(Futures fs) throws IOException {
if (_start_index < _rows_lst.size()) {
if (_start_index + _batch_size > _rows_lst.size())
_start_index = _rows_lst.size() - _batch_size;
// Multi-Threaded data preparation
for (int i = 0; i < _batch_size; i++)
fs.add(H2O.submitTask(new FrameDataConverter(i, _rows_lst.get(_start_index+i), _dinfo, _label_lst==null?-1:_label_lst.get(_start_index + i), _data[which()], _label[which()], _cache)));
fs.blockForPending();
flip();
_start_index += _batch_size;
return true;
} else {
return false;
}
}
final private ArrayList<Integer> _rows_lst;
final private ArrayList<Float> _label_lst;
final private DataInfo _dinfo;
}