/*
* File: ClusterDistanceEvaluator.java
* Authors: Justin Basilico
* Project: Cognitive Foundry Learning Core
*
* Copyright 2011 Cognitive Foundry. All rights reserved.
*/
package gov.sandia.cognition.learning.function.distance;
import gov.sandia.cognition.data.convert.vector.AbstractToVectorEncoder;
import gov.sandia.cognition.learning.algorithm.AbstractBatchLearnerContainer;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.math.DivergenceFunction;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.VectorOutputEvaluator;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
/**
* Evaluates the divergence (distance) between an input and a list of values,
* storing the resulting divergence values in a vector. This can be used as a
* feature representation built from something like a clustering algorithm or
* from a set of prototype/basis elements.
*
* @param <InputType>
* The type of input value that the class evaluates. It is the second
* parameter passed to the divergence function. Typically a type like
* a Vector.
* @param <ValueType>
* The type of value that the divergence is computed from. It is the
* first parameter passed to the divergence function. It is typically a
* type like Vector or CenteroidCluster.
* @author Justin Basilico
* @since 3.3.3
*/
public class DivergencesEvaluator<InputType, ValueType>
extends AbstractToVectorEncoder<InputType>
implements VectorOutputEvaluator<InputType, Vector>,
DivergenceFunctionContainer<ValueType, InputType>
{
/** The divergence function to apply between the data and the input. */
protected DivergenceFunction<? super ValueType, ? super InputType> divergenceFunction;
/** The data to evaluate the divergence from. */
protected Collection<ValueType> values;
/**
* Creates a new {@code DivergencesEvaluator} with a null divergence
* function and an empty set of values.
*/
public DivergencesEvaluator()
{
this(null, new ArrayList<ValueType>());
}
/**
* Creates a new {@code DivergencesEvaluator} with the given divergence
* and values.
*
* @param divergenceFunction
* The divergence function to use.
* @param values
* The values to calculate the divergence from.
*/
public DivergencesEvaluator(
final DivergenceFunction<? super ValueType, ? super InputType> divergenceFunction,
final Collection<ValueType> values)
{
this(divergenceFunction, values, VectorFactory.getDefault());
}
/**
* Creates a new {@code DivergencesEvaluator} with the given divergence
* and values.
*
* @param divergenceFunction
* The divergence function to use.
* @param values
* The values to calculate the divergence from.
* @param vectorFactory
* The vector factory to use.
*/
public DivergencesEvaluator(
final DivergenceFunction<? super ValueType, ? super InputType> divergenceFunction,
final Collection<ValueType> values,
final VectorFactory<?> vectorFactory)
{
super(vectorFactory);
this.setDivergenceFunction(divergenceFunction);
this.setValues(values);
}
@Override
public DivergencesEvaluator<InputType, ValueType> clone()
{
@SuppressWarnings("unchecked")
final DivergencesEvaluator<InputType, ValueType> clone = (DivergencesEvaluator<InputType, ValueType>)
super.clone();
clone.divergenceFunction = ObjectUtil.cloneSmart(this.divergenceFunction);
clone.values = ObjectUtil.cloneSmartElementsAsArrayList(this.values);
return clone;
}
@Override
public void encode(
final InputType input,
final Vector result,
final int startIndex)
{
// Go through the values and compute the divergence to each one.
int index = startIndex;
for (final ValueType cluster : this.getValues())
{
final double distance =
this.divergenceFunction.evaluate(cluster, input);
result.setElement(index, distance);
index++;
}
}
@Override
public int getOutputDimensionality()
{
return this.getValues().size();
}
@Override
public DivergenceFunction<? super ValueType, ? super InputType> getDivergenceFunction()
{
return this.divergenceFunction;
}
/**
* Sets the divergence function to use from the values to the inputs.
*
* @param divergenceFunction
* The divergence function to use.
*/
public void setDivergenceFunction(
final DivergenceFunction<? super ValueType, ? super InputType> divergenceFunction)
{
this.divergenceFunction = divergenceFunction;
}
/**
* Gets the values that the divergence is computed from using the
* divergence function to the input.
*
* @return
* The values that the distance is computed from.
*/
public Collection<ValueType> getValues()
{
return this.values;
}
/**
* Sets the values that the divergence is computed from using the
* divergence function to the input.
*
* @param values
* The values that the distance is computed from.
*/
public void setValues(
final Collection<ValueType> values)
{
this.values = values;
}
/**
* Convenience method for creation a {@code DivergeceEvaluator}.
*
* @param <InputType>
* The type of input value that the class evaluates.
* @param <ValueType>
* The type of value that the divergence is computed from.
* @param divergenceFunction
* The divergence function to use.
* @param values
* The values to calculate the divergence from.
* @return
* A new evaluator.
*/
public static <InputType, ValueType> DivergencesEvaluator<InputType, ValueType>
create(
final DivergenceFunction<? super ValueType, ? super InputType> divergenceFunction,
final Collection<ValueType> values)
{
return new DivergencesEvaluator<InputType, ValueType>(
divergenceFunction, values);
}
/**
* A learner adapter for the {@code DivergencesEvaluator}. It calls a
* base learner and then wraps learned collection of values in an evaluator
* that uses the given divergence function.
*
* @param <DataType>
* The data type for learning. Passed to the wrapped learner.
* @param <InputType>
* The input type for the evaluator.
* @param <ValueType>
* The value type that is the output of learning and is used as the
* values in the learned evaluator.
*/
public static class Learner<DataType, InputType, ValueType>
extends AbstractBatchLearnerContainer<BatchLearner<? super DataType, ? extends Collection<ValueType>>>
implements BatchLearner<DataType, DivergencesEvaluator<InputType, ValueType>>,
DivergenceFunctionContainer<ValueType, InputType>,
VectorFactoryContainer
{
/** The divergence function to apply between the data and the input. */
protected DivergenceFunction<? super ValueType, ? super InputType> divergenceFunction;
/** The vector factory to use. */
protected VectorFactory<?> vectorFactory;
/**
* Creates a new {@code DivergenceFunction.Learner} with null base
* learner and divergence functions.
*/
public Learner()
{
this(null, null);
}
/**
* Creates a new {@code DivergenceFunction.Learner} with the given
* properties.
*
* @param learner
* The base learner to use.
* @param divergenceFunction
* The divergence function to use.
*/
public Learner(
final BatchLearner<DataType, ? extends Collection<ValueType>> learner,
final DivergenceFunction<? super ValueType, ? super InputType> divergenceFunction)
{
this(learner, divergenceFunction, VectorFactory.getDefault());
}
/**
* Creates a new {@code DivergenceFunction.Learner} with the given
* properties.
*
* @param learner
* The base learner to use.
* @param divergenceFunction
* The divergence function to use.
* @param vectorFactory
* The vector factory to use.
*/
public Learner(
final BatchLearner<DataType, ? extends Collection<ValueType>> learner,
final DivergenceFunction<? super ValueType, ? super InputType> divergenceFunction,
final VectorFactory<?> vectorFactory)
{
super(learner);
this.setDivergenceFunction(divergenceFunction);
this.setVectorFactory(vectorFactory);
}
@Override
public Learner<DataType, InputType, ValueType> clone()
{
@SuppressWarnings("unchecked")
final Learner<DataType, InputType, ValueType> clone = (Learner<DataType, InputType, ValueType>)
super.clone();
clone.divergenceFunction = ObjectUtil.cloneSmart(this.divergenceFunction);
return clone;
}
@Override
public DivergencesEvaluator<InputType, ValueType> learn(
final DataType data)
{
return new DivergencesEvaluator<InputType, ValueType>(
this.getDivergenceFunction(),
this.getLearner().learn(data),
this.getVectorFactory());
}
@Override
public DivergenceFunction<? super ValueType, ? super InputType> getDivergenceFunction()
{
return this.divergenceFunction;
}
/**
* Sets the divergence function to use from the values to the inputs.
*
* @param divergenceFunction
* The divergence function to use.
*/
public void setDivergenceFunction(
final DivergenceFunction<? super ValueType, ? super InputType> divergenceFunction)
{
this.divergenceFunction = divergenceFunction;
}
@Override
public VectorFactory<? extends Vector> getVectorFactory()
{
return this.vectorFactory;
}
/**
* Sets the vector factory to use.
*
* @param vectorFactory
* The vector factory to use.
*/
public void setVectorFactory(
final VectorFactory<?> vectorFactory)
{
this.vectorFactory = vectorFactory;
}
/**
* Convenience method for creating a
* {@code DivergencesEvaluator.Learner}.
*
* @param <DataType>
* The data type for learning. Passed to the wrapped learner.
* @param <InputType>
* The input type for the evaluator.
* @param <ValueType>
* The value type that is the output of learning and is used as the
* values in the learned evaluator.
* @param learner
* The base learner to use.
* @param divergenceFunction
* The divergence function to use.
* @return
* A new learner.
*/
public static <DataType, InputType, ValueType> Learner<DataType, InputType, ValueType>
create(
final BatchLearner<DataType, ? extends Collection<ValueType>> learner,
final DivergenceFunction<? super ValueType, ? super InputType> divergenceFunction)
{
return new Learner<DataType, InputType, ValueType>(
learner, divergenceFunction);
}
}
}