package com.knowledgebooks.nlp; import com.knowledgebooks.nlp.util.NameValue; import com.knowledgebooks.public_domain.Stemmer; import org.xml.sax.Attributes; import org.xml.sax.SAXException; import org.xml.sax.helpers.DefaultHandler; import javax.xml.parsers.SAXParser; import javax.xml.parsers.SAXParserFactory; import java.io.FileInputStream; import java.io.InputStream; import java.util.*; /** * Associate pre-trained classification categories (tags) with input text: assigns * categories for news story types, technology category types, social information * types, etc. to input text. */ /** * Copyright Mark Watson 2008-2010. All Rights Reserved. * License: LGPL version 3 (http://www.gnu.org/licenses/lgpl-3.0.txt) */ public class AutoTagger { // Make the following package visible so the class // ExtractSearchTerms can use this data: static Hashtable<String, Hashtable<String, Float>> tagClasses; static String[] tagClassNames; static List<Hashtable<String, Float>> hashes = new ArrayList<Hashtable<String, Float>>(); /** * * Static initialization of data from an XML file that contains * word count statistics for several common topics * */ static { DefaultHandler handler = new TagsSAXHandler(); SAXParserFactory factory = SAXParserFactory.newInstance(); // Use the default non-validating parser try { System.err.println("Loading tag.xml for auto classification..."); InputStream xml_input_stream = handler.getClass().getClassLoader().getResourceAsStream("data/tags.xml"); //if (xml_input_stream == null) xml_input_stream = handler.getClass().getClassLoader().getResourceAsStream("data/tags.xml"); //if (xml_input_stream == null) xml_input_stream = handler.getClass().getClassLoader().getResourceAsStream("com/knowledgebooks/nlp/data/tags.xml"); //System.err.println("1. xml_input_stream = " + xml_input_stream); if (xml_input_stream == null) { xml_input_stream = new FileInputStream(System.getProperty("user.dir") + "/" + "data/tags.xml"); //System.err.println("2. xml_input_stream = " + xml_input_stream); } //FileInputStream xml_input_stream = new FileInputStream(System.getProperty("user.dir") + "/" + "data/tags.xml"); SAXParser saxParser = factory.newSAXParser(); saxParser.parse(xml_input_stream, handler); } catch (Throwable t) { t.printStackTrace(); } tagClassNames = new String[tagClasses.size()]; int count = 0; for (Enumeration<String> e = tagClasses.keys(); e.hasMoreElements();) { String cname = e.nextElement(); //System.out.println("cname="+cname); hashes.add(tagClasses.get(cname)); tagClassNames[count++] = cname; } tagClasses = null; } public AutoTagger() { } public List<NameValue<String, Float>> getTags(String text) { List<NameValue<String, Float>> results = new ArrayList<NameValue<String, Float>>(); List<SFtriple> tag_data = getTagsHelper(text); for (SFtriple triple : tag_data) { results.add(new NameValue<String, Float>(triple.getS(), triple.getF())); } return results; } public List<String> getTagsAsStrings(String text) { List<String> results = new ArrayList<String>(); List<SFtriple> tag_data = getTagsHelper(text); for (SFtriple triple : tag_data) { results.add(triple.getS() + ":" + triple.getF()); } return results; } /** * @param text input text processed to identify categories * @return */ private List<SFtriple> getTagsHelper(String text) { Stemmer stemmer = new Stemmer(); List<String> stems = stemmer.stemString(text); return getTagsHelper(stems); } /** * @param stems * @return */ private List<SFtriple> getTagsHelper(List<String> stems) { List<SFtriple> ret = new ArrayList<SFtriple>(); int size = tagClassNames.length; float[] scores = new float[size]; for (String stem : stems) { for (int i = 0; i < size; i++) { Float f = hashes.get(i).get(stem); if (f != null) scores[i] += f; } } float max_score = 0.001f; for (int i = 0; i < size; i++) if (max_score < scores[i]) max_score = scores[i]; float cutoff = 0.2f * max_score; for (int i = 0; i < size; i++) { if (scores[i] > cutoff) ret.add(new SFtriple(tagClassNames[i], scores[i] / max_score, i)); } //for (int i=0; i<size; i++) System.out.println(tagClassNames[i]+"\t"+scores[i]); Collections.sort(ret, new SFtripleComparator()); return ret; } class SFtripleComparator implements Comparator<SFtriple> { public int compare(SFtriple o1, SFtriple o2) { return (int) (1000 * (o2.getF() - o1.getF())); } } /** * @param text * @return */ float[] getWordImportanceWeights(String text) { List<String> stems = new Stemmer().stemString(text); List<SFtriple> best_tags = getTagsHelper(stems); return getWordImportanceWeights(stems, best_tags); } /** * @param stems * @return */ float[] getWordImportanceWeights(List<String> stems) { List<SFtriple> best_tags = getTagsHelper(stems); return getWordImportanceWeights(stems, best_tags); } /** * Find the words that are most important for determining tags and use * this information to find which words in input text are most important for * summarization, semantic understanding, etc. * * @param stems stems for words in text * @param best_tags the best tags for this text * @return */ private float[] getWordImportanceWeights(List<String> stems, List<SFtriple> best_tags) { int num = stems.size(); float[] ret = new float[num]; float scale = 1.0f / best_tags.size(); for (SFtriple tag : best_tags) { Hashtable<String, Float> h = hashes.get(tag.getTopic_index()); for (int i = 0; i < num; i++) { Float f = h.get(stems.get(i)); if (f != null) ret[i] += h.get(stems.get(i)) * scale; } } return ret; } /** * Test program * * @param args not used */ public static void main(String[] args) { AutoTagger test = new AutoTagger(); List<NameValue<String, Float>> results = test.getTags("The President went to Congress to argue for his tax bill before leaving on a vacation to Las Vegas to see some shows and gamble."); for (NameValue<String, Float> result : results) { System.out.println(result); } } static class TagsSAXHandler extends org.xml.sax.helpers.DefaultHandler { int depth = 0; String last_topic = ""; Hashtable<String, Float> hash; // override default methods for a few SAX events: @Override public void startElement(String uri, String localName, String qName, Attributes attributes) throws SAXException { if (depth == 0) { tagClasses = new Hashtable<String, Hashtable<String, Float>>(); } if (depth == 1) { last_topic = attributes.getValue(0); hash = new Hashtable<String, Float>(); tagClasses.put(last_topic, hash); } if (depth == 2) { hash.put(attributes.getValue(0), Float.parseFloat(attributes.getValue(1))); } // debug: /*for (int i=0; i<depth; i++) System.out.print(" "); System.out.println("" + depth + " element: " + qName); if (attributes != null) { int num = attributes.getLength(); for (int i=0; i<num; i++) { String name = attributes.getQName(i); String value = attributes.getValue(i); for (int k=0; k<depth; k++) System.out.print(" "); System.out.println(" attribute: " + name + " value: " + value); } }*/ depth++; } @Override public void endElement(String uri, String localName, String qName) throws SAXException { depth--; } @Override public void characters(char ch[], int start, int length) throws SAXException { } } class SFtriple implements Comparable { public SFtriple(String s, float f, int topic_index) { this.s = s; this.f = f; this.topic_index = topic_index; } public String toString() { return "[SFtriple: " + s + " : " + f + " : " + topic_index + "]"; } public int compareTo(Object o) { return (int) (1000f * (((SFtriple) o).getF() - f)); } private String s; private float f; private int topic_index; public String getS() { return s; } public void setS(String s) { this.s = s; } public float getF() { return f; } public void setF(float f) { this.f = f; } public int getTopic_index() { return topic_index; } public void setTopic_index(int topic_index) { this.topic_index = topic_index; } } }