package water;
import java.util.ArrayList;
import java.util.concurrent.*;
import jsr166y.CountedCompleter;
import jsr166y.ForkJoinPool;
import water.DException.DistributedException;
import water.Job.JobCancelledException;
import water.util.Log;
/** A Distributed DTask.
* Execute a set of Keys on the home for each Key.
* Limited to doing a map/reduce style.
*/
public abstract class DRemoteTask<T extends DRemoteTask> extends DTask<T> implements Cloneable, ForkJoinPool.ManagedBlocker {
// Keys to be worked over
protected Key[] _keys;
// One-time flips from false to true
transient protected boolean _is_local, _top_level;
// Other RPCs we are waiting on
transient private RPC<T> _lo, _hi;
// Local work we are waiting on
transient private T _local;
// We can add more things to block on - in case we want a bunch of lazy tasks
// produced by children to all end before this top-level task ends.
// Semantically, these will all complete before we return from the top-level
// task. Pragmatically, we block on a finer grained basis.
transient protected volatile Futures _fs; // More things to block on
// Combine results from 'drt' into 'this' DRemoteTask
abstract public void reduce( T drt );
// Support for fluid-programming with strong types
private final T self() { return (T)this; }
// Super-class init on the 1st remote instance of this object. Caller may
// choose to clone/fork new instances, but then is reponsible for setting up
// those instances.
public void init() { }
// Invokes the task on all nodes
public T invokeOnAllNodes() {
H2O cloud = H2O.CLOUD;
Key[] args = new Key[cloud.size()];
String skey = "RunOnAll"+Key.rand();
for( int i = 0; i < args.length; ++i )
args[i] = Key.make(skey,(byte)0,Key.DFJ_INTERNAL_USER,cloud._memary[i]);
invoke(args);
for( Key arg : args ) DKV.remove(arg);
return self();
}
// Invoked with a set of keys
public T dfork ( Key... keys ) { keys(keys); _top_level=true; compute2(); return self(); }
public void keys( Key... keys ) { _keys = flatten(keys); }
public T invoke( Key... keys ) {
try {
ForkJoinPool.managedBlock(dfork(keys));
} catch(InterruptedException iex) { Log.errRTExcept(iex); }
// Intent was to quietlyJoin();
// Which forks, then QUIETLY join to not propagate local exceptions out.
return self();
}
// Return true if blocking is unnecessary, which is true if the Task isDone.
@Override public boolean isReleasable() { return isDone(); }
// Possibly blocks the current thread. Returns true if isReleasable would
// return true. Used by the FJ Pool management to spawn threads to prevent
// deadlock is otherwise all threads would block on waits.
@Override public boolean block() throws InterruptedException {
while( !isDone() ) {
try { get(); }
catch(ExecutionException eex) { // skip the execution part
Throwable tex = eex.getCause();
if( tex instanceof Error) throw ( Error)tex;
if( tex instanceof DistributedException) throw ( DistributedException)tex;
if( tex instanceof JobCancelledException) throw (JobCancelledException)tex;
throw new RuntimeException(tex);
}
catch(CancellationException cex) { Log.errRTExcept(cex); }
}
return true;
}
// Decide to do local-work or remote-work
@Override public final void compute2() {
if( _is_local )
lcompute();
else
dcompute();
}
// Decide to do local-completion or remote-completion
@Override public final void onCompletion( CountedCompleter caller ) {
if( _is_local ) lonCompletion(caller);
else donCompletion(caller);
}
// Real Work(tm)!
public abstract void lcompute(); // Override to specify local work
private final void dcompute() {// Work to do the distribution
// Split out the keys into disjointly-homed sets of keys.
// Find the split point. First find the range of home-indices.
H2O cloud = H2O.CLOUD;
int lo=cloud._memary.length, hi=-1;
for( Key k : _keys ) {
int i = k.home(cloud);
if( i<lo ) lo=i;
if( i>hi ) hi=i; // lo <= home(keys) <= hi
}
// Classic fork/join, but on CPUs.
// Split into 3 arrays of keys: lo keys, hi keys and self keys
final ArrayList<Key> locals = new ArrayList<Key>();
final ArrayList<Key> lokeys = new ArrayList<Key>();
final ArrayList<Key> hikeys = new ArrayList<Key>();
int self_idx = cloud.nidx(H2O.SELF);
int mid = (lo+hi)>>>1; // Mid-point
for( Key k : _keys ) {
int idx = k.home(cloud);
if( idx == self_idx ) locals.add(k);
else if( idx < mid ) lokeys.add(k);
else hikeys.add(k);
}
// Launch off 2 tasks for the other sets of keys, and get a place-holder
// for results to block on.
_lo = remote_compute(lokeys);
_hi = remote_compute(hikeys);
// Setup for local recursion: just use the local keys.
if( locals.size() != 0 ) { // Shortcut for no local work
_local = clone(); // 'this' is completer for '_local', so awaits _local completion
_local._is_local = true;
_local._keys = locals.toArray(new Key[locals.size()]); // Keys, including local keys (if any)
_local.init(); // One-time top-level init
H2O.submitTask(_local); // Begin normal execution on a FJ thread
} else {
tryComplete(); // No local work, so just immediate tryComplete
}
}
// Real Completion(tm)!
public void lonCompletion( CountedCompleter caller ) { } // Override for local completion
private final void donCompletion( CountedCompleter caller ) { // Distributed completion
assert _lo == null || _lo.isDone();
assert _hi == null || _hi.isDone();
// Fold up results from left & right subtrees
if( _lo != null ) reduce2(_lo.get());
if( _hi != null ) reduce2(_hi.get());
if( _local != null ) reduce2(_local );
// Note: in theory (valid semantics) we could push these "over the wire"
// and block for them as we're blocking for the top-level initial split.
// However, that would require sending "isDone" flags over the wire also.
// MUCH simpler to just block for them all now, and send over the empty set
// of not-yet-blocked things.
if(_local != null && _local._fs != null )
_local._fs.blockForPending(); // Block on all other pending tasks, also
_keys = null; // Do not return _keys over wire
if( _top_level ) postGlobal();
};
// Override to do work after all the forks have returned
protected void postGlobal(){}
// 'Reduce' left and right answers. Gather exceptions
private void reduce2( T drt ) {
if( drt == null ) return;
reduce(drt);
}
private final RPC<T> remote_compute( ArrayList<Key> keys ) {
if( keys.size() == 0 ) return null;
DRemoteTask rpc = clone();
rpc.setCompleter(null);
rpc._keys = keys.toArray(new Key[keys.size()]);
addToPendingCount(1); // Block until the RPC returns
// Set self up as needing completion by this RPC: when the ACK comes back
// we'll get a wakeup.
return new RPC(keys.get(0).home_node(), rpc).addCompleter(this).call();
}
private static Key[] flatten( Key[] args ) { return args; }
public Futures getFutures() {
if( _fs == null ) synchronized(this) { if( _fs == null ) _fs = new Futures(); }
return _fs;
}
public void alsoBlockFor( Future f ) {
if( f == null ) return;
getFutures().add(f);
}
public void alsoBlockFor( Futures fs ) {
if( fs == null ) return;
getFutures().add(fs);
}
protected void reduceAlsoBlock( T drt ) {
reduce(drt);
alsoBlockFor(drt._fs);
}
@Override public T clone() {
T dt = (T)super.clone();
dt.setCompleter(this); // Set completer, what used to be a final field
dt._fs = null; // Clone does not depend on extent futures
dt.setPendingCount(0); // Volatile write for completer field; reset pending count also
return dt;
}
}