package com.blazebit.ai.decisiontree;
import com.blazebit.ai.decisiontree.impl.SimpleAttributeSelector;
import com.blazebit.ai.decisiontree.impl.SimpleAttributeValue;
import com.blazebit.ai.decisiontree.impl.SimpleDecisionTree;
import com.blazebit.ai.decisiontree.impl.SimpleDiscreteAttribute;
import com.blazebit.ai.decisiontree.impl.SimpleExample;
import com.blazebit.ai.decisiontree.impl.SimpleItem;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import org.junit.Test;
/**
*
* @author Christian Beikov
*/
public class SimpleDecisionTreeTest {
@Test
public void testCreate() {
Set<Vehicle> results = new HashSet<Vehicle>();
Set<Attribute> attributes = new HashSet<Attribute>();
Vehicle audi = new Vehicle("A3", Brand.AUDI, Type.CAR);
Vehicle bmw = new Vehicle("M3", Brand.BMW, Type.CAR);
attributes.add(TestAttributes.BRAND);
attributes.add(TestAttributes.TYPE);
results.add(bmw);
results.add(audi);
DecisionTree<Vehicle> tree = new SimpleDecisionTree<Vehicle>(attributes, examples(results), new SimpleAttributeSelector());
assertEquals(2, tree.apply(item(new Vehicle(null, null, Type.CAR))).size());
assertTrue(results.containsAll(tree.apply(item(new Vehicle(null, null, Type.CAR)))));
assertEquals(1, tree.apply(item(new Vehicle(null, Brand.AUDI, null))).size());
assertTrue(tree.apply(item(new Vehicle(null, Brand.AUDI, null))).contains(audi));
assertEquals(1, tree.apply(item(new Vehicle(null, Brand.BMW, null))).size());
assertTrue(tree.apply(item(new Vehicle(null, Brand.BMW, null))).contains(bmw));
}
private static Item item(Vehicle v){
Map<Attribute, AttributeValue> values = new HashMap<Attribute, AttributeValue>();
if(v.brand != null){
values.put(TestAttributes.BRAND, new SimpleAttributeValue(v.brand));
}
if(v.type != null){
values.put(TestAttributes.TYPE, new SimpleAttributeValue(v.type));
}
return new SimpleItem(values);
}
private static Set<Example<Vehicle>> examples(Collection<Vehicle> vehicles){
Set<Example<Vehicle>> examples = new HashSet<Example<Vehicle>>(vehicles.size());
for(Vehicle v : vehicles){
Map<Attribute, AttributeValue> values = new HashMap<Attribute, AttributeValue>();
if(v.brand != null){
values.put(TestAttributes.BRAND, new SimpleAttributeValue(v.brand));
}
if(v.type != null){
values.put(TestAttributes.TYPE, new SimpleAttributeValue(v.type));
}
examples.add(new SimpleExample<Vehicle>(values, v));
}
return examples;
}
static class Vehicle{
String name;
Brand brand;
Type type;
public Vehicle(String name, Brand brand, Type type) {
this.name = name;
this.brand = brand;
this.type = type;
}
@Override
public int hashCode() {
int hash = 7;
hash = 37 * hash + (this.name != null ? this.name.hashCode() : 0);
hash = 37 * hash + (this.brand != null ? this.brand.hashCode() : 0);
hash = 37 * hash + (this.type != null ? this.type.hashCode() : 0);
return hash;
}
@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
final Vehicle other = (Vehicle) obj;
if ((this.name == null) ? (other.name != null) : !this.name.equals(other.name)) {
return false;
}
if (this.brand != other.brand) {
return false;
}
if (this.type != other.type) {
return false;
}
return true;
}
}
static enum Type{
CAR,
MOTORCYCLE
}
static enum Brand{
AUDI,
MERCEDES,
BMW,
SUZUKI
}
static class TestAttributes{
//public static final Attribute NAME = new SimpleContinuousAttribute("name");
public static final Attribute BRAND = new SimpleDiscreteAttribute("brand", values(Brand.values()));
public static final Attribute TYPE = new SimpleDiscreteAttribute("type", values(Type.values()));
public static Set<AttributeValue> values(Object[] objects){
Set<AttributeValue> values = new HashSet<AttributeValue>(objects.length);
for(Object o : objects){
values.add(new SimpleAttributeValue(o));
}
return values;
}
}
}