/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.github.pmerienne.trident.ml.nlp; import java.io.File; import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.List; import org.codehaus.jackson.map.ObjectMapper; import storm.trident.operation.BaseFunction; import storm.trident.operation.TridentCollector; import storm.trident.tuple.TridentTuple; import backtype.storm.tuple.Values; import com.github.pmerienne.trident.ml.classification.Classifier; import com.github.pmerienne.trident.ml.classification.PAClassifier; import com.github.pmerienne.trident.ml.core.TextInstance; import com.github.pmerienne.trident.ml.preprocessing.TwitterTokenizer; import com.github.pmerienne.trident.ml.testing.data.Datasets; public class TwitterSentimentClassifier extends BaseFunction implements Serializable { private static final long serialVersionUID = 1553274753609262633L; protected TextFeaturesExtractor featuresExtractor; protected Classifier<Boolean> classifier; private TwitterTokenizer tokenizer = new TwitterTokenizer(2, 2); public TwitterSentimentClassifier() { try { this.featuresExtractor = Builder.loadFeatureExtractor(); this.classifier = Builder.loadClassifier(); } catch (IOException e) { throw new RuntimeException("Unable to load TwitterSentimentClassifier : " + e.getMessage(), e); } } @Override public void execute(TridentTuple tuple, TridentCollector collector) { String text = tuple.getString(0); boolean prediction = this.classify(text); collector.emit(new Values(prediction)); } protected Boolean classify(String text) { List<String> tokens = this.tokenizer.tokenize(text); double[] features = this.featuresExtractor.extractFeatures(tokens); Boolean prediction = this.classifier.classify(features); return prediction; } protected static class Builder { private final static File TEXT_FEATURES_EXTRACTOR_FILE = new File(Builder.class.getResource("/twitter-sentiment-classifier-extractor.json").getFile()); private final static File CLASSIFIER_FILE = new File(Builder.class.getResource("/twitter-sentiment-classifier-classifier.json").getFile()); private final static ObjectMapper MAPPER = new ObjectMapper(); public static void main(String[] args) throws IOException { // Get some tweets List<TextInstance<Boolean>> dataset = Datasets.getTwitterSamples(); List<List<String>> documents = new ArrayList<List<String>>(); for (TextInstance<Boolean> instance : dataset) { documents.add(instance.tokens); } // Init feature extractor TFIDF featuresExtractor = new TFIDF(documents, 10000); // Init and train classifier PAClassifier classifier = new PAClassifier(); double[] features; for (TextInstance<Boolean> instance : dataset) { features = featuresExtractor.extractFeatures(instance.tokens); classifier.update(instance.label, features); } // save them save(featuresExtractor, classifier); } protected static void save(TFIDF featuresExtractor, PAClassifier classifier) throws IOException { MAPPER.writeValue(TEXT_FEATURES_EXTRACTOR_FILE, featuresExtractor); MAPPER.writeValue(CLASSIFIER_FILE, classifier); } public static TextFeaturesExtractor loadFeatureExtractor() throws IOException { return MAPPER.readValue(TEXT_FEATURES_EXTRACTOR_FILE, TFIDF.class); } public static Classifier<Boolean> loadClassifier() throws IOException { return MAPPER.readValue(CLASSIFIER_FILE, PAClassifier.class); } } }