package com.yc.nlp.textrank; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; public class KeyWordTextRank { private List<List<String>> docs; private Map<String, List<String>> words; private Map<String, Double> vertex; private Double d; private Integer maxIter; private Double minDiff; private Map<String, Double> top; public KeyWordTextRank(List<List<String>> docs) { this.docs = docs; this.words = new HashMap<String, List<String>>(); this.vertex = new HashMap<String, Double>(); this.d = 0.85; this.maxIter = 200; this.minDiff = 0.001; this.top = new LinkedHashMap<String, Double>(); } public void solve() { for (List<String> doc : docs) { List<String> que = new ArrayList<String>(); for (String ch : doc) { String word = ch.toString(); List<String> value = this.words.get(word); if (value == null) { value = new ArrayList<String>(); this.words.put(word, value); this.vertex.put(word, 1.0); } que.add(word); if (que.size() > 5) { que.remove(0); } for (String w1 : que) { for (String w2 : que) { if (w1.equals(w2)) { continue; } this.words.get(w1).add(w2); this.words.get(w2).add(w1); } } } } Integer iterNum = 0; while (iterNum < this.maxIter) { iterNum++; Map<String, Double> m = new HashMap<String, Double>(); double maxDiff = 0; for (Map.Entry<String, List<String>> entry : this.words.entrySet()) { String key = entry.getKey(); m.put(key, 1 - this.d); for (String j : entry.getValue()) { if (key.equals(j) || this.words.get(j).size() == 0) { continue; } m.put(key, m.get(key) + (this.d / this.words.get(j).size() * this.vertex.get(j))); } if (Math.abs(m.get(key) - this.vertex.get(key)) > maxDiff) { maxDiff = Math.abs(m.get(key) - this.vertex.get(key)); } } this.vertex = m; if (maxDiff <= this.minDiff) { break; } } List<Map.Entry<String, Double>> list = new ArrayList<Map.Entry<String, Double>>(this.vertex.entrySet()); Collections.sort(list, new Comparator<Map.Entry<String, Double>>() { public int compare(Entry<String, Double> o1, Entry<String, Double> o2) { return o2.getValue().compareTo(o1.getValue()); } }); for (Iterator<Entry<String, Double>> it = list.iterator(); it.hasNext();) { Entry<String, Double> entry = it.next(); this.top.put(entry.getKey(), entry.getValue()); } } public List<String> topIndex(Integer limit) { List<String> indexes = new ArrayList<String>(); Integer num = 0; for (Map.Entry<String, Double> entry : this.top.entrySet()) { if (num == limit) { break; } indexes.add(entry.getKey()); num++; } return indexes; } //TODO: /*public List<Set<String>> top(Integer limit) { List<Set<String>> docs = new ArrayList<Set<String>>(); Integer num = 0; for (Map.Entry<String, Double> entry : this.top.entrySet()) { if (num == limit) { break; } //docs.add(this.docs.get(entry.getKey())); num++; } return docs; }*/ }