/* * Copyright 2013 State University of New York at Oswego * * Licensed under the Apache License, Version 2.0 (the "License");  * you may not use this file except in compliance with the License.  * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software  * distributed under the License is distributed on an "AS IS" BASIS,  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  * See the License for the specific language governing permissions and  * limitations under the License. */ package edu.oswego.csc480_hci521_2013.shared.h2o.urlbuilders; import java.util.HashMap; import java.util.Map.Entry; /** * */ public class RFBuilder extends AbstractBuilder { public static enum StatType { GINI, ENTROPY } public static enum SamplingStrategy { RANDOM, STRATIFIED_LOCAL } static final String NAME = "RF"; static String ignoring = "None"; private String responseValue; private int nTree = 50; private String weights; RFBuilder() { } RFBuilder(HashMap<String, String> args) { super(NAME); setArgs(args); } public RFBuilder(String dataKey) { super(NAME); addArg("data_key", dataKey); } public String getResponseVariable() { return this.responseValue; } /** * @param value Column name The output classification (also known as * 'response variable') that is being learned. * @return */ public RFBuilder setResponseVariable(String value) { this.responseValue = value; addArg("response_variable", value); return this; } public int getNtree() { return this.nTree; } /** * @param value Integer from 0 to 2147483647 * @return */ public RFBuilder setNtree(Integer value) { if (value < 0) { throw new IllegalArgumentException("value must be positive"); } this.nTree = value; addArg("ntree", value.toString()); return this; } /** * @param type the stat type to use. * @return this */ public RFBuilder setStatType(final StatType type) { addArg("stat_type", type.name()); return this; } public String getClassWeights() { return this.weights; } /** * @param values Category weight (positive) * @return */ public RFBuilder setClassWeights(HashMap<String, Double> values) { this.weights = values.toString(); StringBuilder value = new StringBuilder(); for (Entry<String, Double> pair : values.entrySet()) { if (pair.getValue() < 0) { throw new IllegalArgumentException("values must be positive"); } value.append(pair.getKey()).append('=') .append(pair.getValue()).append(','); } addArg("class_weights", value.deleteCharAt(value.length() - 1).toString()); return this; } /** * @param value the type of sampling to use * @return this */ public RFBuilder setSamplingStrategy(SamplingStrategy value) { addArg("sampling_strategy", value.name()); return this; } /** * @param values Category strata (integer) * @return */ public RFBuilder setStrataSamples(HashMap<String, Integer> values) { StringBuilder value = new StringBuilder(); for (Entry<String, Integer> pair : values.entrySet()) { value.append(pair.getKey()).append('=') .append(pair.getValue()).append(','); } addArg("strata_samples", value.deleteCharAt(value.length() - 1).toString()); return this; } /** * @param value Valid H2O key * @return */ public RFBuilder setModelKey(String value) { addArg("model_key", value); return this; } /** * @param value Out of bag errors * @return */ public RFBuilder setOutOfBagErrorEstimate(boolean value) { addArg("out_of_bag_error_estimate", value ? "1" : "0"); return this; } /** * @param value Integer from 0 to 2147483647 * @return */ public RFBuilder setFeatures(Integer value) { if (value < 0) { throw new IllegalArgumentException("value must be positive"); } addArg("features", value.toString()); return this; } /** * can be used multiple times * * @param value Columns to select * @return */ public RFBuilder setIgnore(Integer value) { addMultiArg("ignore", value.toString()); return this; } public void storeIgnore(String name){ if(ignoring.compareTo("")==0) ignoring = name; else ignoring += ", " + name; } public String getIgnores(){ return ignoring; } /** * @param value Integer from 1 to 100 * @return */ public RFBuilder setSample(Integer value) { if (value < 0 || value > 100) { throw new IllegalArgumentException("value must be between 1 and 100 (inclusive)"); } addArg("sample", value.toString()); return this; } /** * @param value Integer from 0 to 65535 * @return */ public RFBuilder setBinLimit(Integer value) { if (value < 0 || value > 65535) { throw new IllegalArgumentException("value must be between 0 and 65535 (inclusive)"); } addArg("bin_limit", value.toString()); return this; } /** * @param value Integer from 0 to 2147483647 * @return */ public RFBuilder setDepth(Integer value) { if (value < 0) { throw new IllegalArgumentException("value must be positive"); } addArg("depth", value.toString()); return this; } /** * @param value Integer value * @return */ public RFBuilder setSeed(Integer value) { addArg("seed", value.toString()); return this; } /** * @param value Build trees in parallel * @return */ public RFBuilder setParallel(boolean value) { addArg("parallel", value ? "1" : "0"); return this; } /** * This allows for the use of == conditionals in the trees and not just <=. * @param value Integer from 0 to 2147483647 * @return */ public RFBuilder setExclusiveSplitLimit(Integer value) { if (value < 0) { throw new IllegalArgumentException("value must be positive"); } addArg("exclusive_split_limit", value.toString()); return this; } }