/** * */ package org.streaminer.stream.classifier; import org.streaminer.stream.data.Data; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * <p> * This classifier predicts a random class. * </p> * * @author Christian Bockermann <chris@jwall.org> * */ public class MultiRandomClassifier extends AbstractClassifier<Data, Map<String,String>> { /** The unique class ID */ private static final long serialVersionUID = 3687537399872562759L; static Logger log = LoggerFactory.getLogger( MultiRandomClassifier.class ); Random rnd = new Random(); /* The label attribute to learn from */ String labelAttribute = null; /* The classes from which to choose one for prediction */ Map<String,Set<String>> classes = new HashMap<String,Set<String>>(); /** * @return the labelAttribute */ public String getLabelAttributes() { return labelAttribute; } /** * @param labelAttribute the labelAttribute to set */ public void setLabelAttributes(String labelAttribute) { this.labelAttribute = labelAttribute; if( labelAttribute.indexOf( "," ) >= 0 ){ for( String label : labelAttribute.split( "," ) ){ String l = label.trim(); if( ! "".equals( l ) ){ classes.put( l, new HashSet<String>() ); } } } } /** * @see stream.learner.AbstractClassifier#learn(java.lang.Object) */ @Override public void learn(Data item) { if( classes.isEmpty() ){ for( String key : item.keySet() ) classes.put( key, new HashSet<String>() ); } for( String labelAttribute : classes.keySet() ){ Set<String> classLabels = classes.get( labelAttribute ); if( classLabels == null ) classLabels = new HashSet<String>(); String clazz = "" + item.get( labelAttribute ); if( !classes.get(labelAttribute).contains( clazz ) ) classes.get(labelAttribute).add( clazz ); } } /** * @see stream.learner.AbstractClassifier#predict(java.lang.Object) */ @Override public Map<String,String> predict(Data item) { Map<String,String> pred = new LinkedHashMap<String,String>(); if( classes.isEmpty() ) return pred; for( String key : classes.keySet() ){ String prediction = guess( key ); pred.put(key, prediction ); } return pred; } protected String guess( String label ){ int random = Math.abs( rnd.nextInt() ); List<String> list = new ArrayList<String>( classes.get( label ) ); String guess = list.get( random % list.size() ); log.debug( "Guessing {} => {}", list, guess ); return guess; } }