package hex.deepwater;
import water.Job;
import water.Key;
import water.MRTask;
import water.fvec.Frame;
/**
* DRemoteTask-based Deep Learning.
* Every node has access to all the training data which leads to optimal CPU utilization and training accuracy IFF the data fits on every node.
*/
class DeepWaterTask2 extends MRTask<DeepWaterTask2> {
/**
* Construct a DeepWaterTask2 where every node trains on the entire training dataset
* @param jobKey Job ID
* @param train Frame containing training data
* @param model_info Initial DeepWaterModelInfo (weights + biases)
* @param sync_fraction Fraction of the training data to use for one SGD iteration
*/
DeepWaterTask2(Key jobKey, Frame train, DeepWaterModelInfo model_info, float sync_fraction, int iteration) {
assert(sync_fraction > 0);
_jobKey = jobKey;
_fr = train;
_sharedmodel = model_info;
_sync_fraction = sync_fraction;
}
/**
* Returns the aggregated DeepWater model that was trained by all nodes (over all the training data)
* @return model_info object
*/
public DeepWaterModelInfo model_info() { return _sharedmodel; }
final private Key _jobKey;
final private Frame _fr;
private DeepWaterModelInfo _sharedmodel;
final private float _sync_fraction;
private DeepWaterTask _res;
/**
* Do the local computation: Perform one DeepWaterTask (with run_local=true) iteration.
* Pass over all the data (will be replicated in dfork() here), and use _sync_fraction random rows.
* This calls DeepWaterTask's reduce() between worker threads that update the same local model_info via Hogwild!
* Once the computation is done, reduce() will be called
*/
@Override
public void setupLocal() {
super.setupLocal();
_res = new DeepWaterTask(_sharedmodel, _sync_fraction, (Job)_jobKey.get());
addToPendingCount(1);
_res.dfork(null, _fr, true /*run_local*/);
}
/**
* Reduce between worker nodes, with network traffic (if greater than 1 nodes)
* After all reduce()'s are done, postGlobal() will be called
* @param drt task to reduce
*/
@Override
public void reduce(DeepWaterTask2 drt) {
if (_res == null) _res = drt._res;
else {
// _res._chunk_node_count += drt._res._chunk_node_count;
_res.model_info().add(drt._res.model_info()); //add models, but don't average yet
}
assert(_res.model_info().get_params()._replicate_training_data);
}
/**
* Finish up the work after all nodes have reduced their models via the above reduce() method.
* All we do is average the models and add to the global training sample counter.
* After this returns, model_info() can be queried for the updated model.
*/
@Override
protected void postGlobal() {
assert(_res.model_info().get_params()._replicate_training_data);
super.postGlobal();
// model averaging (DeepWaterTask only computed the per-node models, each on all the data)
// _res.model_info().div(_res._chunk_node_count);
_res.model_info().add_processed_global(_res.model_info().get_processed_local()); //switch from local counters to global counters
_res.model_info().set_processed_local(0L);
_sharedmodel = _res.model_info();
}
}