/** * Copyright 2015, Emory University * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package edu.emory.clir.clearnlp.vector; import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.function.BiFunction; import java.util.stream.Collectors; import edu.emory.clir.clearnlp.collection.map.ObjectIntHashMap; import edu.emory.clir.clearnlp.collection.pair.ObjectIntPair; import edu.emory.clir.clearnlp.util.DSUtils; import edu.emory.clir.clearnlp.util.Joiner; import edu.emory.clir.clearnlp.util.MathUtils; import edu.emory.clir.clearnlp.util.constant.StringConst; /** * @since 3.0.3 * @author Jinho D. Choi ({@code jinho.choi@emory.edu}) */ public class VectorSpaceModel implements Serializable { private static final long serialVersionUID = 4172483442205081702L; private List<ObjectIntPair<String>> id_to_term; private ObjectIntHashMap<String> term_to_id; private Set<String> stop_words; private int DOCUMENT_SIZE; public VectorSpaceModel() { term_to_id = new ObjectIntHashMap<>(); id_to_term = new ArrayList<>(); stop_words = new HashSet<>(); } public void addStopWords(Set<String> stopWords) { stop_words.addAll(stopWords); } public void addStopWord(String stopWord) { stop_words.add(stopWord); } public List<Term> toBagOfWords(List<String> document, int ngram, boolean df) { ObjectIntHashMap<String> map = getBagOfWords(document, stop_words, ngram); List<Term> terms = new ArrayList<>(map.size()); int id; for (ObjectIntPair<String> t : map) { id = getID(t.o); if (id < 0) // term doesn't exist { id = term_to_id.size(); term_to_id.put(t.o, id+1); id_to_term.add(new ObjectIntPair<>(t.o, 0)); } terms.add(new Term(id, t.i)); if (df) id_to_term.get(id).i++; } Collections.sort(terms); return terms; } /** @param documents each document is represented as a list of strings. */ public List<List<Term>> toTFIDFs(List<List<String>> documents, int ngram, BiFunction<Term,Integer,Double> f) { List<List<Term>> list = documents.stream().map(document -> toBagOfWords(document, ngram, true)).collect(Collectors.toCollection(ArrayList::new)); DOCUMENT_SIZE = documents.size(); for (List<Term> terms : list) { for (Term term : terms) { term.setDocumentFrequency(getDocumentFrequency(term.getID())); term.setScore(f.apply(term, DOCUMENT_SIZE)); } } return list; } public List<Term> getTFIDFs(List<String> document, int ngram, BiFunction<Term,Integer,Double> f) { ObjectIntHashMap<String> map = getBagOfWords(document, stop_words, ngram); List<Term> list = new ArrayList<>(); Term term; int id; for (ObjectIntPair<String> t : map) { id = getID(t.o); if (id >= 0) { term = new Term(id, t.i, getDocumentFrequency(id)); term.setScore(f.apply(term, DOCUMENT_SIZE)); list.add(term); } } return list; } /** @return the term corresponding to the ID if exists; otherwise, null. */ public String getTerm(int id) { return DSUtils.isRange(id_to_term, id) ? id_to_term.get(id).o : null; } /** @return the ID of the term if exists; otherwise, -1. */ public int getID(String term) { return term_to_id.get(term) - 1; } public int getTermSize() { return id_to_term.size(); } public int getDocumentFrequency(int id) { return DSUtils.isRange(id_to_term, id) ? id_to_term.get(id).i : 0; } public void resetDocumentFrequency() { for (ObjectIntPair<String> term : id_to_term) term.i = 0; DOCUMENT_SIZE = 0; } static public double getTFIDF(double termScore, int documentFrequency, int documentSize) { return Math.log(MathUtils.divide(documentSize, documentFrequency)) * termScore; } static public double getTFIDF(Term term, int documentSize) { return getTFIDF(term.getTermFrequency(), term.getDocumentFrequency(), documentSize); } static public double getWFIDF(Term term, int documentSize) { double termScore = (term.getTermFrequency() > 0) ? 1d + Math.log(term.getTermFrequency()) : 0; return getTFIDF(termScore, term.getDocumentFrequency(), documentSize); } static public double getEuclideanDistance(List<Term> d1, List<Term> d2) { int i = 0, j = 0, len1 = d1.size(), len2 = d2.size(); double sum = 0; Term t1, t2; while (i<len1 && j<len2) { t1 = d1.get(i); t2 = d2.get(j); if (t1.getID() < t2.getID()) { sum += MathUtils.sq(t1.getScore()); i++; } else if (t1.getID() > t2.getID()) { sum += MathUtils.sq(t2.getScore()); j++; } else { sum += MathUtils.sq(t1.getScore() - t2.getScore()); i++; j++; } } for (; i<len1; i++) sum += MathUtils.sq(d1.get(i).getScore()); for (; j<len2; j++) sum += MathUtils.sq(d2.get(j).getScore()); return Math.sqrt(sum); } static public double getCosineSimilarity(List<Term> d1, List<Term> d2) { int i = 0, j = 0, len1 = d1.size(), len2 = d2.size(); double num = 0, den1 = 0, den2 = 0; Term t1, t2; while (i<len1 && j<len2) { t1 = d1.get(i); t2 = d2.get(j); den1 += MathUtils.sq(t1.getScore()); den2 += MathUtils.sq(t2.getScore()); if (t1.getID() < t2.getID()) i++; else if (t1.getID() > t2.getID()) j++; else { num += t1.getScore() * t2.getScore(); i++; j++; } } for (; i<len1; i++) den1 += MathUtils.sq(d1.get(i).getScore()); for (; j<len2; j++) den2 += MathUtils.sq(d2.get(j).getScore()); return num / (Math.sqrt(den1) * Math.sqrt(den2)); } static public ObjectIntHashMap<String> getBagOfWords(List<String> terms, Set<String> stopWords, int n) { ObjectIntHashMap<String> map = new ObjectIntHashMap<>(); int i, j, k, size; terms = DSUtils.removeAll(terms, stopWords); size = terms.size(); for (i=0; i<size; i++) for (j=0,k=i; j<n && k>=0; j++,k--) map.add(Joiner.join(terms, StringConst.UNDERSCORE, k, i+1)); return map; } static public Set<String> generateStopWords(List<List<String>> documents, int cutoff) { ObjectIntHashMap<String> map = new ObjectIntHashMap<>(); Set<String> set; for (List<String> document : documents) { set = new HashSet<>(document); for (String term : set) map.add(term); } List<ObjectIntPair<String>> list = map.toList(); Collections.sort(list, Collections.reverseOrder()); int i, len = list.size(); set = new HashSet<>(); for (i=0; i<len; i++) set.add(list.get(i).o); return set; } }