package com.knowledgebooks.nlp;
import com.knowledgebooks.nlp.util.NameValue;
import com.knowledgebooks.public_domain.Stemmer;
import org.xml.sax.Attributes;
import org.xml.sax.SAXException;
import org.xml.sax.helpers.DefaultHandler;
import javax.xml.parsers.SAXParser;
import javax.xml.parsers.SAXParserFactory;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.*;
/**
* Associate pre-trained classification categories (tags) with input text: assigns
* categories for news story types, technology category types, social information
* types, etc. to input text.
*/
/**
* Copyright Mark Watson 2008-2010. All Rights Reserved.
* License: LGPL version 3 (http://www.gnu.org/licenses/lgpl-3.0.txt)
*/
public class AutoTagger {
// Make the following package visible so the class
// ExtractSearchTerms can use this data:
static Hashtable<String, Hashtable<String, Float>> tagClasses;
static String[] tagClassNames;
static List<Hashtable<String, Float>> hashes = new ArrayList<Hashtable<String, Float>>();
/**
*
* Static initialization of data from an XML file that contains
* word count statistics for several common topics
*
*/
static {
DefaultHandler handler = new TagsSAXHandler();
SAXParserFactory factory = SAXParserFactory.newInstance(); // Use the default non-validating parser
try {
System.err.println("Loading tag.xml for auto classification...");
InputStream xml_input_stream = handler.getClass().getClassLoader().getResourceAsStream("data/tags.xml");
//if (xml_input_stream == null) xml_input_stream = handler.getClass().getClassLoader().getResourceAsStream("data/tags.xml");
//if (xml_input_stream == null) xml_input_stream = handler.getClass().getClassLoader().getResourceAsStream("com/knowledgebooks/nlp/data/tags.xml");
//System.err.println("1. xml_input_stream = " + xml_input_stream);
if (xml_input_stream == null) {
xml_input_stream = new FileInputStream(System.getProperty("user.dir") + "/" + "data/tags.xml");
//System.err.println("2. xml_input_stream = " + xml_input_stream);
}
//FileInputStream xml_input_stream = new FileInputStream(System.getProperty("user.dir") + "/" + "data/tags.xml");
SAXParser saxParser = factory.newSAXParser();
saxParser.parse(xml_input_stream, handler);
} catch (Throwable t) {
t.printStackTrace();
}
tagClassNames = new String[tagClasses.size()];
int count = 0;
for (Enumeration<String> e = tagClasses.keys(); e.hasMoreElements();) {
String cname = e.nextElement();
//System.out.println("cname="+cname);
hashes.add(tagClasses.get(cname));
tagClassNames[count++] = cname;
}
tagClasses = null;
}
public AutoTagger() {
}
public List<NameValue<String, Float>> getTags(String text) {
List<NameValue<String, Float>> results = new ArrayList<NameValue<String, Float>>();
List<SFtriple> tag_data = getTagsHelper(text);
for (SFtriple triple : tag_data) {
results.add(new NameValue<String, Float>(triple.getS(), triple.getF()));
}
return results;
}
public List<String> getTagsAsStrings(String text) {
List<String> results = new ArrayList<String>();
List<SFtriple> tag_data = getTagsHelper(text);
for (SFtriple triple : tag_data) {
results.add(triple.getS() + ":" + triple.getF());
}
return results;
}
/**
* @param text input text processed to identify categories
* @return
*/
private List<SFtriple> getTagsHelper(String text) {
Stemmer stemmer = new Stemmer();
List<String> stems = stemmer.stemString(text);
return getTagsHelper(stems);
}
/**
* @param stems
* @return
*/
private List<SFtriple> getTagsHelper(List<String> stems) {
List<SFtriple> ret = new ArrayList<SFtriple>();
int size = tagClassNames.length;
float[] scores = new float[size];
for (String stem : stems) {
for (int i = 0; i < size; i++) {
Float f = hashes.get(i).get(stem);
if (f != null) scores[i] += f;
}
}
float max_score = 0.001f;
for (int i = 0; i < size; i++) if (max_score < scores[i]) max_score = scores[i];
float cutoff = 0.2f * max_score;
for (int i = 0; i < size; i++) {
if (scores[i] > cutoff) ret.add(new SFtriple(tagClassNames[i], scores[i] / max_score, i));
}
//for (int i=0; i<size; i++) System.out.println(tagClassNames[i]+"\t"+scores[i]);
Collections.sort(ret, new SFtripleComparator());
return ret;
}
class SFtripleComparator implements Comparator<SFtriple> {
public int compare(SFtriple o1, SFtriple o2) {
return (int) (1000 * (o2.getF() - o1.getF()));
}
}
/**
* @param text
* @return
*/
float[] getWordImportanceWeights(String text) {
List<String> stems = new Stemmer().stemString(text);
List<SFtriple> best_tags = getTagsHelper(stems);
return getWordImportanceWeights(stems, best_tags);
}
/**
* @param stems
* @return
*/
float[] getWordImportanceWeights(List<String> stems) {
List<SFtriple> best_tags = getTagsHelper(stems);
return getWordImportanceWeights(stems, best_tags);
}
/**
* Find the words that are most important for determining tags and use
* this information to find which words in input text are most important for
* summarization, semantic understanding, etc.
*
* @param stems stems for words in text
* @param best_tags the best tags for this text
* @return
*/
private float[] getWordImportanceWeights(List<String> stems, List<SFtriple> best_tags) {
int num = stems.size();
float[] ret = new float[num];
float scale = 1.0f / best_tags.size();
for (SFtriple tag : best_tags) {
Hashtable<String, Float> h = hashes.get(tag.getTopic_index());
for (int i = 0; i < num; i++) {
Float f = h.get(stems.get(i));
if (f != null) ret[i] += h.get(stems.get(i)) * scale;
}
}
return ret;
}
/**
* Test program
*
* @param args not used
*/
public static void main(String[] args) {
AutoTagger test = new AutoTagger();
List<NameValue<String, Float>> results = test.getTags("The President went to Congress to argue for his tax bill before leaving on a vacation to Las Vegas to see some shows and gamble.");
for (NameValue<String, Float> result : results) {
System.out.println(result);
}
}
static class TagsSAXHandler extends org.xml.sax.helpers.DefaultHandler {
int depth = 0;
String last_topic = "";
Hashtable<String, Float> hash;
// override default methods for a few SAX events:
@Override
public void startElement(String uri, String localName,
String qName, Attributes attributes)
throws SAXException {
if (depth == 0) {
tagClasses = new Hashtable<String, Hashtable<String, Float>>();
}
if (depth == 1) {
last_topic = attributes.getValue(0);
hash = new Hashtable<String, Float>();
tagClasses.put(last_topic, hash);
}
if (depth == 2) {
hash.put(attributes.getValue(0), Float.parseFloat(attributes.getValue(1)));
}
// debug:
/*for (int i=0; i<depth; i++) System.out.print(" ");
System.out.println("" + depth + " element: " + qName);
if (attributes != null) {
int num = attributes.getLength();
for (int i=0; i<num; i++) {
String name = attributes.getQName(i);
String value = attributes.getValue(i);
for (int k=0; k<depth; k++) System.out.print(" ");
System.out.println(" attribute: " + name + " value: " + value);
}
}*/
depth++;
}
@Override
public void endElement(String uri, String localName, String qName)
throws SAXException {
depth--;
}
@Override
public void characters(char ch[], int start, int length)
throws SAXException {
}
}
class SFtriple implements Comparable {
public SFtriple(String s, float f, int topic_index) {
this.s = s;
this.f = f;
this.topic_index = topic_index;
}
public String toString() {
return "[SFtriple: " + s + " : " + f + " : " + topic_index + "]";
}
public int compareTo(Object o) {
return (int) (1000f * (((SFtriple) o).getF() - f));
}
private String s;
private float f;
private int topic_index;
public String getS() {
return s;
}
public void setS(String s) {
this.s = s;
}
public float getF() {
return f;
}
public void setF(float f) {
this.f = f;
}
public int getTopic_index() {
return topic_index;
}
public void setTopic_index(int topic_index) {
this.topic_index = topic_index;
}
}
}