package hex; import com.jogamp.opencl.*; import com.jogamp.opencl.CLMemory.Mem; import hex.Layer.*; import jsr166y.CountedCompleter; import water.*; import water.H2O.H2OCountedCompleter; import water.fvec.Chunk; import water.fvec.Frame; import water.fvec.Vec; import water.util.Log; import water.util.Utils; import java.io.IOException; import java.nio.FloatBuffer; import java.util.Arrays; import java.util.Map.Entry; import java.util.Random; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicIntegerArray; import java.util.concurrent.atomic.AtomicLong; /** * Trains a neural network. * * @author cypof */ public abstract class Trainer { Trainer() { } public abstract Layer[] layers(); public abstract void start(); public abstract void join(); public long processed() { throw new UnsupportedOperationException(); } public static class Base extends Trainer { final Layer[] _ls; public Base(Layer[] ls) { _ls = ls; } @Override public Layer[] layers() { return _ls; } @Override public void start() { throw new UnsupportedOperationException(); } @Override public void join() { throw new UnsupportedOperationException(); } final void step(long seed) { // Log.info("step with seed " + seed); fprop(seed); for( int i = 1; i < _ls.length - 1; i++ ) Arrays.fill(_ls[i]._e, 0); bprop(); } final void fprop(long seed) { for (Layer _l : _ls) _l.fprop(seed, true); } final void bprop() { for( int i = _ls.length - 1; i > 0; i-- ) _ls[i].bprop(); } } /** * Trains NN on current thread. */ public static class Direct extends Base { long _processed, _limit; Thread _thread; Key _job; public Direct(Layer[] ls, double epochs, Key job) { super(ls); _limit = (long) Math.ceil(epochs * ((Input) ls[0])._len); _job = job; } @Override public Layer[] layers() { return _ls; } public void run() { Training training = new Training() { @Override long processed() { return _processed; } }; for (Layer _l : _ls) _l._training = training; Input input = (Input) _ls[0]; for( ; _limit == 0 || _processed < _limit; _processed++ ) { step(_processed); input.move(); if( _job != null && (!Job.isRunning(_job) || !NeuralNet.running ) ) break; } } @Override public long processed() { return _processed; } @Override public void start() { _thread = new Thread() { @Override public void run() { Direct.this.run(); } }; _thread.start(); } @Override public void join() { try { _thread.join(); } catch( InterruptedException e ) { throw new RuntimeException(e); } } } /** * Runs several trainers in parallel on the same weights, using threads. Only works on one node. */ public static class Threaded extends Trainer { final Base[] _trainers; final Thread[] _threads; final long _stepsPerThread; final AtomicLong _processed = new AtomicLong(); public Threaded(Layer[] ls, double epochs, final Key job, int threads) { int num_threads = threads > 0 ? threads : Runtime.getRuntime().availableProcessors(); _trainers = new Base[num_threads]; _threads = new Thread[num_threads]; _stepsPerThread = (long) (epochs * ((Input) ls[0])._len / num_threads); Log.info("Starting " + num_threads + " threads."); for( int t = 0; t < num_threads; t++ ) { Layer[] clones = new Layer[ls.length]; for( int y = 0; y < clones.length; y++ ) clones[y] = ls[y].clone(); for( int y = 0; y < clones.length; y++ ) { clones[y].init(clones, y, false); clones[y]._training = new Training() { @Override long processed() { return _processed.get(); } }; } final Input input = (Input) clones[0]; input._pos = input._len * t / num_threads; _trainers[t] = new Base(clones); final Base trainer = _trainers[t]; final int thread_num = t; _threads[t] = new Thread("H2O Trainer " + t) { @Override public void run() { for( long i = 0; _stepsPerThread == 0 || i < _stepsPerThread; i++ ) { if( job != null && (!Job.isRunning(job) || !NeuralNet.running ) ) break; try { // long seed = thread_num * _stepsPerThread + input._pos; //BAD long seed = new Random().nextLong(); //GOOD // long seed = thread_num * _stepsPerThread + _processed.get(); //TRY trainer.step(seed); input.move(); _processed.incrementAndGet(); } catch (Exception e) { e.getStackTrace(); } } } }; } } @Override public Layer[] layers() { return _trainers[0].layers(); } @Override public long processed() { return _processed.get(); } @Override public void start() { for (Thread _thread : _threads) _thread.start(); } @Override public void join() { for (Thread _thread : _threads) { try { _thread.join(); } catch (InterruptedException e) { throw new RuntimeException(e); } } } public void run() { start(); join(); } } /** * Distributed trainer. All tasks on a node update the same weights, like Threaded. Updates * between nodes are synchronized at regular intervals by exchanging messages between the * initiating machine and others. Requires input to be Frame. */ public static class MapReduce extends Trainer { static final ConcurrentHashMap<Key, MapReduce> _instances = new ConcurrentHashMap<Key, MapReduce>(); Layer[] _ls; double _epochs; Key _job; AtomicIntegerArray _counts; transient Key _key; transient Descent _task; public MapReduce(Layer[] ls, double epochs, Key job) { _ls = ls; _epochs = epochs; _job = job; _key = Key.make((byte) 1, Key.DFJ_INTERNAL_USER, H2O.SELF); _instances.put(_key, this); DKV.put(_key, new Value(_key, new byte[0])); Vec[] vecs = ((VecsInput) ls[0]).vecs; assert ls[0]._a.length == VecsInput.expand(vecs); //assert vecs[0].nChunks() >= NeuralNet.cores() : "Not enough chunks, c.f. NeuralNet.reChunk"; _counts = new AtomicIntegerArray(vecs[0].nChunks()); } @Override public Layer[] layers() { return _ls; } @Override public long processed() { Vec[] vecs = ((VecsInput) _ls[0]).vecs; long n = 0; for( int i = 0; i < _counts.length(); i++ ) n += _counts.get(i) * vecs[0].chunkLen(i); return n; } @Override public void start() { // TODO? Chunk weights over all nodes // _keys = new Key[H2O.CLOUD._memary.length]; // Weights[] weights = new Weights[_keys.length]; _task = new Descent(); _task._job = _job; _task._ls = _ls; _task._key = _key; _task._epochs = _epochs; _task._ws = new float[_ls.length][]; _task._bs = new float[_ls.length][]; for( int y = 1; y < _ls.length; y++ ) { _task._ws[y] = _ls[y]._w; _task._bs[y] = _ls[y]._b; } Vec[] vecs = ((VecsInput) _ls[0]).vecs; Layer out = _ls[_ls.length - 1]; Vec response = out instanceof VecSoftmax ? ((VecSoftmax) out).vec : ((VecLinear) out)._vec; _task.dfork(new Frame(null, Utils.append(vecs, response))); } @Override public void join() { _task.join(); } public void run() { start(); join(); while (NeuralNet.running) { try { Thread.sleep(100); } catch (InterruptedException e) { e.printStackTrace(); } } } void done() { NeuralNet.running = false; _instances.remove(_key); UKV.remove(_key); if( _job != null ) { Job job = Job.findJob(_job); if( job != null ) { H2OCountedCompleter task = job._fjtask; if( task != null ) task.tryComplete(); job.remove(); } } } } static class Descent extends MRTask2<Descent> { Key _job; Layer[] _ls; float[][] _ws; float[][] _bs; Key _key; double _epochs; transient NodeDescent _node; transient volatile boolean _done; @Override protected void setupLocal() { _node = new NodeDescent(_job, _ls, _ws, _bs, _key); // Separate thread for more regular latency final boolean home = _key.home(); Thread thread = new Thread() { @Override public void run() { while( _job == null || Job.isRunning(_job) ) { if( !home ) _node.sync(); else { _node._total = _node._trainer.processed(); try { Thread.sleep(1); } catch( InterruptedException ex ) { } } } } }; thread.setDaemon(true); thread.start(); } @Override protected void closeLocal() { // Launch actual computation in order, otherwise passes // between chunks diverge quickly DescentEpoch epoch = new DescentEpoch(); epoch._node = _node; epoch._count = _epochs == 0. ? -1 : (int)Math.ceil(_epochs); H2O.submitTask(epoch); _ls = null; _ws = null; _bs = null; _key = null; } @Override public void map(Chunk[] cs) { _node._chunks.add(cs); } } private static abstract class NodeTask extends H2OCountedCompleter { NodeDescent _node; @Override public boolean onExceptionalCompletion(Throwable ex, CountedCompleter caller) { String error = Utils.getStackAsString(ex); Log.info(error); if( _node._job != null ) Job.findJob(_node._job).cancel(error); return super.onExceptionalCompletion(ex, caller); } } private static class DescentEpoch extends NodeTask { int _count; @Override public void compute2() { if( (_count < 0 || --_count >= 0) && (_node._job == null || Job.isRunning(_node._job)) ) { for( Chunk[] cs : _node._chunks ) { DescentChunk task = new DescentChunk(); task._node = _node; task._cs = cs; H2O.submitTask(task); } reinitialize(); H2O.submitTask(this); } else { if( _node._key.home() ) _node._trainer.done(); } } } static class DescentChunk extends NodeTask { Chunk[] _cs; @Override public void compute2() { if( _node._job == null || (Job.isRunning(_node._job) && NeuralNet.running)) { Layer[] clones = new Layer[_node._ls.length]; ChunksInput input = new ChunksInput(Utils.remove(_cs, _cs.length - 1), (VecsInput) _node._ls[0]); clones[0] = input; for( int y = 1; y < _node._ls.length - 1; y++ ) clones[y] = _node._ls[y].clone(); Layer output = _node._ls[_node._ls.length - 1]; if( output instanceof VecSoftmax ) clones[clones.length - 1] = new ChunkSoftmax(_cs[_cs.length - 1], (VecSoftmax) output); else clones[clones.length - 1] = new ChunkLinear(_cs[_cs.length - 1], (VecLinear) output); // create new _a and _e, but link to weights/bias from _node (Hogwild) for( int y = 0; y < clones.length; y++ ) { clones[y].init(clones, y, false); clones[y]._w = _node._ws[y]; clones[y]._b = _node._bs[y]; clones[y]._wm = _node._wm[y]; clones[y]._bm = _node._bm[y]; clones[y]._training = new Training() { @Override long processed() { return _node._total; } }; } Base base = new Base(clones); for( input._pos = 0; input._pos < _cs[0]._len; input._pos++ ) base.step(new Random().nextLong()); //warning: no reproducible seeding int chunk = _cs[0].cidx(); _node.stepped(chunk); } tryComplete(); } } static class NodeDescent { ConcurrentLinkedQueue<Chunk[]> _chunks = new ConcurrentLinkedQueue<Chunk[]>(); Key _job; Layer[] _ls; float[][] _ws; // Current weights float[][] _bs; // Current bias float[][] _wi; // Initial weights, for synchronization float[][] _bi; // Initial biases, for synchronization float[][] _wm; // Momentums float[][] _bm; // Momentums Key _key; ConcurrentHashMap<Integer, Integer> _counters; MapReduce _trainer; long _total; NodeDescent(Key job, Layer[] ls, float[][] ws, float[][] bs, Key key) { _job = job; _ls = ls; _key = key; _ws = ws; _bs = bs; _wi = new float[ws.length][]; _bi = new float[bs.length][]; _wm = new float[ws.length][]; _bm = new float[bs.length][]; for( int y = 1; y < _ws.length; y++ ) { _wi[y] = ws[y].clone(); _bi[y] = bs[y].clone(); if( ls[y].params.momentum_start != 0 || ls[y].params.momentum_stable != 0 ) { _wm[y] = new float[ws[y].length]; _bm[y] = new float[bs[y].length]; } } _trainer = MapReduce._instances.get(_key); assert (_trainer != null) == _key.home(); if( _trainer == null ) _counters = new ConcurrentHashMap<Integer, Integer>(); } void stepped(int chunk) { assert (_trainer != null) == _key.home(); if( _trainer != null ) _trainer._counts.incrementAndGet(chunk); else { for( ;; ) { Integer n = _counters.get(chunk); if( n == null ) { if( _counters.putIfAbsent(chunk, 1) == null ) break; } else { if( _counters.replace(chunk, n, n + 1) ) break; } } } } boolean sync() { assert !_key.home(); int[] counts = new int[10]; int n = 0; for( Entry<Integer, Integer> entry : _counters.entrySet() ) { if( n == counts.length ) { int[] t = new int[counts.length * 2]; System.arraycopy(counts, 0, t, 0, counts.length); counts = t; } counts[n++] = entry.getKey(); counts[n++] = _counters.remove(entry.getKey()); } if( n > counts.length ) { int[] t = new int[n]; System.arraycopy(counts, 0, t, 0, t.length); counts = t; } if( n > 0 ) { Shuttle s = new Shuttle(); s._w = new float[_ws.length][]; s._b = new float[_bs.length][]; for( int y = 1; y < _ws.length; y++ ) { s._w[y] = new float[_ws[y].length]; for( int i = 0; i < _ws[y].length; i++ ) { s._w[y][i] = _ws[y][i] - _wi[y][i]; _wi[y][i] = _ws[y][i]; } s._b[y] = new float[_bs[y].length]; for( int i = 0; i < _bs[y].length; i++ ) { s._b[y][i] = _bs[y][i] - _bi[y][i]; _bi[y][i] = _bs[y][i]; } } s._counts = counts; s.invoke(_key); _total = s._processed; for( int y = 1; y < _ws.length; y++ ) { for( int i = 0; i < _ws[y].length; i++ ) { float d = _ws[y][i] - _wi[y][i]; _wi[y][i] = s._w[y][i]; _ws[y][i] = s._w[y][i] + d; } for( int i = 0; i < _bs[y].length; i++ ) { float d = _bs[y][i] - _bi[y][i]; _bi[y][i] = s._b[y][i]; _bs[y][i] = s._b[y][i] + d; } } return true; } return false; } static class Shuttle extends Atomic { float[][] _w; // Deltas in, values out float[][] _b; // Deltas in, values out int[] _counts; long _processed; @Override public Value atomic(Value value) { assert _key.home(); MapReduce trainer = MapReduce._instances.get(_key); if( trainer != null ) { for( int y = 1; y < trainer._ls.length; y++ ) { for( int i = 0; i < _w[y].length; i++ ) trainer._ls[y]._w[i] += _w[y][i]; for( int i = 0; i < _b[y].length; i++ ) trainer._ls[y]._b[i] += _b[y][i]; } for( int y = 1; y < trainer._ls.length; y++ ) { _w[y] = trainer._ls[y]._w; _b[y] = trainer._ls[y]._b; } for( int i = 0; i < _counts.length; i += 2 ) trainer._counts.addAndGet(_counts[i], _counts[i + 1]); _counts = null; _processed = trainer.processed(); } return null; } } } /** * GPU based trainer. Alpha code! */ public static class OpenCL extends Trainer { final Layer[] _ls; public OpenCL(Layer[] ls) { _ls = ls; } @Override public Layer[] layers() { return _ls; } @Override public void start() { CLContext context = CLContext.create(); Log.debug("Created " + context); try { CLDevice device = context.getMaxFlopsDevice(); Log.debug("Using " + device); CLCommandQueue queue = device.createCommandQueue(); CLProgram program = context.createProgram(Boot._init.getResource2("/kernels.cl")).build(); CLKernel[] fprops = new CLKernel[_ls.length]; CLKernel[] bprops = new CLKernel[_ls.length]; CLKernel[] resets = new CLKernel[_ls.length]; CLBuffer<FloatBuffer>[] w = new CLBuffer[_ls.length]; CLBuffer<FloatBuffer>[] b = new CLBuffer[_ls.length]; CLBuffer<FloatBuffer>[] a = new CLBuffer[_ls.length]; CLBuffer<FloatBuffer>[] e = new CLBuffer[_ls.length]; for( int y = 0; y < _ls.length; y++ ) { a[y] = context.createFloatBuffer(_ls[y]._a.length, Mem.READ_WRITE); if( y > 0 ) { w[y] = context.createFloatBuffer(_ls[y]._w.length, Mem.READ_ONLY); b[y] = context.createFloatBuffer(_ls[y]._b.length, Mem.READ_ONLY); e[y] = context.createFloatBuffer(_ls[y]._e.length, Mem.READ_ONLY); queue.putWriteBuffer(w[y], false); queue.putWriteBuffer(b[y], false); fprops[y] = program.createCLKernel(_ls.getClass().getSimpleName() + "_fprop"); fprops[y].putArg(_ls[y - 1]._a.length); fprops[y].putArgs(a[y - 1], w[y], b[y], a[y]); bprops[y] = program.createCLKernel(_ls.getClass().getSimpleName() + "_bprop"); bprops[y].putArg(_ls[y - 1]._a.length); bprops[y].putArgs(a[y - 1], w[y], b[y], a[y], e[y]); // bprops[y].putArg(_ls[y]._r); if( e[y - 1] != null ) bprops[y].putArg(e[y - 1]); resets[y] = program.createCLKernel("reset_error"); resets[y].putArg(e[y]); } } int group = device.getMaxWorkGroupSize(); Input input = (Input) _ls[0]; while (true) { input.fprop(new Random().nextLong(), true); for( int i = 0; i < input._a.length; i++ ) a[0].getBuffer().put(i, input._a[i]); queue.putWriteBuffer(a[0], false); for( int y = 1; y < fprops.length; y++ ) queue.put1DRangeKernel(fprops[y], 0, _ls[y]._a.length, group); queue.putReadBuffer(a[_ls.length - 1], true); for( int y = 1; y < fprops.length - 1; y++ ) queue.put1DRangeKernel(resets[y], 0, _ls[y]._a.length, group); // softmax(input, a[a.length - 1].getBuffer(), e[e.length - 1].getBuffer()); queue.putWriteBuffer(a[_ls.length - 1], false); queue.putWriteBuffer(e[_ls.length - 1], false); for( int y = _ls.length - 1; y > 0; y-- ) queue.put1DRangeKernel(bprops[y], 0, _ls[y]._a.length, group); input.move(); } } catch( IOException ex ) { throw new RuntimeException(ex); } finally { context.release(); } } @Override public void join() { throw new UnsupportedOperationException(); } // static void softmax(Input input, FloatBuffer a, FloatBuffer e) { // float max = Float.NEGATIVE_INFINITY; // for( int o = 0; o < a.capacity(); o++ ) // if( max < a.get(o) ) // max = a.get(o); // float scale = 0; // for( int o = 0; o < a.capacity(); o++ ) { // a.put(o, (float) Math.exp(a.get(o) - max)); // scale += a.get(o); // } // for( int o = 0; o < a.capacity(); o++ ) { // a.put(o, a.get(o) / scale); // e.put(o, (o == input.label() ? 1 : 0) - a.get(o)); // } // } } }