/** * Copyright 2015, Emory University * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package edu.emory.clir.clearnlp.experiment; import java.io.BufferedReader; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintStream; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.StringJoiner; import java.util.function.BiFunction; import java.util.stream.Collectors; import java.util.stream.IntStream; import edu.emory.clir.clearnlp.collection.pair.DoubleIntPair; import edu.emory.clir.clearnlp.collection.pair.ObjectIntPair; import edu.emory.clir.clearnlp.tokenization.AbstractTokenizer; import edu.emory.clir.clearnlp.tokenization.EnglishTokenizer; import edu.emory.clir.clearnlp.util.IOUtils; import edu.emory.clir.clearnlp.util.Splitter; import edu.emory.clir.clearnlp.util.StringUtils; import edu.emory.clir.clearnlp.util.constant.PatternConst; import edu.emory.clir.clearnlp.vector.Term; import edu.emory.clir.clearnlp.vector.VectorSpaceModel; /** * @since 3.0.3 * @author Jinho D. Choi ({@code jinho.choi@emory.edu}) */ public class SymbolStrip { final int[] CATEGORIES = {2,3,6,8,10,12,13,18}; private List<List<DoubleIntPair>> category_list; private VectorSpaceModel vs_model; List<List<Term>> train_vectors; public SymbolStrip() { category_list = IntStream.range(0, CATEGORIES.length).mapToObj(i -> new ArrayList<DoubleIntPair>()).collect(Collectors.toCollection(ArrayList::new)); vs_model = new VectorSpaceModel(); } public void initVectors(InputStream in, int ngram, BiFunction<Term,Integer,Double> f, boolean normalize) throws Exception { BufferedReader reader = IOUtils.createBufferedReader(in); AbstractTokenizer tokenizer = new EnglishTokenizer(); List<List<String>> documents = new ArrayList<>(); int id, i, j, c, len = CATEGORIES.length; String line = reader.readLine(); List<String> document; String[] t; String s; for (i=0; (line = reader.readLine()) != null; i++) { line = StringUtils.toLowerCase(line); t = PatternConst.TAB.split(line); document = StringUtils.stripPunctuation(tokenizer.tokenize(t[1])); if (document.isEmpty()) System.err.println("Empty document: "+i); else { id = documents.size(); documents.add(document); for (j=0; j<len; j++) { c = CATEGORIES[j]; if (c >= t.length) break; if (!(s = t[c]).isEmpty()) category_list.get(j).add(new DoubleIntPair(Double.parseDouble(s), id)); } } } reader.close(); train_vectors = vs_model.toTFIDFs(documents, ngram, f); if (normalize) { for (List<DoubleIntPair> list : category_list) normalize(list); } // for (i=0; i<len; i++) // System.out.printf("%2d: %d\n", CATEGORIES[i], category_list.get(i).size()); } private void normalize(List<DoubleIntPair> list) { double max = list.get(0).d, min = list.get(0).d; int i, len = list.size(); DoubleIntPair p; for (i=1; i<len; i++) { max = Math.max(max, list.get(i).d); min = Math.min(min, list.get(i).d); } for (i=0; i<len; i++) { p = list.get(i); p.d = (p.d - min) / (max - min); } } public void measureCategories(InputStream in, OutputStream out, int ngram, BiFunction<Term,Integer,Double> f) throws Exception { Map<String,ObjectIntPair<double[]>> map = new HashMap<>(); BufferedReader reader = IOUtils.createBufferedReader(in); int i, len = CATEGORIES.length; ObjectIntPair<double[]> p; List<String> document; List<Term> d1; String line; String[] t; while ((line = reader.readLine()) != null) { t = Splitter.splitTabs(line); p = map.computeIfAbsent(t[0], k -> new ObjectIntPair<>(new double[len], 0)); document = StringUtils.stripPunctuation(Splitter.splitSpace(t[2])); d1 = vs_model.getTFIDFs(document, ngram, f); if (d1.isEmpty()) continue; for (i=0; i<len; i++) p.o[i] += getScore(category_list.get(i), d1, ngram, f); p.i++; } reader.close(); PrintStream fout = IOUtils.createBufferedPrintStream(out); List<String> states = new ArrayList<>(map.keySet()); StringJoiner joiner = new StringJoiner(","); Collections.sort(states); joiner.add("State"); for (i=0; i<len; i++) joiner.add(Integer.toString(CATEGORIES[i])); fout.println(joiner.toString()); for (String state : states) { joiner = new StringJoiner(","); p = map.get(state); joiner.add(state); for (i=0; i<len; i++) joiner.add(Double.toString(p.o[i] / p.i)); fout.println(joiner.toString()); } fout.close(); } private double getScore(List<DoubleIntPair> cluster, List<Term> d1, int ngram, BiFunction<Term,Integer,Double> f) { double sum = 0; List<Term> d2; for (DoubleIntPair p : cluster) { d2 = train_vectors.get(p.i); sum += VectorSpaceModel.getCosineSimilarity(d1, d2) * p.d; } return sum / cluster.size(); } public void split(String inputFile, String outputFile) throws Exception { PrintStream fout = IOUtils.createBufferedPrintStream(outputFile); BufferedReader reader = IOUtils.createBufferedReader(inputFile); String line; String[] t; int i; for (i=0; (line = reader.readLine()) != null; i++) { if (i%10000 == 0) System.out.print("."); t = Splitter.splitTabs(line.trim()); if (!StringUtils.containsPunctuationOrDigitsOrWhiteSpacesOnly(t[1])) fout.println(t[0]+"\t"+i+"\t"+t[1]); } } static public void main(String[] args) throws Exception { final String inputDir = args[0]; final String func = args[1]; final int ngram = Integer.parseInt(args[2]); final boolean normalize = Boolean.parseBoolean(args[3]); final String trainFile = inputDir+"/mind_wandering_and_axiety.txt"; final String tweetFile = inputDir+"/tweetsByStateSplittedCleaned.csv.out"; final String outputFile = inputDir+"/"+func+"-"+ngram+"-"+normalize+".csv"; final BiFunction<Term,Integer,Double> f = func.equals("tf") ? VectorSpaceModel::getTFIDF : VectorSpaceModel::getWFIDF; SymbolStrip vs = new SymbolStrip(); vs.initVectors(IOUtils.createFileInputStream(trainFile), ngram, f, normalize); vs.measureCategories(IOUtils.createFileInputStream(tweetFile), IOUtils.createFileOutputStream(outputFile), ngram, f); } }