package water.rapids;
import water.DKV;
import water.Futures;
import water.Key;
import water.MRTask;
import water.fvec.Frame;
import water.fvec.Vec;
import water.nbhm.*;
import water.rapids.ast.AstFunction;
import water.rapids.ast.AstRoot;
import water.rapids.ast.prims.operators.AstPlus;
import water.util.Log;
import java.util.Collections;
import java.util.Map;
/**
* Session is a long-lasting environment supporting caching and Copy-On-Write optimization of Vecs. This session may
* last over many different Rapids calls (provided they refer to the same session). When the session ends, all the
* cached Vecs will be deleted (except those in user facing Frames).
**/
public class Session {
// How often to perform sanity checks. 0 means disable checks, 1000 is to always check.
private final static int sanityChecksFrequency = 1000;
private static int sanityChecksCounter = 0;
private String id;
// --------------------------------------------------------------------------
// Copy On Write optimization
// --------------------------------------------------------------------------
// COW optimization: instead of copying Vecs, they are "virtually" copied by
// simply pointer-sharing, and raising the ref-cnt here. Losing a copy can
// lower the ref-cnt, and when it goes to zero the Vec can be removed. If
// the Vec needs to be modified, and the ref-cnt is 1 - an update-in-place
// can happen. Otherwise a true data copy is made, and the private copy is
// modified.
// Ref-counts per Vec. Always positive; zero is removed from the table;
// negative is an error. At the end of any given Rapids expression the
// counts should match all the Vecs in the FRAMES set.
private NonBlockingHashMap<Vec, Integer> REFCNTS = new NonBlockingHashMap<>();
// Frames tracked by this Session and alive to the next Rapids call. When
// the whole session ends, these frames can be removed from the DKV. These
// Frames can share Vecs amongst themselves (tracked by the REFCNTS) and also
// with other global frames.
private NonBlockingHashMap<Key, Frame> FRAMES = new NonBlockingHashMap<>();
// Vec that came from global frames, and are considered immutable. Rapids
// will always copy these Vecs before mutating or deleting. Total visible
// refcnts are effectively the normal refcnts plus 1 for being in the GLOBALS
// set.
private NonBlockingHashSet<Vec> GLOBALS = new NonBlockingHashSet<>();
/**
* Constructor
*/
public Session() {
this(Key.make().toString());
}
public Session(String id) {
this.id = id;
cluster_init();
}
/** Return this session's id. */
public String id() {
return id;
}
/**
* Execute an AstRoot in the current Session with much assertion-checking
* @param ast Rapids expression to execute
* @param scope ?
* @return the result from the Rapids expression
*/
public Val exec(AstRoot ast, AstFunction scope) {
sanity_check_refs(null);
// Execute
Env env = new Env(this);
env._scope = scope;
Val val = ast.exec(env); // Execute
assert env.sp() == 0; // Stack balanced at end
sanity_check_refs(val);
return val; // Can return a frame, which may point to session-shared Vecs
}
/**
* Normal session exit. Returned Frames are fully deep-copied, and are responsibility of the caller to delete.
* Returned Frames have their refcnts currently up by 1 (for the returned value itself).
*/
public Val end(Val returning) {
sanity_check_refs(returning);
// Remove all temp frames
Futures fs = new Futures();
for (Frame fr : FRAMES.values()) {
fs = downRefCnt(fr, fs); // Remove internal Vecs one by one
DKV.remove(fr._key, fs); // Shallow remove, internal Vecs removed 1-by-1
}
fs.blockForPending();
FRAMES.clear(); // No more temp frames
// Copy (as needed) so the returning Frame is completely independent of the
// (disappearing) session.
if (returning != null && returning.isFrame()) {
Frame fr = returning.getFrame();
Vec[] vecs = fr.vecs();
for (int i = 0; i < vecs.length; i++) {
_addRefCnt(vecs[i], -1); // Returning frame has refcnt +1, lower it now; should go to zero internal refcnts.
if (GLOBALS.contains(vecs[i])) // Copy if shared with globals
fr.replace(i, vecs[i].makeCopy());
}
}
GLOBALS.clear(); // No longer tracking globals
sanity_check_refs(null);
REFCNTS.clear();
return returning;
}
/**
* The Rapids call threw an exception. Best-effort cleanup, no more exceptions
*/
public RuntimeException endQuietly(Throwable ex) {
try {
GLOBALS.clear();
Futures fs = new Futures();
for (Frame fr : FRAMES.values()) {
for (Vec vec : fr.vecs()) {
Integer I = REFCNTS.get(vec);
int i = (I == null ? 0 : I) - 1;
if (i > 0) REFCNTS.put(vec, i);
else {
REFCNTS.remove(vec);
vec.remove(fs);
}
}
DKV.remove(fr._key, fs); // Shallow remove, internal Vecs removed 1-by-1
}
fs.blockForPending();
FRAMES.clear();
REFCNTS.clear();
} catch (Exception ex2) {
Log.warn("Exception " + ex2 + " suppressed while cleaning up Rapids Session after already throwing " + ex);
}
return ex instanceof RuntimeException ? (RuntimeException) ex : new RuntimeException(ex);
}
/**
* Internal ref cnts (not counting globals - which only ever keep things alive, and have a virtual +1 to refcnts
* always).
*/
private int _getRefCnt(Vec vec) {
Integer I = REFCNTS.get(vec);
assert I == null || I > 0; // No zero or negative counts
return I == null ? 0 : I;
}
private int _putRefCnt(Vec vec, int i) {
assert i >= 0; // No negative counts
if (i > 0) REFCNTS.put(vec, i);
else REFCNTS.remove(vec);
return i;
}
/**
* Bump internal count, not counting globals
*/
private int _addRefCnt(Vec vec, int i) {
return _putRefCnt(vec, _getRefCnt(vec) + i);
}
/**
* External refcnt: internal refcnt plus 1 for being global
*/
private int getRefCnt(Vec vec) {
return _getRefCnt(vec) + (GLOBALS.contains(vec) ? 1 : 0);
}
/**
* RefCnt +i this Vec; Global Refs can be alive with zero internal counts
*/
private int addRefCnt(Vec vec, int i) {
return _addRefCnt(vec, i) + (GLOBALS.contains(vec) ? 1 : 0);
}
/**
* RefCnt +i all Vecs this Frame.
*/
Frame addRefCnt(Frame fr, int i) {
if (fr != null) // Allow and ignore null Frame, easier calling convention
for (Vec vec : fr.vecs()) _addRefCnt(vec, i);
return fr; // Flow coding
}
/**
* Found in the DKV, if not a tracked TEMP make it a global
*/
Frame addGlobals(Frame fr) {
if (!FRAMES.containsKey(fr._key))
Collections.addAll(GLOBALS, fr.vecs());
return fr; // Flow coding
}
/**
* Track a freshly minted tmp frame. This frame can be removed when the session ends (unlike global frames), or
* anytime during the session when the client removes it.
*/
public Frame track_tmp(Frame fr) {
assert fr._key != null; // Temps have names
FRAMES.put(fr._key, fr); // Track for session
addRefCnt(fr, 1); // Refcnt is also up: these Vecs stick around after single Rapids call for the next one
DKV.put(fr); // Into DKV, so e.g. Flow can view for debugging
return fr; // Flow coding
}
/**
* Remove and delete a session-tracked frame.
* Remove from all session tracking spaces.
* Remove any newly-unshared Vecs, but keep the shared ones.
*/
public void remove(Frame fr) {
if (fr == null) return;
Futures fs = new Futures();
if (!FRAMES.containsKey(fr._key)) { // In globals and not temps?
for (Vec vec : fr.vecs()) {
GLOBALS.remove(vec); // Not a global anymore
if (REFCNTS.get(vec) == null) // If not shared with temps
vec.remove(fs); // Remove unshared dead global
}
} else { // Else a temp and not a global
fs = downRefCnt(fr, fs); // Standard down-ref counting of all Vecs
FRAMES.remove(fr._key); // And remove from temps
}
DKV.remove(fr._key, fs); // Shallow remove, internal were Vecs removed 1-by-1
fs.blockForPending();
}
/**
* Lower refcnt of all Vecs in frame, deleting Vecs that go to zero refs.
* Passed in a Futures which is returned, and set to non-null if something gets deleted.
*/
Futures downRefCnt(Frame fr, Futures fs) {
for (Vec vec : fr.vecs()) // Refcnt -1 all Vecs
if (addRefCnt(vec, -1) == 0) {
if (fs == null) fs = new Futures();
vec.remove(fs);
}
return fs;
}
/**
* Update a global ID, maintaining sharing of Vecs
*/
public Frame assign(Key<Frame> id, Frame src) {
if (FRAMES.containsKey(id)) throw new IllegalArgumentException("Cannot reassign temp " + id);
Futures fs = new Futures();
// Vec lifetime invariant: Globals do not share with other globals (but can
// share with temps). All the src Vecs are about to become globals. If
// the ID already exists, and global Vecs within it are about to die, and thus
// may be deleted.
Frame fr = DKV.getGet(id);
if (fr != null) { // Prior frame exists
for (Vec vec : fr.vecs()) {
if (GLOBALS.remove(vec) && _getRefCnt(vec) == 0)
vec.remove(fs); // Remove unused global vec
}
}
// Copy (defensive) the base vecs array. Then copy any vecs which are
// already globals - this new global must be independent of any other
// global Vecs - because global Vecs get side-effected by unrelated
// operations.
Vec[] svecs = src.vecs().clone();
for (int i = 0; i < svecs.length; i++)
if (GLOBALS.contains(svecs[i]))
svecs[i] = svecs[i].makeCopy();
// Make and install new global Frame
Frame fr2 = new Frame(id, src._names.clone(), svecs);
DKV.put(fr2, fs);
addGlobals(fr2);
fs.blockForPending();
return fr2;
}
/**
* Support C-O-W optimizations: the following list of columns are about to be updated. Copy them as-needed and
* replace in the Frame. Return the updated Frame vecs for flow-coding.
*/
public Vec[] copyOnWrite(Frame fr, int[] cols) {
Vec did_copy = null; // Did a copy?
Vec[] vecs = fr.vecs();
for (int col : cols) {
Vec vec = vecs[col];
int refcnt = getRefCnt(vec);
assert refcnt > 0;
if (refcnt > 1) // If refcnt is 1, we allow the update to take in-place
fr.replace(col, (did_copy = vec.makeCopy()));
}
if (did_copy != null && fr._key != null) DKV.put(fr); // Then update frame in the DKV
return vecs;
}
/**
* Check that ref counts are in a consistent state.
* This should only be called between calls to Rapids expressions (otherwise may blow false-positives).
* @param returning If sanity check is done at the end of the session, there is a value being returned. This value
* might be a Frame (which would not be in FRAMES). So we pass the "returning" value explicitly,
* so that all its references can be properly accounted for.
*/
private void sanity_check_refs(Val returning) {
if ((sanityChecksCounter++) % 1000 >= sanityChecksFrequency) return;
// Compute refcnts from tracked frames only. Since we are between Rapids
// calls the only tracked Vecs should be those from tracked frames.
NonBlockingHashMap<Vec, Integer> refcnts = new NonBlockingHashMap<>(REFCNTS.size());
for (Frame fr : FRAMES.values())
for (Vec vec : fr.vecs()) {
Integer count = refcnts.get(vec);
refcnts.put(vec, count == null ? 1 : count + 1);
}
// Now account for the returning frame (if it is a Frame). Note that it is entirely possible that this frame is
// already in the FRAMES list, however we need to account for it anyways -- this is how Env works...
if (returning != null && returning.isFrame())
for (Vec vec : returning.getFrame().vecs()) {
Integer count = refcnts.get(vec);
refcnts.put(vec, count == null ? 1 : count + 1);
}
// Now compare computed refcnts to cached REFCNTS.
// First check that every Vec in computed refcnt is also in REFCNTS, with equal counts.
for (Map.Entry<Vec,Integer> pair : refcnts.entrySet()) {
Vec vec = pair.getKey();
Integer count = pair.getValue();
Integer savedCount = REFCNTS.get(vec);
if (savedCount == null) throw new IllegalStateException("REFCNTS missing vec " + vec);
if (count.intValue() != savedCount.intValue())
throw new IllegalStateException(
"Ref-count mismatch for vec " + vec + ": REFCNT = " + savedCount + ", should be " + count);
}
// Then check that every cached REFCNT is in the computed set as well.
if (refcnts.size() != REFCNTS.size())
for (Map.Entry<Vec,Integer> pair : REFCNTS.entrySet()) {
if (!refcnts.containsKey(pair.getKey()))
throw new IllegalStateException(
"REFCNTs contains an extra vec " + pair.getKey() + ", count = " + pair.getValue());
}
}
// To avoid a class-circularity hang, we need to force other members of the
// cluster to load the Rapids & AstRoot classes BEFORE trying to execute code
// remotely, because e.g. ddply runs functions on all nodes.
private static volatile boolean _initialized; // One-shot init
static void cluster_init() {
if (_initialized) return;
// Touch a common class to force loading
new MRTask() {
@Override
public void setupLocal() {
new AstPlus();
}
}.doAllNodes();
_initialized = true;
}
}