/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.models; import org.elasticsearch.ml.modelinput.MapModelInput; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; public class EsTreeModel extends EsModelEvaluator<MapModelInput, String> { private final EsTreeNode startNode; public EsTreeModel(EsTreeNode startNode) { this.startNode = startNode; } @Override public Map<String, Object> evaluateDebug(MapModelInput modelInput) { Map<String, Object> vector = modelInput.getAsMap(); assert startNode.predicate.match(vector); return startNode.evaluate(vector); } @Override public String evaluate(MapModelInput modelInput) { Map<String, Object> vector = modelInput.getAsMap(); assert startNode.predicate.match(vector); return (String)startNode.evaluate(vector).get("class"); } public static class EsTreeNode { EsPredicate predicate; java.util.List<EsTreeNode> childNodes = new ArrayList<>(); String score; public EsTreeNode(List<EsTreeNode> childNodes, EsPredicate predicate, String score) { this.predicate = predicate; this.childNodes = childNodes; this.score = score; } private Map<String, Object> evaluate(Map<String, Object> vector) { for (EsTreeNode childNode : childNodes) { if (childNode.predicate.match(vector)) { return childNode.evaluate(vector); } } Map<String, Object> result = new HashMap<>(); result.put("class", score); return result; } } public interface EsPredicate { boolean match(Map<String, Object> vector); boolean notEnoughValues(Map<String, Object> vector); } public abstract static class EsSimplePredicate<T extends Comparable<T>> implements EsPredicate { protected final T value; protected String field; protected EsSimplePredicate(T value, String field) { this.value = value; this.field = field; } public abstract boolean match(T fieldValue); @SuppressWarnings("unchecked") public boolean match(Map<String, Object> vector) { Object fieldValue = vector.get(field); if (fieldValue instanceof HashSet) { fieldValue = new ComparableSet<>((HashSet<Comparable<T>>) fieldValue); } if (fieldValue == null) { return false; } return match((T) fieldValue); } @Override public boolean notEnoughValues(Map<String, Object> vector) { return vector.containsKey(field) == false; } } public abstract static class EsCompoundPredicate implements EsPredicate { protected List<EsPredicate> predicates; protected EsCompoundPredicate(List<EsPredicate> predicates) { this.predicates = predicates; } public boolean match(Map<String, Object> vector) { return matchList(vector); } protected abstract boolean matchList(Map<String, Object> vector); @Override public boolean notEnoughValues(Map<String, Object> vector) { boolean valuesMissing = false; for (EsPredicate predicate : predicates) { valuesMissing = predicate.notEnoughValues(vector) || valuesMissing; } return valuesMissing; } } public static class EsSimpleSetPredicate<T> implements EsPredicate { protected HashSet<T> values; private String field; public EsSimpleSetPredicate(HashSet<T> values, String field) { this.values = values; this.field = field; } @Override public boolean match(Map<String, Object> vector) { // we do not check for null because HashSet allows null values. for (Object value : (Set)vector.get(field)) { if (values.contains(value)) { return true; } } return false; } @Override public boolean notEnoughValues(Map<String, Object> vector) { return vector.containsKey(field) == false; } } public static class ComparableSet<T> extends HashSet<Comparable<T>> implements Comparable<T> { public ComparableSet(HashSet<Comparable<T>> set) { this.addAll(set); } @SuppressWarnings("unchecked") @Override public int compareTo(T o) { if (this.size()!= 1) { throw new UnsupportedOperationException("cannot really compare sets, I am just pretending!"); } if (o instanceof Comparable == false) { throw new UnsupportedOperationException("cannot compare to object " + o.getClass().getName()); } //noinspection unchecked Comparable<T> first = this.iterator().next(); return first.compareTo(o); } } }