package hex; import hex.rng.MersenneTwisterRNG; import java.lang.annotation.*; import java.lang.reflect.Field; import java.lang.reflect.Modifier; import java.util.*; import water.util.Log; /** * Looks for parameters on a set of objects and perform random search. */ public class ParamsSearch { @Retention(RetentionPolicy.RUNTIME) public @interface Info { /** * Parameter search will move the value relative to origin. */ double origin() default 0; double min() default Double.NaN; double max() default Double.NaN; } @Retention(RetentionPolicy.RUNTIME) public @interface Ignore { } Param[] _params; Random _rand = new MersenneTwisterRNG(new Random().nextLong()); double _rate = .1; class Param { int _objectIndex; Field _field; Info _info; double _initial, _best, _last; @Info public double defaults; void modify(Object o) throws Exception { if( _field.getType() == boolean.class ) { if( _rand.nextDouble() < _rate ) { _last = _best == 0 ? 1 : 0; _field.set(o, _last == 1); } } else { if( _info == null ) _info = Param.class.getField("defaults").getAnnotation(Info.class); double delta = (_best - _info.origin()) * _rate; double min = _best - delta, max = _best + delta; _last = min + _rand.nextDouble() * (max - min); if( _field.getType() == float.class ) _field.set(o, (float) _last); else if( _field.getType() == int.class ) _field.set(o, (int) _last); } String change = _best + " -> " + _last; Log.info(this + ": " + change); } void write() { Log.info(this + ": " + _best); } String objectName() { return _field.getDeclaringClass().getName() + " " + _objectIndex; } @Override public String toString() { return objectName() + "." + _field.getName(); } } public void run(Object... os) { try { ArrayList<Object> expanded = new ArrayList<Object>(); for( Object o : os ) { if( o instanceof Object[] ) expanded.addAll(Arrays.asList((Object[]) o)); else if( o instanceof Collection ) expanded.addAll((Collection) o); else expanded.add(o); } if( _params == null ) { ArrayList<Param> params = new ArrayList<Param>(); for( int i = 0; i < expanded.size(); i++ ) { Class c = expanded.get(i).getClass(); ArrayList<Field> fields = new ArrayList<Field>(); getAllFields(fields, c); for( Field f : fields ) { f.setAccessible(true); if( (f.getModifiers() & Modifier.STATIC) == 0 && !ignore(f) ) { Object v = f.get(expanded.get(i)); if( v instanceof Number || v instanceof Boolean ) { Param param = new Param(); for( Annotation a : f.getAnnotations() ) if( a.annotationType() == Info.class ) param._info = (Info) a; param._objectIndex = i; param._field = f; if( v instanceof Boolean ) param._initial = ((Boolean) v).booleanValue() ? 1 : 0; else param._initial = ((Number) v).doubleValue(); param._last = param._best = param._initial; params.add(param); param.write(); } } } } _params = params.toArray(new Param[0]); Log.info(toString()); } else { for( int i = 0; i < _params.length; i++ ) modify(expanded, i); } } catch( Exception ex ) { throw new RuntimeException(ex); } } private static boolean ignore(Field f) { for( Annotation a : f.getAnnotations() ) if( a.annotationType() == Ignore.class ) return true; return false; } private static void getAllFields(List<Field> fields, Class<?> type) { for( Field field : type.getDeclaredFields() ) fields.add(field); if( type.getSuperclass() != null ) getAllFields(fields, type.getSuperclass()); } private void modify(ArrayList<Object> expanded, int i) throws Exception { Object o = expanded.get(_params[i]._objectIndex); _params[i].modify(o); } public void save() { for( int i = 0; i < _params.length; i++ ) _params[i]._best = _params[i]._last; } // @Override public String toString() { // StringBuilder sb = new StringBuilder(); // int objectIndex = -1; // for( Param param : _params ) { // if( objectIndex != param._objectIndex ) { // objectIndex = param._objectIndex; // sb.append(param._field.getDeclaringClass().getName() + " " + objectIndex + '\n'); // } // sb.append(" " + param._field.getName() + ": " + param._best + '\n'); // } // return sb.toString(); // } }