/*
* File: AbstractCountingGlobalWeighter.java
* Authors: Justin Basilico
* Company: Sandia National Laboratories
* Project: Cognitive Foundry
*
* Copyright April 20, 2009, Sandia Corporation.
* Under the terms of Contract DE-AC04-94AL85000, there is a non-exclusive
* license for use of this work by or on behalf of the U.S. Government. Export
* of this program may require a license from the United States Government.
* See CopyrightHistory.txt for complete details.
*
*/
package gov.sandia.cognition.text.term.vector.weighter.global;
import gov.sandia.cognition.math.matrix.DimensionalityMismatchException;
import gov.sandia.cognition.math.matrix.SparseVectorFactory;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorEntry;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.util.ObjectUtil;
/**
* An abstract {@code GlobalTermWeighter} that keeps track of term frequencies
* in documents. For each term, it keeps track of both the document frequency
* (the number of documents the term appears in) and the global frequency
* (the total number of times the term appears). It also keeps track of the
* total number of documents.
*
* @author Justin Basilico
* @since 3.0
*/
public abstract class AbstractFrequencyBasedGlobalTermWeighter
extends AbstractGlobalTermWeighter
{
/** The number of documents the weight is computed over. */
protected int documentCount;
/** The vector containing the number of documents that each term occurs in.
*/
protected Vector termDocumentFrequencies;
/** A vector containing the total number of times that each term occurred
* in the document set. */
protected Vector termGlobalFrequencies;
/**
* Creates a new {@code AbstractCountingBasedGlobalTermWeighter}.
*/
public AbstractFrequencyBasedGlobalTermWeighter()
{
this(SparseVectorFactory.getDefault());
}
/**
* Creates a new {@code AbstractCountingBasedGlobalTermWeighter}.
*
* @param vectorFactory
* The vector factory to use.
*/
public AbstractFrequencyBasedGlobalTermWeighter(
final VectorFactory<? extends Vector> vectorFactory)
{
super(vectorFactory);
this.setDocumentCount(0);
this.setTermDocumentFrequencies(null);
this.setTermGlobalFrequencies(null);
}
@Override
public AbstractFrequencyBasedGlobalTermWeighter clone()
{
final AbstractFrequencyBasedGlobalTermWeighter clone =
(AbstractFrequencyBasedGlobalTermWeighter) super.clone();
clone.termDocumentFrequencies = ObjectUtil.cloneSafe(
this.termDocumentFrequencies);
clone.termGlobalFrequencies = ObjectUtil.cloneSafe(
this.termGlobalFrequencies);
return clone;
}
public void add(
final Vector counts)
{
final int dimensionality = counts.getDimensionality();
if (this.termDocumentFrequencies == null)
{
// Initialize the internal vectors for the given dimensionality.
this.initializeVectors(dimensionality);
}
else
{
int currentDimensionality =
this.termDocumentFrequencies.getDimensionality();
if (dimensionality < currentDimensionality)
{
throw new DimensionalityMismatchException(
"dimensionality must be at least "
+ this.termDocumentFrequencies.getDimensionality());
}
else if (dimensionality > currentDimensionality)
{
// We need to grow the vectors to support the new dimensionality.
this.growVectors(dimensionality);
currentDimensionality = dimensionality;
}
}
this.documentCount++;
// Increment the global frequencies.
this.termGlobalFrequencies.plusEquals(counts);
// Increment the count of the number of documents a term occurrs in
// for the nonzero entries of the counts.
for (VectorEntry entry : counts)
{
if (entry.getValue() != 0.0)
{
final int index = entry.getIndex();
final double count =
this.termDocumentFrequencies.getElement(index) + 1;
this.termDocumentFrequencies.setElement(index, count);
}
}
}
public boolean remove(
final Vector counts)
{
// TODO: Should we first verify that the document was in the collection?
// For example, making sure that subtrating all the document counts make
// the count >= 0?
// Make sure that the dimensionalities match.
this.termDocumentFrequencies.assertSameDimensionality(counts);
// We're removing the document, so decrease the total count.
this.documentCount--;
// Update the global frequencies.
this.termGlobalFrequencies.minusEquals(counts);
// Decrement the count of the number of documents a term occurrs in
// for the nonzero entries of the counts.
for (VectorEntry entry : counts)
{
if (entry.getValue() != 0.0)
{
final int index = entry.getIndex();
final double count =
this.termDocumentFrequencies.getElement(index) - 1;
this.termDocumentFrequencies.setElement(index, count);
}
}
return true;
}
/**
* Initializes internal vectors to the given dimensionality.
*
* @param dimensionality
* The dimensionality to initialize to.
*/
protected void initializeVectors(
final int dimensionality)
{
this.termDocumentFrequencies = this.getVectorFactory().createVector(
dimensionality);
this.termGlobalFrequencies = this.getVectorFactory().createVector(
dimensionality);
}
/**
* Called when the dimensionality of the term vector grows.
*
* @param newDimensionality
* The new dimensionality;
*/
protected void growVectors(
final int newDimensionality)
{
// We need to grow the vector to hold more data.
// TODO: This is an ugly way of growing a vector.
final Vector difference = this.getVectorFactory().createVector(
newDimensionality - this.termDocumentFrequencies.getDimensionality());
this.termDocumentFrequencies = this.termDocumentFrequencies.stack(
difference);
this.termGlobalFrequencies = this.termGlobalFrequencies.stack(
difference);
}
public int getDocumentCount()
{
return this.documentCount;
}
/**
* Sets the document count.
*
* @param documentCount
* The document count.
*/
protected void setDocumentCount(
final int documentCount)
{
this.documentCount = documentCount;
}
/**
* Gets the vector containing the number of documents that each term
* appears in.
*
* @return
* The term document frequencies.
*/
public Vector getTermDocumentFrequencies()
{
return this.termDocumentFrequencies;
}
/**
* Sets the vector containing the number of documents that each term
* appears in.
*
* @param termDocumentFrequencies
* The document frequencies.
*/
protected void setTermDocumentFrequencies(
final Vector termDocumentFrequencies)
{
this.termDocumentFrequencies = termDocumentFrequencies;
}
/**
* Gets the vector containing the number of times that each term appears.
*
* @return
* The term global frequencies.
*/
public Vector getTermGlobalFrequencies()
{
return this.termGlobalFrequencies;
}
/**
* Gets the vector containing the number of times that each term appears.
*
* @param termGlobalFrequencies
* The term global frequencies.
*/
protected void setTermGlobalFrequencies(
final Vector termGlobalFrequencies)
{
this.termGlobalFrequencies = termGlobalFrequencies;
}
}