package hex.deepwater; import deepwater.backends.BackendModel; import deepwater.backends.BackendTrain; import hex.FrameTask; import water.Futures; import water.H2O; import water.Job; import water.fvec.Chunk; import water.fvec.NewChunk; import water.parser.BufferedString; import water.util.ArrayUtils; import water.util.Log; import water.util.RandomUtils; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.Random; public class DeepWaterTask extends FrameTask<DeepWaterTask> { private DeepWaterModelInfo _localmodel; //per-node state (to be reduced) private DeepWaterModelInfo _sharedmodel; //input/output private int _chunk_node_count = 1; private float _useFraction; private boolean _shuffle; private final Job _job; /** * Accessor to the object containing the (final) state of the Deep Learning model * Should only be queried after calling this.doAll(Frame training) * @return "The" final model after one Map/Reduce iteration */ final public DeepWaterModelInfo model_info() { assert(_sharedmodel != null); return _sharedmodel; } /** * The only constructor * @param inputModel Initial model state * @param fraction Fraction of rows of the training to train with */ DeepWaterTask(DeepWaterModelInfo inputModel, float fraction, Job job) { super(job._key,inputModel._dataInfo); _sharedmodel = inputModel; _useFraction=fraction; _shuffle = model_info().get_params()._shuffle_training_data; _job = job; } /** * Transfer ownership from global (shared) model to local model which will be worked on */ @Override protected void setupLocal(){ // long start = System.currentTimeMillis(); assert(_localmodel == null); _localmodel = _sharedmodel; _sharedmodel = null; _localmodel.set_processed_local(0); final int weightIdx =_fr.find(_localmodel.get_params()._weights_column); final int respIdx =_fr.find(_localmodel.get_params()._response_column); final int batchSize = _localmodel.get_params()._mini_batch_size; // long nativetime = 0; DeepWaterIterator iter = null; long seed = 0xDECAF + 0xD00D * _localmodel.get_processed_global(); Random rng = RandomUtils.getRNG(seed); if (_fr.numRows()>Integer.MAX_VALUE) { throw H2O.unimpl("Need to implement batching into int-sized chunks."); } int len = (int)_fr.numRows(); int j=0; Futures fs = new Futures(); ArrayList trainLabels = new ArrayList<>(); ArrayList trainData = new ArrayList<>(); try { // Binary data (Images/Documents/etc.) if (_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.image || _localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.text) { int dataIdx = 0; //must be the first column //FIXME Log.debug("Using column " + _fr.name(dataIdx) + " for " + ((_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.image) ? "path to image data" :((_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.text) ? "text data" : "path to arbitrary bytes"))); // full passes over the data BufferedString bs = new BufferedString(); int fullpasses = (int)_useFraction; // Example: train_samples_per_iteration = 4700, and train.numRows()=1000 -> _useFraction = 4.7 -> fullpasses = 4 while (j++ < fullpasses) { for (int i=0; i<_fr.numRows(); ++i) { double weight = weightIdx == -1 ? 1 : _fr.vec(weightIdx).at(i); if (weight == 0) continue; BufferedString file = _fr.vec(dataIdx).atStr(bs, i); if (file!=null) trainData.add(file.toString()); float response = (float) _fr.vec(respIdx).at(i); trainLabels.add(response); } } // fractional passes // 0.7 while (trainData.size() < _useFraction*len || trainData.size() % batchSize != 0) { assert(_shuffle); int i = rng.nextInt(len); double weight = weightIdx == -1 ? 1 : _fr.vec(weightIdx).at(i); if (weight == 0) continue; BufferedString file = _fr.vec(dataIdx).atStr(bs, i); if (file!=null) trainData.add(file.toString()); float response = (float) _fr.vec(respIdx).at(i); trainLabels.add(response); } } // Numeric data (H2O Frame full with numeric columns) else if (_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.dataset) { double mul = _localmodel._dataInfo._normRespMul!=null ? _localmodel._dataInfo._normRespMul[0] : 1; double sub = _localmodel._dataInfo._normRespSub!=null ? _localmodel._dataInfo._normRespSub[0] : 0; // full passes over the data int fullpasses = (int) _useFraction; while (j++ < fullpasses) { for (int i = 0; i < _fr.numRows(); ++i) { double weight = weightIdx == -1 ? 1 : _fr.vec(weightIdx).at(i); if (weight == 0) continue; float response = (float)((_fr.vec(respIdx).at(i) - sub) / mul); trainData.add(i); trainLabels.add(response); } } // fractional passes while (trainData.size() < _useFraction * len || trainData.size() % batchSize != 0) { int i = rng.nextInt(len); double weight = weightIdx == -1 ? 1 : _fr.vec(weightIdx).at(i); if (weight == 0) continue; float response = (float)((_fr.vec(respIdx).at(i) - sub) / mul); trainData.add(i); trainLabels.add(response); } } // shuffle the (global) list if (_shuffle) { rng.setSeed(seed); Collections.shuffle(trainLabels, rng); rng.setSeed(seed); Collections.shuffle(trainData, rng); } if (_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.image) { iter = new DeepWaterImageIterator(trainData, trainLabels, _localmodel._meanData, batchSize, _localmodel._width, _localmodel._height, _localmodel._channels, _localmodel.get_params()._cache_data); } else if (_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.dataset) { assert (_localmodel._dataInfo != null); iter = new DeepWaterDatasetIterator(trainData, trainLabels, _localmodel._dataInfo, batchSize, _localmodel.get_params()._cache_data); } else if (_localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.text) { iter = new DeepWaterTextIterator(trainData, trainLabels, batchSize, 56/*FIXME*/, _localmodel.get_params()._cache_data); } NativeTrainTask ntt; while (iter.Next(fs) && !_job.isStopping()) { // if (ntt != null) nativetime += ntt._timeInMillis; long n = _localmodel.get_processed_total(); // if(!_localmodel.get_params()._quiet_mode) // Log.info("Trained " + n + " samples. Training on " + Arrays.toString(((DeepWaterImageIterator)iter).getFiles())); _localmodel._backend.setParameter(_localmodel.getModel().get(), "learning_rate", _localmodel.get_params().learningRate((double) n)); _localmodel._backend.setParameter(_localmodel.getModel().get(), "momentum", _localmodel.get_params().momentum((double) n)); //fork off GPU work, but let the iterator.Next() wait on completion before swapping again //System.err.println("data: " + Arrays.toString(iter.getData())); /* float[] preds = _localmodel._backend.predict(_localmodel._model, iter.getData()); if (Float.isNaN(ArrayUtils.sum(preds))) { Log.err(DeepWaterModel.unstable_msg); throw new UnsupportedOperationException(DeepWaterModel.unstable_msg); } */ // System.err.println("pred: " + Arrays.toString(preds)); ntt = new NativeTrainTask(_localmodel._backend, _localmodel.getModel().get(), iter.getData(), iter.getLabel()); fs.add(H2O.submitTask(ntt)); _localmodel.add_processed_local(iter._batch_size); } fs.blockForPending(); // nativetime += ntt._timeInMillis; } catch (IOException e) { e.printStackTrace(); //gracefully continue if we can't find files etc. } // long end = System.currentTimeMillis(); // if (!_localmodel.get_params()._quiet_mode) { // Log.info("Time for one iteration: " + PrettyPrint.msecs(end - start, true)); // Log.info("Time for native training : " + PrettyPrint.msecs(nativetime, true)); // } } @Override public void map(Chunk [] chunks, NewChunk [] outputs) { } static private class NativeTrainTask extends H2O.H2OCountedCompleter<NativeTrainTask> { long _timeInMillis; final BackendTrain _backend; final BackendModel _model; float[] _data; float[] _labels; NativeTrainTask(BackendTrain backend, BackendModel model, float[] data, float[] label) { _backend = backend; _model = model; _data = data; _labels = label; } @Override public void compute2() { long start = System.currentTimeMillis(); _backend.train(_model, _data,_labels); //ignore predictions long end = System.currentTimeMillis(); _timeInMillis += end-start; tryComplete(); } } /** * After all maps are done on a node, this is called to store the per-node model into DKV (for elastic averaging) * Otherwise, do nothing. */ @Override protected void closeLocal() { _sharedmodel = null; } /** * Average the per-node models (for elastic averaging, already wrote them to DKV in postLocal()) * This is a no-op between F/J worker threads (operate on the same weights/biases) * @param other Other DeepWaterTask to reduce */ @Override public void reduce(DeepWaterTask other){ if (_localmodel != null && other._localmodel != null && other._localmodel.get_processed_local() > 0 //other DLTask was active (its model_info should be used for averaging) && other._localmodel != _localmodel) //other DLTask worked on a different model_info { // avoid adding remote model info to unprocessed local data, still random // (this can happen if we have no chunks on the master node) if (_localmodel.get_processed_local() == 0) { _localmodel = other._localmodel; _chunk_node_count = other._chunk_node_count; } else { _localmodel.add(other._localmodel); _chunk_node_count += other._chunk_node_count; } } } private static long _lastWarn; private static long _warnCount; /** * After all reduces are done, the driver node calls this method to clean up * This is only needed if we're not inside a DeepWaterTask2 (which will do the reduction between replicated data workers). * So if replication is disabled, and every node works on partial data, then we have work to do here (model averaging). */ @Override protected void postGlobal(){ DeepWaterParameters dlp = _localmodel.get_params(); if (H2O.CLOUD.size() > 1 && !dlp._replicate_training_data) { long now = System.currentTimeMillis(); if (_chunk_node_count < H2O.CLOUD.size() && (now - _lastWarn > 5000) && _warnCount < 3) { // Log.info("Synchronizing across " + _chunk_node_count + " H2O node(s)."); Log.warn(H2O.CLOUD.size() - _chunk_node_count + " node(s) (out of " + H2O.CLOUD.size() + ") are not contributing to model updates. Consider setting replicate_training_data to true or using a larger training dataset (or fewer H2O nodes)."); _lastWarn = now; _warnCount++; } } // Check that we're not inside a DeepWaterTask2 assert ((!dlp._replicate_training_data || H2O.CLOUD.size() == 1) == !_run_local); if (!_run_local) { _localmodel.add_processed_global(_localmodel.get_processed_local()); //move local sample counts to global ones _localmodel.set_processed_local(0L); // model averaging if (_chunk_node_count > 1) _localmodel.div(_chunk_node_count); } else { //Get ready for reduction in DeepWaterTask2 //Just swap the local and global models _sharedmodel = _localmodel; } if (_sharedmodel == null) _sharedmodel = _localmodel; _localmodel = null; } }