package org.apache.samoa.streams.generators; /* * #%L * SAMOA * %% * Copyright (C) 2014 - 2015 Apache Software Foundation * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import java.io.Serializable; import java.util.ArrayList; import java.util.Random; import org.apache.samoa.instances.Attribute; import org.apache.samoa.instances.DenseInstance; import org.apache.samoa.instances.Instance; import org.apache.samoa.instances.Instances; import org.apache.samoa.instances.InstancesHeader; import org.apache.samoa.moa.core.FastVector; import org.apache.samoa.moa.core.InstanceExample; import org.apache.samoa.moa.core.ObjectRepository; import org.apache.samoa.moa.options.AbstractOptionHandler; import org.apache.samoa.moa.tasks.TaskMonitor; import org.apache.samoa.streams.InstanceStream; import com.github.javacliparser.FloatOption; import com.github.javacliparser.IntOption; /** * Stream generator for a stream based on a randomly generated tree.. * * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz) * @version $Revision: 7 $ */ public class RandomTreeGenerator extends AbstractOptionHandler implements InstanceStream { @Override public String getPurposeString() { return "Generates a stream based on a randomly generated tree."; } private static final long serialVersionUID = 1L; public IntOption treeRandomSeedOption = new IntOption("treeRandomSeed", 'r', "Seed for random generation of tree.", 1); public IntOption instanceRandomSeedOption = new IntOption( "instanceRandomSeed", 'i', "Seed for random generation of instances.", 1); public IntOption numClassesOption = new IntOption("numClasses", 'c', "The number of classes to generate.", 2, 2, Integer.MAX_VALUE); public IntOption numNominalsOption = new IntOption("numNominals", 'o', "The number of nominal attributes to generate.", 5, 0, Integer.MAX_VALUE); public IntOption numNumericsOption = new IntOption("numNumerics", 'u', "The number of numeric attributes to generate.", 5, 0, Integer.MAX_VALUE); public IntOption numValsPerNominalOption = new IntOption( "numValsPerNominal", 'v', "The number of values to generate per nominal attribute.", 5, 2, Integer.MAX_VALUE); public IntOption maxTreeDepthOption = new IntOption("maxTreeDepth", 'd', "The maximum depth of the tree concept.", 5, 0, Integer.MAX_VALUE); public IntOption firstLeafLevelOption = new IntOption( "firstLeafLevel", 'l', "The first level of the tree above maxTreeDepth that can have leaves.", 3, 0, Integer.MAX_VALUE); public FloatOption leafFractionOption = new FloatOption("leafFraction", 'f', "The fraction of leaves per level from firstLeafLevel onwards.", 0.15, 0.0, 1.0); protected static class Node implements Serializable { private static final long serialVersionUID = 1L; public int classLabel; public int splitAttIndex; public double splitAttValue; public Node[] children; } protected Node treeRoot; protected InstancesHeader streamHeader; protected Random instanceRandom; @Override public void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { monitor.setCurrentActivity("Preparing random tree...", -1.0); generateHeader(); generateRandomTree(); restart(); } @Override public long estimatedRemainingInstances() { return -1; } @Override public boolean isRestartable() { return true; } @Override public void restart() { this.instanceRandom = new Random(this.instanceRandomSeedOption.getValue()); } @Override public InstancesHeader getHeader() { return this.streamHeader; } @Override public boolean hasMoreInstances() { return true; } @Override public InstanceExample nextInstance() { double[] attVals = new double[this.numNominalsOption.getValue() + this.numNumericsOption.getValue()]; InstancesHeader header = getHeader(); Instance inst = new DenseInstance(header.numAttributes()); for (int i = 0; i < attVals.length; i++) { attVals[i] = i < this.numNominalsOption.getValue() ? this.instanceRandom.nextInt(this.numValsPerNominalOption .getValue()) : this.instanceRandom.nextDouble(); inst.setValue(i, attVals[i]); } inst.setDataset(header); inst.setClassValue(classifyInstance(this.treeRoot, attVals)); return new InstanceExample(inst); } protected int classifyInstance(Node node, double[] attVals) { if (node.children == null) { return node.classLabel; } if (node.splitAttIndex < this.numNominalsOption.getValue()) { return classifyInstance( node.children[(int) attVals[node.splitAttIndex]], attVals); } return classifyInstance( node.children[attVals[node.splitAttIndex] < node.splitAttValue ? 0 : 1], attVals); } protected void generateHeader() { FastVector<Attribute> attributes = new FastVector<>(); FastVector<String> nominalAttVals = new FastVector<>(); for (int i = 0; i < this.numValsPerNominalOption.getValue(); i++) { nominalAttVals.addElement("value" + (i + 1)); } for (int i = 0; i < this.numNominalsOption.getValue(); i++) { attributes.addElement(new Attribute("nominal" + (i + 1), nominalAttVals)); } for (int i = 0; i < this.numNumericsOption.getValue(); i++) { attributes.addElement(new Attribute("numeric" + (i + 1))); } FastVector<String> classLabels = new FastVector<>(); for (int i = 0; i < this.numClassesOption.getValue(); i++) { classLabels.addElement("class" + (i + 1)); } attributes.addElement(new Attribute("class", classLabels)); this.streamHeader = new InstancesHeader(new Instances( getCLICreationString(InstanceStream.class), attributes, 0)); this.streamHeader.setClassIndex(this.streamHeader.numAttributes() - 1); } protected void generateRandomTree() { Random treeRand = new Random(this.treeRandomSeedOption.getValue()); ArrayList<Integer> nominalAttCandidates = new ArrayList<>( this.numNominalsOption.getValue()); for (int i = 0; i < this.numNominalsOption.getValue(); i++) { nominalAttCandidates.add(i); } double[] minNumericVals = new double[this.numNumericsOption.getValue()]; double[] maxNumericVals = new double[this.numNumericsOption.getValue()]; for (int i = 0; i < this.numNumericsOption.getValue(); i++) { minNumericVals[i] = 0.0; maxNumericVals[i] = 1.0; } this.treeRoot = generateRandomTreeNode(0, nominalAttCandidates, minNumericVals, maxNumericVals, treeRand); } protected Node generateRandomTreeNode(int currentDepth, ArrayList<Integer> nominalAttCandidates, double[] minNumericVals, double[] maxNumericVals, Random treeRand) { if ((currentDepth >= this.maxTreeDepthOption.getValue()) || ((currentDepth >= this.firstLeafLevelOption.getValue()) && (this.leafFractionOption.getValue() >= (1.0 - treeRand .nextDouble())))) { Node leaf = new Node(); leaf.classLabel = treeRand.nextInt(this.numClassesOption.getValue()); return leaf; } Node node = new Node(); int chosenAtt = treeRand.nextInt(nominalAttCandidates.size() + this.numNumericsOption.getValue()); if (chosenAtt < nominalAttCandidates.size()) { node.splitAttIndex = nominalAttCandidates.get(chosenAtt); node.children = new Node[this.numValsPerNominalOption.getValue()]; ArrayList<Integer> newNominalCandidates = new ArrayList<>( nominalAttCandidates); newNominalCandidates.remove(new Integer(node.splitAttIndex)); newNominalCandidates.trimToSize(); for (int i = 0; i < node.children.length; i++) { node.children[i] = generateRandomTreeNode(currentDepth + 1, newNominalCandidates, minNumericVals, maxNumericVals, treeRand); } } else { int numericIndex = chosenAtt - nominalAttCandidates.size(); node.splitAttIndex = this.numNominalsOption.getValue() + numericIndex; double minVal = minNumericVals[numericIndex]; double maxVal = maxNumericVals[numericIndex]; node.splitAttValue = ((maxVal - minVal) * treeRand.nextDouble()) + minVal; node.children = new Node[2]; double[] newMaxVals = maxNumericVals.clone(); newMaxVals[numericIndex] = node.splitAttValue; node.children[0] = generateRandomTreeNode(currentDepth + 1, nominalAttCandidates, minNumericVals, newMaxVals, treeRand); double[] newMinVals = minNumericVals.clone(); newMinVals[numericIndex] = node.splitAttValue; node.children[1] = generateRandomTreeNode(currentDepth + 1, nominalAttCandidates, newMinVals, maxNumericVals, treeRand); } return node; } @Override public void getDescription(StringBuilder sb, int indent) { // TODO Auto-generated method stub } }