//----------------------------------------------------------------------------//
// //
// L i n e a r E v a l u a t o r //
// //
//----------------------------------------------------------------------------//
// <editor-fold defaultstate="collapsed" desc="hdr"> //
// Copyright © Hervé Bitteur and others 2000-2013. All rights reserved. //
// This software is released under the GNU General Public License. //
// Goto http://kenai.com/projects/audiveris to report bugs or suggestions. //
//----------------------------------------------------------------------------//
// </editor-fold>
package omr.math;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Marshaller;
import javax.xml.bind.Unmarshaller;
import javax.xml.bind.annotation.XmlAccessType;
import javax.xml.bind.annotation.XmlAccessorType;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlElementWrapper;
import javax.xml.bind.annotation.XmlID;
import javax.xml.bind.annotation.XmlIDREF;
import javax.xml.bind.annotation.XmlRootElement;
import javax.xml.bind.annotation.adapters.XmlAdapter;
import javax.xml.bind.annotation.adapters.XmlJavaTypeAdapter;
/**
* Class {@code LinearEvaluator} is an evaluator using linear regression.
*
* <p>It provides a distance between 2 "patterns". A pattern is a vector of
* parameter values in the input domain.
*
* <p>It provides a distance between a "pattern" from the input domain to a
* "category" in the output range, thus allowing to map patterns to categories.
* This feature can be used for example to map a given Glyph (through the
* pattern of its measured moments values) to the best fitting Shape category.
*
* <p>This evaluator can be trained, by feeding it with sample patterns for each
* defined category.
*
* <p>The evaluator data can be marshalled to and unmarshalled from an XML
* formatted stream.
*
* @author Hervé Bitteur
*/
@XmlAccessorType(XmlAccessType.NONE)
@XmlRootElement(name = "linear-evaluator")
public class LinearEvaluator
{
//~ Static fields/initializers ---------------------------------------------
/** Usual logger utility */
private static final Logger logger = LoggerFactory.getLogger(
LinearEvaluator.class);
/** Un/marshalling context for use with JAXB */
private static volatile JAXBContext jaxbContext;
/** To avoid infinity */
public static final double INFINITE_DISTANCE = 50e50;
/** To detect a near-zero value in a double */
private static final double EPSILON = 1E-10;
//~ Instance fields --------------------------------------------------------
/** A descriptor for each input parameter. */
@XmlElementWrapper(name = "defaults")
@XmlElement(name = "parameter")
private final Parameter[] parameters;
/** A descriptor for each output category. */
@XmlJavaTypeAdapter(CategoryMapAdapter.class)
@XmlElement(name = "categories")
private final SortedMap<String, Category> categories;
/**
* Flag to indicate that some data has changed since unmarshalling
* and that engine internals must be marshalled to disk before
* exiting.
*/
private boolean dataModified = false;
//~ Constructors -----------------------------------------------------------
//-----------------//
// LinearEvaluator //
//-----------------//
/**
* Creates a new LinearEvaluator object.
*
* @param inputNames the parameter names
*/
public LinearEvaluator (String[] inputNames)
{
categories = new TreeMap<>();
parameters = new Parameter[inputNames.length];
for (int i = 0; i < inputNames.length; i++) {
parameters[i] = new Parameter(inputNames[i]);
}
}
//-----------------//
// LinearEvaluator //
//-----------------//
/** Private no-arg constructor meant for the JAXB compiler only */
private LinearEvaluator ()
{
categories = null;
parameters = null;
}
//~ Methods ----------------------------------------------------------------
//
//-------------------//
// getParameterNames //
//-------------------//
/**
* Report the sequence of parameter names.
*
* @return the sequence of parameter names
*/
public String[] getParameterNames ()
{
if (parameters == null) {
return new String[0];
} else {
String[] names = new String[parameters.length];
for (int i = 0; i < parameters.length; i++) {
names[i] = parameters[i].name;
}
return names;
}
}
//------------------//
// getCategoryNames //
//------------------//
/**
* Report the collection of category names (order is irrelevant).
*
* @return the collection of category names
*/
public String[] getCategoryNames ()
{
if (categories == null) {
return new String[0];
} else {
Collection<Category> values = categories.values();
String[] names = new String[values.size()];
int index = 0;
for (Category cat : values) {
names[index++] = cat.getId();
}
return names;
}
}
//--------------//
// getInputSize //
//--------------//
/**
* Report the number of parameters in the input patterns.
*
* @return the count of pattern parameters
*/
public final int getInputSize ()
{
return parameters.length;
}
//------------------//
// categoryDistance //
//------------------//
/**
* Measure the "distance" information between a given pattern and
* (the mean pattern of) a category.
*
* @param pattern the value for each parameter of the pattern to evaluate
* @param categoryId the category id to measure distance from
* @return the measured distance
*/
public double categoryDistance (double[] pattern,
String categoryId)
{
return checkArguments(pattern, categoryId).distance(pattern, parameters);
}
//------//
// dump //
//------//
public void dump ()
{
System.out.println();
System.out.println("LinearEvaluator");
System.out.println("===============");
System.out.println();
// Input size
System.out.println("Inputs : " + getInputSize() + " parameters");
// Output size
System.out.println(
"Outputs : " + categories.keySet().size() + " categories");
// Description of each category
for (Category category : categories.values()) {
category.dump();
}
}
//--------------//
// dumpDistance //
//--------------//
/**
* Print out the "distance" information between a given pattern and
* a category.
* It's a sort of debug information.
*
* @param pattern the pattern at hand
* @param category the category to measure distance from
*/
public void dumpDistance (double[] pattern,
String category)
{
categories.get(category).dumpDistance(pattern, parameters);
}
//------------//
// getMaximum //
//------------//
/**
* Get the constraint test on maximum for a parameter of the
* provided category.
*
* @param paramIndex the impacted parameter
* @param categoryId the targeted category
* @return the current maximum value (null if test is disabled)
*/
public Double getMaximum (int paramIndex,
String categoryId)
{
return getCategoryParam(paramIndex, categoryId).max;
}
//------------//
// getMinimum //
//------------//
/**
* Get the constraint test on minimum for a parameter of the
* provided category.
*
* @param paramIndex the impacted parameter
* @param categoryId the targeted category
* @return the current minimum value (null if test is disabled)
*/
public Double getMinimum (int paramIndex,
String categoryId)
{
return getCategoryParam(paramIndex, categoryId).min;
}
//---------------//
// includeSample //
//---------------//
/**
* Include a new sample (on top of unmarshalled data).
* We use this to widen the min/max constraints, and also to increase
* the population and thus the categories training status.
*
* @param params the parameters
* @param categoryId the targeted category
* @return true if some min/max bound has changed
*/
public boolean includeSample (double[] params,
String categoryId)
{
// Check category label
Category category = categories.get(categoryId);
if (category == null) {
throw new IllegalArgumentException(
"Unknown category: " + categoryId);
}
boolean extended = category.include(params);
// Update categories parameters accordingly
computeCategoriesParams();
dataModified = true;
return extended;
}
//----------------//
// isDataModified //
//----------------//
/**
* @return true if some data has been modified since unmarshalling
*/
public boolean isDataModified ()
{
return dataModified;
}
//---------//
// marshal //
//---------//
/**
* Marshal the LinearEvaluator to its XML file.
*
* @param os the XML output stream, which is not closed by this method
* @exception JAXBException raised when marshalling goes wrong
*/
public void marshal (OutputStream os)
throws JAXBException
{
Marshaller m = getJaxbContext().createMarshaller();
m.setProperty(Marshaller.JAXB_FORMATTED_OUTPUT, true);
m.marshal(this, os);
logger.debug("LinearEvaluator marshalled");
}
//-----------------//
// patternDistance //
//-----------------//
/**
* Measure the "distance" information between two patterns.
*
* @param one the first pattern
* @param two the second pattern
* @return the measured distance between them
*/
public double patternDistance (double[] one,
double[] two)
{
final int inputSize = getInputSize();
// Check sizes
if ((one == null)
|| (one.length != inputSize)
|| (two == null)
|| (two.length != inputSize)) {
throw new IllegalArgumentException(
"Patterns are null or inconsistent with the LinearEvaluator");
}
double dist = 0;
for (int p = 0; p < inputSize; p++) {
double dif = one[p] - two[p];
dist += (dif * dif * parameters[p].defaultWeight);
}
return dist / inputSize;
}
//-------//
// train //
//-------//
/**
* Perform the training of the evaluator.
*
* @param samples a collection of samples (category + pattern)
*/
public void train (Collection<Sample> samples)
{
// Check size consistencies.
if ((samples == null) || samples.isEmpty()) {
throw new IllegalArgumentException(
"samples collection is null or empty");
}
// Reset counters for each category, if needed
for (Category category : categories.values()) {
for (CategoryParam param : category.params) {
param.reset();
}
}
// Accumulate data from samples into categories descriptors
for (Sample sample : samples) {
Category category = categories.get(sample.category);
if (category == null) {
category = new Category(sample.category, parameters);
categories.put(sample.category, category);
}
category.include(sample.pattern);
logger.debug("Accu {} count:{}",
category.getId(), category.getCardinality());
}
computeCategoriesParams();
}
//-----------//
// unmarshal //
//-----------//
/**
* Unmarshal the provided XML stream to allocate the corresponding
* LinearEvaluator.
*
* @param in the input stream that contains the evaluator definition in XML
* format. The stream is not closed by this method
* @return the allocated network.
* @exception JAXBException raised when unmarshalling goes wrong
*/
public static LinearEvaluator unmarshal (InputStream in)
throws JAXBException
{
Unmarshaller um = getJaxbContext().createUnmarshaller();
LinearEvaluator evaluator = (LinearEvaluator) um.unmarshal(in);
logger.debug("LinearEvaluator unmarshalled");
return evaluator;
}
//----------------//
// getJaxbContext //
//----------------//
private static JAXBContext getJaxbContext ()
throws JAXBException
{
// Lazy creation
if (jaxbContext == null) {
jaxbContext = JAXBContext.newInstance(LinearEvaluator.class);
}
return jaxbContext;
}
//----------------//
// checkArguments //
//----------------//
private Category checkArguments (double[] pattern,
String categoryId)
{
// Check sizes
if ((pattern == null) || (pattern.length != getInputSize())) {
throw new IllegalArgumentException(
"Pattern is null or inconsistent with the LinearEvaluator");
}
// Check category label
Category category = categories.get(categoryId);
if (category == null) {
throw new IllegalArgumentException(
"Unknown category: " + categoryId);
}
return category;
}
//-------------------------//
// computeCategoriesParams //
//-------------------------//
private void computeCategoriesParams ()
{
// Compute parameters means & weights for each category
for (Category category : categories.values()) {
logger.debug("Computing {} count:{}",
category.getId(), category.getCardinality());
category.compute();
}
// Compute default weight for each parameter
// (using the sample populations of all categories)
for (int p = 0; p < parameters.length; p++) {
Population paramPop = new Population();
for (Category category : categories.values()) {
CategoryParam param = category.params[p];
if (param.training != CategoryParam.TrainingStatus.NONE) {
paramPop.includePopulation(param.population);
}
}
if (paramPop.getCardinality() > 1) {
double var = paramPop.getVariance();
if (var >= EPSILON) {
parameters[p].defaultWeight = 1 / var;
}
}
}
}
//------------------//
// getCategoryParam //
//------------------//
private CategoryParam getCategoryParam (int paramIndex,
String categoryId)
{
// Check category label
Category category = categories.get(categoryId);
if (category == null) {
throw new IllegalArgumentException(
"Unknown category: " + categoryId);
}
return category.params[paramIndex];
}
//~ Inner Classes ----------------------------------------------------------
//--------//
// Sample //
//--------//
/**
* Meant to host one sample for training, representing pattern
* values for a given category.
*/
public static class Sample
{
//~ Instance fields ----------------------------------------------------
/** The known category */
public final String category;
/** The observed pattern */
public final double[] pattern;
//~ Constructors -------------------------------------------------------
public Sample (String category,
double[] pattern)
{
this.category = category;
this.pattern = pattern;
}
//~ Methods ------------------------------------------------------------
@Override
public String toString ()
{
StringBuilder sb = new StringBuilder("{");
sb.append(getClass().getSimpleName());
sb.append(" ").append(category);
sb.append(" ").append(Arrays.toString(pattern));
sb.append("}");
return sb.toString();
}
}
//---------//
// Printer //
//---------//
/**
* Printouts meant for analysis of behavior of LinearEvaluator.
*/
public class Printer
{
//~ Instance fields ----------------------------------------------------
// Format strings
private final String sf; // For String
private final String df; // For double
//~ Constructors -------------------------------------------------------
public Printer (int width)
{
sf = "%" + width + "s";
df = "%" + width + "f";
}
//~ Methods ------------------------------------------------------------
public String getDashes ()
{
StringBuilder sb = new StringBuilder();
for (int p = 0; p < parameters.length; p++) {
sb.append(String.format(sf, "----------"));
}
return sb.toString();
}
public String getDefaults ()
{
StringBuilder sb = new StringBuilder();
for (Parameter param : parameters) {
sb.append(String.format(df, param.defaultWeight));
}
return sb.toString();
}
public String getDeltas (double[] one,
double[] two)
{
StringBuilder sb = new StringBuilder();
for (int p = 0; p < parameters.length; p++) {
double dif = one[p] - two[p];
sb.append(String.format(df, dif * dif));
}
return sb.toString();
}
public String getNames ()
{
StringBuilder sb = new StringBuilder();
for (Parameter param : parameters) {
sb.append(String.format(sf, param.name));
}
return sb.toString();
}
public String getWeightedDeltas (double[] one,
double[] two)
{
StringBuilder sb = new StringBuilder();
for (int p = 0; p < parameters.length; p++) {
double dif = one[p] - two[p];
sb.append(
String.format(df,
dif * dif * parameters[p].defaultWeight));
}
return sb.toString();
}
}
//----------//
// Category //
//----------//
/**
* Meant to encapsulate the regression data for one category.
*/
private static class Category
{
//~ Instance fields ----------------------------------------------------
/** Category id */
@XmlAttribute(name = "id")
private final String id;
/** A specific descriptor for each parameter */
@XmlElement(name = "parameter")
final CategoryParam[] params;
//~ Constructors -------------------------------------------------------
/**
* Creates a new Category object.
*
* @param id the category id
* @param parameters the sequence of parameter descriptors
*/
public Category (String id,
Parameter[] parameters)
{
this.id = id;
params = new CategoryParam[parameters.length];
for (int p = 0; p < params.length; p++) {
params[p] = new CategoryParam(parameters[p]);
}
}
/**
* Meant to please JAXB
*/
private Category ()
{
id = null;
params = null;
}
//~ Methods ------------------------------------------------------------
public void compute ()
{
if (getCardinality() > 0) {
for (CategoryParam param : params) {
try {
param.compute();
} catch (Exception ex) {
logger.warn(
"Category {} cannot compute parameters ex:{}",
id, ex);
}
}
} else {
logger.warn("Category {} has no sample", id);
}
}
public synchronized double distance (double[] pattern,
Parameter[] parameters)
{
double dist = 0;
for (int p = 0; p < params.length; p++) {
dist += params[p].weightedDelta(
pattern[p],
parameters[p].defaultWeight);
}
dist /= params.length;
return dist;
}
public synchronized void dump ()
{
System.out.println(
"\ncategory:" + id + " cardinality:" + getCardinality());
for (CategoryParam param : params) {
param.dump();
}
}
public synchronized double dumpDistance (double[] pattern,
Parameter[] parameters)
{
if ((pattern == null) || (pattern.length != params.length)) {
throw new IllegalArgumentException(
"dumpDistance."
+ " Pattern array is null or non compatible in length ");
}
if (getCardinality() >= 2) {
double dist = 0;
for (int p = 0; p < params.length; p++) {
CategoryParam param = params[p];
double wDelta = param.weightedDelta(
pattern[p],
parameters[p].defaultWeight);
dist += wDelta;
System.out.printf(
"%2d-> weight:%e wDelta:%e\n",
p,
param.weight,
wDelta);
}
dist /= params.length;
System.out.println("Dist to cat " + id + " = " + dist);
return dist;
} else {
return INFINITE_DISTANCE;
}
}
public int getCardinality ()
{
return params[0].population.getCardinality();
}
/**
* @return the id
*/
public String getId ()
{
return id;
}
/** Include data from the provided pattern into category descriptor */
public synchronized boolean include (double[] pattern)
{
boolean extended = false;
if ((pattern == null) || (pattern.length != params.length)) {
throw new IllegalArgumentException(
"include."
+ " Pattern array is null or non compatible in length ");
}
for (int p = 0; p < params.length; p++) {
if (params[p].includeValue(pattern[p])) {
extended = true;
}
}
return extended;
}
}
//--------------------//
// CategoryMapAdapter //
//--------------------//
/**
* Meant for JAXB support of a map.
*/
private static class CategoryMapAdapter
extends XmlAdapter<Category[], Map<String, Category>>
{
//~ Constructors -------------------------------------------------------
/**
* Meant to please JAXB
*/
public CategoryMapAdapter ()
{
}
//~ Methods ------------------------------------------------------------
//-----------//
// unmarshal //
//-----------//
@Override
public Category[] marshal (Map<String, Category> map)
throws Exception
{
return map.values().toArray(new Category[map.size()]);
}
//-----------//
// unmarshal //
//-----------//
@Override
public Map<String, Category> unmarshal (Category[] categories)
{
SortedMap<String, Category> map = new TreeMap<>();
for (Category category : categories) {
map.put(category.getId(), category);
}
return map;
}
}
//---------------//
// CategoryParam //
//---------------//
/**
* Meant to encapsulate the regression data for one parameter in
* the context of a category.
*/
private static class CategoryParam
{
//~ Static fields/initializers -----------------------------------------
/** Used instead of infinitive weight, when variance is zero */
private static final double HIGH_WEIGHT_FACTOR = 10;
//~ Enumerations -------------------------------------------------------
/** Description of the training done so far on a parameter */
public static enum TrainingStatus
{
//~ Enumeration constant initializers ------------------------------
/**
* Not trained
* => no mean value, no weight
*/
NONE,
/**
* Just one data element
* => a mean value, but artificial (average) weight
*/
SINGLE_DATA,
/**
* Several data elements, but with identical values
* => a mean value, but infinite weight
*/
IDENTICAL_VALUES,
/**
* Several data elements, with some variation in the values
* => a mean value and weight computed as 1/variance
*/
NOMINAL;
}
//~ Instance fields ----------------------------------------------------
/** Population to compute mean value & std deviation */
@XmlElement(name = "population")
private Population population;
/** Maximum value for this parameter */
@XmlAttribute(name = "max")
private Double max = null;
/** Mean value for this parameter */
@XmlAttribute(name = "mean")
private double mean;
/** Minimum value for this parameter */
@XmlAttribute(name = "min")
private Double min = null;
/** Weight for this parameter */
@XmlAttribute(name = "weight")
private double weight;
/** Training status */
@XmlAttribute(name = "training")
private TrainingStatus training = TrainingStatus.NONE;
/** Related parameter descriptor */
@XmlIDREF
@XmlAttribute(name = "name")
private Parameter parameter;
//~ Constructors -------------------------------------------------------
public CategoryParam (Parameter parameter)
{
this.parameter = parameter;
population = new Population();
}
/**
* Meant to please JAXB
*/
public CategoryParam ()
{
}
//~ Methods ------------------------------------------------------------
/** Compute the param characteristics out of its data sample */
public void compute ()
{
int count = population.getCardinality();
if (count > 0) {
mean = population.getMeanValue();
}
if (count == 1) {
training = TrainingStatus.SINGLE_DATA;
} else {
double var = population.getVariance();
if (var < EPSILON) {
training = TrainingStatus.IDENTICAL_VALUES;
} else {
training = TrainingStatus.NOMINAL;
weight = 1 / var;
}
}
}
public void dump ()
{
StringBuilder sb = new StringBuilder();
sb.append(" ").append(parameter);
sb.append(" training=").append(training);
sb.append(" min=").append(min);
sb.append(" mean=").append(mean);
sb.append(" max=").append(max);
sb.append(" weight=").append(weight);
if (population.getCardinality() > 1) {
sb.append(" var=").append(population.getVariance());
}
System.out.println(sb);
}
/**
* Include a new value for this category parameter.
*
* @param val the new value
* @return true if any of the min/max bounds has changed
*/
public boolean includeValue (double val)
{
boolean extended = false;
// Cumulate into Population
population.includeValue(val);
// Handle min value
if (min != null) {
if (val < min) {
min = val;
extended = true;
}
} else {
min = val;
extended = true;
}
// Handle max value
if (max != null) {
if (val > max) {
max = val;
extended = true;
}
} else {
max = val;
extended = true;
}
return extended;
}
public void reset ()
{
population.reset();
min = null;
max = null;
mean = 0;
weight = 0;
}
/**
* Report the weighted square delta of a value vs param mean value
*
* @param val the observed value
* @param stdWeight the standard average weight
* @return the weighted square delta
*/
public double weightedDelta (double val,
double stdWeight)
{
if (training == TrainingStatus.NONE) {
return INFINITE_DISTANCE;
} else {
double dif = mean - val;
return dif * dif * getWeight(stdWeight);
}
}
/**
* Report the proper value to be used for parameter weight.
*
* @param stdWeight the standard average weight
* @return the proper weight value
*/
private double getWeight (double stdWeight)
{
switch (training) {
case NONE:
return 0;
case SINGLE_DATA:
return stdWeight;
case IDENTICAL_VALUES:
return stdWeight * HIGH_WEIGHT_FACTOR;
default:
return weight;
}
}
}
//-----------//
// Parameter //
//-----------//
/**
* Description of an input parameter for the LinearEvaluator.
*/
private static class Parameter
{
//~ Instance fields ----------------------------------------------------
/** Default weight */
@XmlAttribute(name = "weight")
public double defaultWeight;
/** Name used for this parameter */
@XmlID
@XmlAttribute(name = "name")
public final String name;
//~ Constructors -------------------------------------------------------
/**
* Creates a new Parameter object.
*
* @param name the unique name for this parameter
*/
public Parameter (String name)
{
this.name = name;
}
/**
* Needed by JAXB
*/
public Parameter ()
{
name = null;
}
//~ Methods ------------------------------------------------------------
@Override
public String toString ()
{
return "{Param " + name + " defaultWeight:" + defaultWeight + "}";
}
}
}