/*
* File: BagOfWordsTransform.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright February 10, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.text.term.vector;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.math.matrix.SparseVectorFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.DefaultVectorFactoryContainer;
import gov.sandia.cognition.text.term.DefaultTermIndex;
import gov.sandia.cognition.text.term.Term;
import gov.sandia.cognition.text.term.TermIndex;
import gov.sandia.cognition.text.term.Termable;
/**
* Transforms a list of term occurrences into a vector of counts.
*
* @author Justin Basilico
* @since 3.0
*/
public class BagOfWordsTransform
extends DefaultVectorFactoryContainer
implements Evaluator<Iterable<? extends Termable>, Vector>
{
/** Gets the term index used by the transform. Maps terms to indices in
* the vector. */
protected TermIndex termIndex;
/**
* Creates a new {@code BagOfWordsTransform}. Starts with an empty term
* index.
*/
public BagOfWordsTransform()
{
this(new DefaultTermIndex());
}
/**
* Creates a new {@code BagOfWordsTransform} with the given term index.
*
* @param termIndex
* The term index to use to map terms to vector indices.
*/
public BagOfWordsTransform(
final TermIndex termIndex)
{
this(termIndex, SparseVectorFactory.getDefault());
}
/**
* Creates a new {@code BagOfWordsTransform} with the given term index.
*
* @param termIndex
* The term index to use to map terms to vector indices.
* @param vectorFactory
* The vector factory to use.
*/
public BagOfWordsTransform(
final TermIndex termIndex,
final VectorFactory<? extends Vector> vectorFactory)
{
super();
this.setTermIndex(termIndex);
this.setVectorFactory(vectorFactory);
}
public Vector evaluate(
final Iterable<? extends Termable> terms)
{
return this.convertToVector(terms);
}
/**
* Converts a given list of terms to a vector by counting the occurrence of
* each term.
*
* @param terms
* The terms to count.
* @return
* The bag-of-words vector representation of the terms, which is the
* count of how many times each term occurs in the document.
*/
public Vector convertToVector(
final Iterable<? extends Termable> terms)
{
return this.convertToVector(terms, this.getVectorFactory());
}
/**
* Converts a given list of terms to a vector by counting the occurrence of
* each term.
*
* @param terms
* The terms to count.
* @param vectorFactory
* The vector factory to use to create the vector.
* @return
* The bag-of-words vector representation of the terms, which is the
* count of how many times each term occurs in the document.
*/
public Vector convertToVector(
final Iterable<? extends Termable> terms,
final VectorFactory<?> vectorFactory)
{
return convertToVector(terms, this.getTermIndex(), vectorFactory);
}
/**
* Converts a given list of terms to a vector by counting the occurrence of
* each term.
*
* @param terms
* The terms to count.
* @param termIndex
* The term index to use to map terms to their vector indices.
* @param vectorFactory
* The vector factory to use to create the vector.
* @return
* The bag-of-words vector representation of the terms, which is the
* count of how many times each term occurs in the document.
*/
public static Vector convertToVector(
final Iterable<? extends Termable> terms,
final TermIndex termIndex,
final VectorFactory<?> vectorFactory)
{
// Create the vector to store the result.
final Vector result = vectorFactory.createVector(
termIndex.getTermCount());
for (Termable termable : terms)
{
final Term term = termable.asTerm();
int index = termIndex.getIndex(term);
if (index >= 0)
{
final double count = result.getElement(index);
result.setElement(index, count + 1);
}
// TODO: Ideally we would somehow handle all of the "unknown"
// elements also. Perhaps by using the first vector element for
// unknowns.
}
return result;
}
/**
* Gets the term index that the transform uses to map terms to their vector
* indices.
*
* @return
* The term index used by the transform.
*/
public TermIndex getTermIndex()
{
return this.termIndex;
}
/**
* Sets the term index that the transform is to use to map terms to their
* vector indices.
*
* @param termIndex
* The term index for the transform to use.
*/
public void setTermIndex(
final TermIndex termIndex)
{
this.termIndex = termIndex;
}
}