/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.ml.common; import rapaio.data.Frame; import rapaio.data.VRange; import rapaio.data.Var; import rapaio.data.VarType; import rapaio.printer.Printable; import java.util.Arrays; import java.util.Collections; import java.util.List; import static java.util.stream.Collectors.*; /** * Capabilities describes what a machine learning algorithm can train and fit. * <p> * Created by <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> at 12/1/14. */ public class Capabilities implements Printable { private Integer minInputCount; private Integer maxInputCount; private List<VarType> inputTypes; private Boolean allowMissingInputValues; private Integer minTargetCount; private Integer maxTargetCount; private List<VarType> targetTypes; private Boolean allowMissingTargetValues; public Capabilities withInputCount(int minInputCount, int maxInputCount) { this.minInputCount = minInputCount; this.maxInputCount = maxInputCount; return this; } public Capabilities withInputTypes(VarType... types) { this.inputTypes = Arrays.stream(types).collect(toList()); Collections.sort(this.inputTypes); return this; } public Capabilities withTargetCount(int minTargetCount, int maxTargetCount) { this.minTargetCount = minTargetCount; this.maxTargetCount = maxTargetCount; return this; } public Capabilities withTargetTypes(VarType... types) { this.targetTypes = Arrays.stream(types).collect(toList()); Collections.sort(this.inputTypes); return this; } public Capabilities withAllowMissingInputValues(boolean allow) { this.allowMissingInputValues = allow; return this; } public Capabilities withAllowMissingTargetValues(boolean allow) { this.allowMissingTargetValues = allow; return this; } /** * This method evaluates the capabilities of the algorithm at the learning phase. * * @param df data frame to be learned * @param weights weights of the data frame * @param targetVars target variable names */ public void checkAtLearnPhase(Frame df, Var weights, String... targetVars) { // check if capabilities are well-specified if (inputTypes == null) { throw new IllegalArgumentException("Capabilities not initialized completely: missing inputTypes"); } if (minInputCount == null) { throw new IllegalArgumentException("Capabilities not initialized completely: missing minInputCount"); } if (maxInputCount == null) { throw new IllegalArgumentException("Capabilities not initialized completely: missing maxInputCount"); } if (allowMissingInputValues == null) { throw new IllegalArgumentException("Capabilities not initialized completely: missing allowMissingInputValues"); } if (targetTypes == null) { throw new IllegalArgumentException("Capabilities not initialized completely: miaaing targetTypes"); } if (minTargetCount == null) { throw new IllegalArgumentException("Capabilities not initialized completely: missing minTargetCount"); } if (maxTargetCount == null) { throw new IllegalArgumentException("Capabilities not initialized completely: missing maxTargetCount"); } if (allowMissingTargetValues == null) { throw new IllegalArgumentException("Capabilities not initialized completely: missing allowMissingTargetValues"); } // check target type checkInputCount(df, weights, targetVars); checkInputTypes(df, weights, targetVars); checkMissingInputValues(df, weights, targetVars); checkTargetCount(df, weights, targetVars); checkTargetTypes(df, weights, targetVars); checkMissingTargetValues(df, weights, targetVars); } private void checkTargetCount(Frame df, Var weights, String... targetVarNames) { List<String> varList = VRange.of(targetVarNames).parseVarNames(df); int size = varList.size(); if (size < minTargetCount) { throw new IllegalArgumentException("Algorithm requires more than " + minInputCount + " target variables."); } if (size > maxTargetCount) { throw new IllegalArgumentException("Algorithm does not allow more than " + maxInputCount + " target variables"); } } private void checkTargetTypes(Frame df, Var weights, String... targetVarNames) { List<String> varList = VRange.of(targetVarNames).parseVarNames(df); for (String varName : varList) { if (!targetTypes.contains(df.var(varName).type())) { throw new IllegalArgumentException("Algorithm does not allow " + df.var(varName).type().name() + " as target type vor var: " + varName); } } } private void checkMissingTargetValues(Frame df, Var weights, String... targetVarNames) { if (allowMissingTargetValues) return; List<String> varList = VRange.of(targetVarNames).parseVarNames(df); StringBuilder sb = new StringBuilder(); for (String targetName : varList) { if (df.var(targetName).stream().complete().count() != df.var(targetName).rowCount()) { if (sb.length() != 0) { sb.append(", "); } sb.append(targetName); } } if (sb.length() > 0) throw new IllegalArgumentException("Algorithm does not allow target variables with missing values; see : " + sb.toString()); } private void checkInputCount(Frame df, Var weights, String... targetVars) { List<String> inputNames = VRange.of(targetVars).parseInverseVarNames(df); int size = inputNames.size(); if (size < minInputCount) { throw new IllegalArgumentException("Algorithm requires more than " + minInputCount + " input variables."); } if (size > maxInputCount) { throw new IllegalArgumentException("Algorithm does not allow more than " + maxInputCount + " input variables"); } } void checkInputTypes(Frame df, Var weights, String... targetVars) { List<String> inputNames = VRange.of(targetVars).parseInverseVarNames(df); StringBuilder sb = new StringBuilder(); for (String inputName : inputNames) { Var inputVar = df.var(inputName); if (!inputTypes.contains(inputVar.type())) { if (sb.length() != 0) { sb.append(", "); } sb.append(inputName).append("[").append(inputVar.type().name()).append("]"); } } if (sb.length() > 0) { throw new IllegalArgumentException("Algorithm does not allow input variables of give types: " + sb.toString()); } } private void checkMissingInputValues(Frame df, Var weights, String... targetVarNames) { if (allowMissingInputValues) return; List<String> varList = VRange.of(targetVarNames).parseInverseVarNames(df); StringBuilder sb = new StringBuilder(); for (String inputName : varList) { if (df.var(inputName).stream().complete().count() != df.var(inputName).rowCount()) { if (sb.length() != 0) { sb.append(", "); } sb.append(inputName); } } if (sb.length() > 0) throw new IllegalArgumentException("Algorithm does not allow input variables with missing values; see : " + sb.toString()); } public Integer getMinInputCount() { return minInputCount; } public Integer getMaxInputCount() { return maxInputCount; } public List<VarType> getInputTypes() { return inputTypes; } public Boolean getAllowMissingInputValues() { return allowMissingInputValues; } public Integer getMinTargetCount() { return minTargetCount; } public Integer getMaxTargetCount() { return maxTargetCount; } public List<VarType> getTargetTypes() { return targetTypes; } public Boolean getAllowMissingTargetValues() { return allowMissingTargetValues; } @Override public String summary() { StringBuilder sb = new StringBuilder(); sb.append("types inputs/targets: ").append(inputTypes.stream().map(Enum::name).collect(joining(","))).append("/").append(targetTypes.stream().map(Enum::name).collect(joining(","))).append("\n"); sb.append("counts inputs/targets: [").append(minInputCount).append(",").append(maxInputCount).append("] / [") .append(minTargetCount).append(",").append(maxTargetCount).append("]\n"); sb.append("missing inputs/targets: ").append(allowMissingInputValues).append("/").append(allowMissingTargetValues).append("\n"); return sb.toString(); } }