package water;
import dontweave.gson.JsonElement;
import dontweave.gson.JsonObject;
import dontweave.gson.JsonParser;
import hex.GridSearch;
import water.api.DocGen;
import water.api.Request;
import water.api.RequestArguments;
import water.api.RequestServer.API_VERSION;
import water.fvec.Vec;
import water.util.Log;
import water.util.Utils;
import java.lang.annotation.Annotation;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.*;
public abstract class Request2 extends Request {
static final int API_WEAVER = 1;
static public DocGen.FieldDoc[] DOC_FIELDS;
protected transient Properties _parms;
@API(help = "Response stats and info.")
public ResponseInfo response_info;
public String input(String fieldName) {
return _parms == null ? null : _parms.getProperty(fieldName);
}
public class TypeaheadKey extends TypeaheadInputText<Key> {
transient Key _defaultValue;
transient Class _type;
public TypeaheadKey() {
this(null, true);
}
public TypeaheadKey(Class type, boolean required) {
super(mapTypeahead(type), "", required);
_type = type;
setRefreshOnChange();
}
public void setValue(Key key) {
record()._value = key;
record()._originalValue = key.toString();
}
@Override protected Key parse(String input) {
if (_validator!=null) _validator.validateRaw(input);
Key k = Key.make(input);
Value v = DKV.get(k);
if( v == null && _mustExist )
throw new H2OIllegalArgumentException(this, "Key '" + input + "' does not exist!");
if( _type != null ) {
if( v == null && _required )
throw new H2OIllegalArgumentException(this, "Key '" + input + "' does not exist!");
}
return k;
}
@Override protected Key defaultValue() {
return _defaultValue;
}
@Override protected String queryDescription() {
return "A key" + (_type != null ? " of type " + _type.getSimpleName() : "");
}
@Override protected String[] errors() {
if( _type != null )
return new String[] { "Key is not a " + _type.getSimpleName() };
return super.errors();
}
}
/**
* Fields that depends on another, e.g. select Vec from a Frame.
*/
public class Dependent implements Filter {
public final String _ref;
protected Dependent(String name) {
_ref = name;
}
@Override public boolean run(Object value) {
return true;
}
}
public class ColumnSelect extends Dependent {
protected ColumnSelect(String key) {
super(key);
}
}
public class VecSelect extends Dependent {
protected VecSelect(String key) {
super(key);
}
}
public class SpecialVecSelect extends VecSelect {
public boolean optional = false;
protected SpecialVecSelect(String key) { this(key,false);}
protected SpecialVecSelect(String key, boolean optional) {
super(key);
this.optional = optional;
}
}
public class VecClassSelect extends Dependent {
protected VecClassSelect(String key) {
super(key);
}
}
/**
* Specify how a column specifier field is parsed.
*/
public enum MultiVecSelectType {
/**
* Treat a token as a column name. Otherwise, treat it as a 0-based index if it looks like a
* positive integer.
*/
NAMES_THEN_INDEXES,
/**
* Treat a token as a column name no matter what (even if it looks like it is an integer). This
* is used by the Web UI, which blindly specifies column names.
*/
NAMES_ONLY
}
public class MultiVecSelect extends Dependent {
boolean _namesOnly;
private void init(MultiVecSelectType selectType) {
_namesOnly = false;
switch( selectType ) {
case NAMES_THEN_INDEXES:
_namesOnly = false;
break;
case NAMES_ONLY:
_namesOnly = true;
break;
}
}
protected MultiVecSelect(String key) {
super(key);
init(MultiVecSelectType.NAMES_THEN_INDEXES);
}
protected MultiVecSelect(String key, MultiVecSelectType selectType) {
super(key);
init(selectType);
}
}
public class DoClassBoolean extends Dependent {
protected DoClassBoolean(String key) {
super(key);
}
}
public class DRFCopyDataBoolean extends Dependent {
protected DRFCopyDataBoolean(String key) { super(key); }
}
/**
* Iterates over fields and their annotations, and creates argument handlers.
*/
@Override protected void registered(API_VERSION version) {
try {
ArrayList<Class> classes = new ArrayList<Class>();
{
Class c = getClass();
while( c != null ) {
classes.add(c);
c = c.getSuperclass();
}
}
// Fields from parent classes first
Collections.reverse(classes);
ArrayList<Field> fields = new ArrayList<Field>();
for( Class c : classes )
for( Field field : c.getDeclaredFields() )
if( !Modifier.isStatic(field.getModifiers()) )
fields.add(field);
// TODO remove map, response field already processed specifically
HashMap<String, FrameClassVec> classVecs = new HashMap<String, FrameClassVec>();
for( Field f : fields ) {
Annotation[] as = f.getAnnotations();
API api = find(as, API.class);
if( api != null && Helper.isInput(api) ) {
f.setAccessible(true);
Object defaultValue = f.get(this);
// Create an Argument instance to reuse existing Web framework for now
Argument arg = null;
// Simplest case, filter is an Argument
if( Argument.class.isAssignableFrom(api.filter()) ) {
arg = (Argument) newInstance(api);
}
//
else if( ColumnSelect.class.isAssignableFrom(api.filter()) ) {
ColumnSelect name = (ColumnSelect) newInstance(api);
throw H2O.fail();
//H2OHexKey key = null;
//for( Argument a : _arguments )
// if( a instanceof H2OHexKey && name._ref.equals(((H2OHexKey) a)._name) )
// key = (H2OHexKey) a;
//arg = new HexAllColumnSelect(f.getName(), key);
}
//
else if( Dependent.class.isAssignableFrom(api.filter()) ) {
Dependent d = (Dependent) newInstance(api);
Argument ref = find(d._ref);
if( d instanceof VecSelect )
arg = new FrameKeyVec(f.getName(), (TypeaheadKey) ref, api.help(), api.required());
else if( d instanceof VecClassSelect ) {
arg = new FrameClassVec(f.getName(), (TypeaheadKey) ref);
classVecs.put(d._ref, (FrameClassVec) arg);
} else if( d instanceof MultiVecSelect ) {
FrameClassVec response = classVecs.get(d._ref);
boolean names = ((MultiVecSelect) d)._namesOnly;
arg = new FrameKeyMultiVec(f.getName(), (TypeaheadKey) ref, response, api.help(), names,filterNaCols());
} else if( d instanceof DoClassBoolean ) {
FrameClassVec response = classVecs.get(d._ref);
arg = new ClassifyBool(f.getName(), response);
} else if( d instanceof DRFCopyDataBoolean ) {
arg = new DRFCopyDataBool(f.getName(), (TypeaheadKey)ref);
}
}
// String
else if( f.getType() == String.class )
arg = new Str(f.getName(), (String) defaultValue);
// Real
else if( f.getType() == float.class || f.getType() == double.class ) {
double val = ((Number) defaultValue).doubleValue();
arg = new Real(f.getName(), api.required(), val, api.dmin(), api.dmax(), api.help());
}
// LongInt
else if( f.getType() == int.class || f.getType() == long.class ) {
long val = ((Number) defaultValue).longValue();
arg = new LongInt(f.getName(), api.required(), val, api.lmin(), api.lmax(), api.help());
}
// RSeq
else if( f.getType() == int[].class ) {
int[] val = (int[]) defaultValue;
double[] ds = null;
if( val != null ) {
ds = new double[val.length];
for( int i = 0; i < ds.length; i++ )
ds[i] = val[i];
}
arg = new RSeq(f.getName(), api.required(), new NumberSequence(ds, null, true), false, api.help());
}
// RSeq
else if( f.getType() == double[].class ) {
double[] val = (double[]) defaultValue;
arg = new RSeq(f.getName(), api.required(), new NumberSequence(val, null, false), false, api.help());
}
// RSeq float
else if( f.getType() == float[].class ) {
float[] val = (float[]) defaultValue;
arg = new RSeqFloat(f.getName(), api.required(), new NumberSequenceFloat(val, null, false), false, api.help());
}
// Bool
else if( f.getType() == boolean.class && api.filter() == Default.class ) {
boolean val = (Boolean) defaultValue;
arg = new Bool(f.getName(), val, api.help());
}
// Enum
else if( Enum.class.isAssignableFrom(f.getType()) ) {
Enum val = (Enum) defaultValue;
arg = new EnumArgument(f.getName(), val);
}
// Key
else if( f.getType() == Key.class ) {
TypeaheadKey t = new TypeaheadKey();
t._defaultValue = (Key) defaultValue;
arg = t;
}
// Generic Freezable field
else if( Freezable.class.isAssignableFrom(f.getType()) )
arg = new TypeaheadKey(f.getType(), api.required());
if( arg != null ) {
arg._name = f.getName();
arg._displayName = api.displayName().length() > 0 ? api.displayName() : null;
arg._required = api.required();
arg._field = f;
arg._hideInQuery = api.hide();
arg._gridable = api.gridable();
arg._mustExist = api.mustExist();
arg._validator = newValidator(api);
}
}
}
} catch( Exception e ) {
throw new RuntimeException(e);
}
}
final protected Argument find(String name) {
for( Argument a : _arguments )
if( name.equals(a._name) )
return a;
return null;
}
// Extracted in separate class as Weaver cannot load Request during boot
static final class Helper {
static boolean isInput(API api) {
return api.filter() != Filter.class || api.filters().length != 0;
}
}
private static <T> T find(Annotation[] as, Class<T> c) {
for( Annotation a : as )
if( a.annotationType() == c )
return (T) a;
return null;
}
private Filter newInstance(API api) throws Exception {
for( Constructor c : api.filter().getDeclaredConstructors() ) {
c.setAccessible(true);
Class[] ps = c.getParameterTypes();
if( ps.length == 1 && RequestArguments.class.isAssignableFrom(ps[0]) )
return (Filter) c.newInstance(this);
}
for( Constructor c : api.filter().getDeclaredConstructors() ) {
Class[] ps = c.getParameterTypes();
if( ps.length == 0 )
return (Filter) c.newInstance();
}
throw new Exception("Class " + api.filter().getName() + " must have an empty constructor");
}
private Validator newValidator(API api) throws Exception {
for( Constructor c : api.validator().getDeclaredConstructors() ) {
c.setAccessible(true);
Class[] ps = c.getParameterTypes();
return (Validator) c.newInstance();
}
return null;
}
// Create an instance per call instead of ThreadLocals
@Override protected Request create(Properties parms) {
Request2 request;
try {
request = getClass().newInstance();
request._arguments = _arguments;
request._parms = parms;
} catch( Exception e ) {
throw new RuntimeException(e);
}
return request;
}
public Response servePublic() {
return serve();
}
// Expand grid search related argument sets
@Override protected NanoHTTPD.Response serveGrid(NanoHTTPD server, Properties parms, RequestType type) {
String[][] values = new String[_arguments.size()][];
boolean gridSearch = false;
for( int i = 0; i < _arguments.size(); i++ ) {
Argument arg = _arguments.get(i);
if( arg._gridable ) {
String value = _parms.getProperty(arg._name);
if( value != null ) {
// Skips grid if argument is an array, except if imbricated expression
// Little hackish, waiting for real language
boolean imbricated = value.contains("(");
if( !arg._field.getType().isArray() || imbricated ) {
values[i] = split(value);
if( values[i] != null && values[i].length > 1 )
gridSearch = true;
} else if (arg._field.getType().isArray() && !imbricated) { // Copy values which are arrays
values[i] = new String[] { value };
}
}
}
}
if( !gridSearch )
return superServeGrid(server, parms, type);
// Ignore destination key so that each job gets its own
_parms.remove("destination_key");
for( int i = 0; i < _arguments.size(); i++ )
if( _arguments.get(i)._name.equals("destination_key") )
values[i] = null;
// Iterate over all argument combinations
int[] counters = new int[values.length];
ArrayList<Job> jobs = new ArrayList<Job>();
for( ;; ) {
Job job = (Job) create(_parms);
Properties combination = new Properties();
for( int i = 0; i < values.length; i++ ) {
if( values[i] != null ) {
String value = values[i][counters[i]];
value = value.trim();
combination.setProperty(_arguments.get(i)._name, value);
_arguments.get(i).reset();
_arguments.get(i).check(job, value);
}
}
job._parms = combination;
jobs.add(job);
if( !increment(counters, values) )
break;
}
GridSearch grid = new GridSearch();
grid.jobs = jobs.toArray(new Job[jobs.size()]);
return grid.superServeGrid(server, parms, type);
}
// Splits one-level imbricated expressions like 4, 5, (2, 3), 7
// TODO: switch to real parser for unified imbricated argument sets, expressions etc.
public static String[] split(String value) {
String[] values = null;
value = value.trim();
StringTokenizer st = new StringTokenizer(value, ",()", true);
String s, current = "";
while( (s = getNextToken(st)) != null ) {
if( ",".equals(s) ) {
values = addSplit(values, current);
current = "";
} else if( "(".equals(s) ) {
while( !(")".equals((s = getNextToken(st)))) ) {
if( s == null )
throw new IllegalArgumentException("Missing closing parenthesis");
current += s;
}
values = addSplit(values, current);
current = "";
} else
current += s;
}
values = addSplit(values, current);
return values;
}
private static String[] addSplit(String[] values, String value) {
if( value.contains(":") ) {
double[] gen = NumberSequence.parseGenerator(value, false, 1);
for( double d : gen )
values = Utils.append(values, "" + d);
} else if( value.length() > 0 )
values = Utils.append(values, value);
return values;
}
private static String getNextToken(StringTokenizer st) {
while( st.hasMoreTokens() ) {
String tok = st.nextToken().trim();
if( tok.length() > 0 )
return tok;
}
return null;
}
public final NanoHTTPD.Response superServeGrid(NanoHTTPD server, Properties parms, RequestType type) {
return super.serveGrid(server, parms, type);
}
private static boolean increment(int[] counters, String[][] values) {
for( int i = 0; i < counters.length; i++ ) {
if( values[i] != null && counters[i] < values[i].length - 1 ) {
counters[i]++;
return true;
} else
counters[i] = 0;
}
return false;
}
/*
* Arguments to fields casts.
*/
public void set(Argument arg, String input, Object value) {
if( arg._field.getType() != Key.class && value instanceof Key )
value = UKV.get((Key) value);
try {
//
if( arg._field.getType() == int.class && value instanceof Long )
value = ((Long) value).intValue();
//
else if( arg._field.getType() == float.class && value instanceof Double )
value = ((Double) value).floatValue();
//
else if( value instanceof NumberSequence ) {
double[] ds = ((NumberSequence) value)._arr;
if( arg._field.getType() == int[].class ) {
int[] is = new int[ds.length];
for( int i = 0; i < is.length; i++ )
is[i] = (int) ds[i];
value = is;
} else
value = ds;
}
else if( value instanceof NumberSequenceFloat ) {
float[] fs = ((NumberSequenceFloat) value)._arr;
if( arg._field.getType() == int[].class ) {
int[] is = new int[fs.length];
for( int i = 0; i < is.length; i++ )
is[i] = (int) fs[i];
value = is;
} else
value = fs;
}
arg._field.set(this, value);
} catch( Exception e ) {
throw new RuntimeException(e);
}
}
@Override public API_VERSION[] supportedVersions() {
return SUPPORTS_ONLY_V2;
}
public void fillResponseInfo(Response response) {
this.response_info = response.extractInfo();
}
public JsonObject toJSON() {
final String json = new String(writeJSON(new AutoBuffer()).buf());
if (json.length() == 0) return new JsonObject();
JsonObject jo = (JsonObject)new JsonParser().parse(json);
jo.remove("Request2");
jo.remove("response_info");
return jo;
}
public JsonObject toJSON(Set<String> whitelist) {
JsonObject jo = toJSON();
for (Map.Entry<String , JsonElement> entry : jo.entrySet()) {
String key = entry.getKey();
if (! whitelist.contains(key))
jo.remove(key);
}
return jo;
}
@Override
public String toString() {
return GSON_BUILDER.toJson(toJSON());
}
protected void logStart() {
Log.info("Building H2O " + this.getClass().getSimpleName() + " model with these parameters:");
for (String s : toString().split("\n")) Log.info(s);
}
public boolean makeJsonBox(StringBuilder sb) {
sb.append("<div class='pull-right'><a href='#' onclick='$(\"#params\").toggleClass(\"hide\");'"
+ " class='btn btn-inverse btn-mini'>Model Parameters</a></div><div class='hide' id='params'>"
+ "<pre><code class=\"language-json\">");
sb.append(toString());
sb.append("</code></pre></div>");
return true;
}
protected boolean filterNaCols(){return false;}
}