package com.yc.nlp.textrank; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import com.yc.nlp.sim.BM25; public class TextRank { private List<List<String>> docs; private BM25 bm25; private Integer D; private Double d; private List<List<Double>> weight; private List<Double> weightSum; private List<Double> vertex; private Integer maxIter; private Double minDiff; private Map<Integer, Double> top; public TextRank(List<List<String>> docs) { this.docs = docs; this.bm25 = new BM25(docs); this.D = docs.size(); this.d = 0.85; this.weight = new ArrayList<List<Double>>(); this.weightSum = new ArrayList<Double>(); this.vertex = new ArrayList<Double>(); this.maxIter = 200; this.minDiff = 0.001; this.top = new LinkedHashMap<Integer, Double>(); } public void solve() { for (List<String> doc : this.docs) { List<Double> scores = this.bm25.simall(doc); this.weight.add(scores); double sum = 0; for (Double score : scores) { sum += score; } this.weightSum.add(sum); this.vertex.add(1.0); } int iterNum = 0; while (iterNum < this.maxIter) { iterNum++; List<Double> m = new ArrayList<Double>(); double maxDiff = 0; for (int i = 0; i < this.D; i++) { m.add(1 - this.d); for (int j = 0; j < this.D; j++) { if (j == i || this.weightSum.get(j) == 0 || this.weightSum.get(j) == 0.0) { continue; } m.set(m.size() - 1, m.get(m.size() - 1) + this.d * this.weight.get(i).get(j) / this.weightSum.get(j) * this.vertex.get(j)); } if (Math.abs(m.get(m.size() - 1) - this.vertex.get(i)) > maxDiff) { maxDiff = Math.abs(m.get(m.size() - 1) - this.vertex.get(i)); } } this.vertex = m; if (maxDiff <= this.minDiff) { break; } } for (int i = 0; i < this.vertex.size(); i++) { this.top.put(i, this.vertex.get(i)); } List<Map.Entry<Integer, Double>> list = new ArrayList<Map.Entry<Integer, Double>>(this.top.entrySet()); Collections.sort(list, new Comparator<Map.Entry<Integer, Double>>() { public int compare(Entry<Integer, Double> o1, Entry<Integer, Double> o2) { return o2.getValue().compareTo(o1.getValue()); } }); this.top.clear(); for (Iterator<Entry<Integer, Double>> it = list.iterator(); it.hasNext();) { Entry<Integer, Double> entry = it.next(); this.top.put(entry.getKey(), entry.getValue()); } } public List<Integer> topIndex(Integer limit) { List<Integer> indexes = new ArrayList<Integer>(); Integer num = 0; for (Map.Entry<Integer, Double> entry : this.top.entrySet()) { if (num == limit) { break; } indexes.add(entry.getKey()); num++; } return indexes; } public List<List<String>> top(Integer limit) { List<List<String>> docs = new ArrayList<List<String>>(); Integer num = 0; for (Map.Entry<Integer, Double> entry : this.top.entrySet()) { if (num == limit) { break; } docs.add(this.docs.get(entry.getKey())); num++; } return docs; } }