package com.caseystella.summarize;
import com.caseystella.util.ConversionUtils;
import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import org.apache.commons.lang3.StringUtils;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.mllib.feature.Word2Vec;
import org.apache.spark.mllib.feature.Word2VecModel;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.*;
import scala.Function1;
import scala.Tuple2;
import java.io.Serializable;
import java.util.*;
public class SynonymHandler implements Serializable {
private transient JavaPairRDD<String, List<Tuple2<String, Double>>> word2SynonymRDD;
public SynonymHandler(DataFrame df
, final List<String> columns
, int minOccurrance
, int vectorSize
, final double scoreCutoff)
{
JavaRDD<List<String>> rows =
df.javaRDD().map(new Function<Row, List<String>>() {
@Override
public List<String> call(Row row) throws Exception {
List<String> str = new ArrayList<>();
for(int i = 0;i < columns.size();++i) {
str.add(columns.get(i) + ":" + row.get(i));
}
return str;
}
}).cache();
// Learn a mapping from words to Vectors.
Word2Vec word2Vec = new Word2Vec()
.setVectorSize(vectorSize)
.setSeed(0)
.setWindowSize(columns.size())
.setMinCount(minOccurrance);
final Word2VecModel model = word2Vec.fit(rows);
JavaRDD<String> vocabulary = rows.flatMap(new FlatMapFunction<List<String>, String>() {
@Override
public Iterable<String> call(List<String> row) throws Exception {
List<String> ret = new ArrayList<String>();
for(String val : row) {
ret.add(val);
}
return ret;
}
}).distinct();
word2SynonymRDD =
vocabulary.flatMapToPair(new PairFlatMapFunction<String, String, List<Tuple2<String, Double>>>() {
@Override
public Iterable<Tuple2<String, List<Tuple2<String, Double>>>> call(String s) throws Exception {
List<Tuple2<String, Double>> synonyms = getSynonyms(model, s, scoreCutoff);
if(synonyms.size() > 0) {
return ImmutableList.of(new Tuple2<>(s, synonyms));
}
return Collections.emptyList();
}
}).cache();
}
public Map<String, String> findAllSynonyms() {
List<Tuple2<String, Tuple2<String, Double>>> allSynonyms =
word2SynonymRDD.flatMapToPair(new PairFlatMapFunction<Tuple2<String,List<Tuple2<String,Double>>>, String, Tuple2<String, Double>>() {
@Override
public Iterable<Tuple2<String, Tuple2<String, Double>>> call(Tuple2<String, List<Tuple2<String, Double>>> t) throws Exception {
List<Tuple2<String, Tuple2<String, Double>>> ret = new ArrayList<>();
for(Tuple2<String, Double> kv : t._2) {
ret.add(new Tuple2<>(t._1, new Tuple2<>(kv._1, kv._2)));
}
return ret;
}
}).takeOrdered(20, new ScoreComparator());
Map<String, String> ret = new LinkedHashMap<>();
for(Tuple2<String, Tuple2<String, Double>> synonym : allSynonyms) {
if(!synonym._1.equals(synonym._2._1)) {
ret.put(synonym._1, synonym._2._1);
}
}
return ret;
}
public static class ScoreComparator implements Comparator<Tuple2<String, Tuple2<String, Double>>>, Serializable {
@Override
public int compare(Tuple2<String, Tuple2<String, Double>> o1, Tuple2<String, Tuple2<String, Double>> o2) {
return -1*Double.compare(o1._2._2, o2._2._2);
}
}
public Map<String, Map<String, String>> findSynonymsByColumn() {
Map<String, List<Tuple2<String, Double>>> wordToSynonym = word2SynonymRDD.collectAsMap();
Map<String, Set<Tuple2<Map.Entry<String, String>, Double>>> synonymMap = new HashMap<>();
for(Map.Entry<String, List<Tuple2<String, Double>>> kv : wordToSynonym.entrySet()) {
Tuple2<String, String> cw = word2columnVal(kv.getKey());
String column = cw._1;
String word= cw._2;
Set<Tuple2<Map.Entry<String, String>, Double>> list = synonymMap.get(column);
if(list == null) {
list = new HashSet<>();
synonymMap.put(column, list);
}
for(Tuple2<String, Double> synonym : kv.getValue()) {
Tuple2<String, String> synonymCw = word2columnVal(synonym._1);
if(synonymCw._1.equals(column)) {
String synonymWord = synonymCw._2;
int comparision = word.compareTo(synonymWord);
if (comparision != 0) {
String left = comparision < 0 ? word : synonymWord;
String right = comparision > 0 ? word : synonymWord;
list.add(new Tuple2<Map.Entry<String, String>, Double>(new AbstractMap.SimpleEntry<>(left, right), synonym._2));
}
}
}
}
Map<String, Map<String, String>> ret= new HashMap<>();
for(Map.Entry<String, Set<Tuple2<Map.Entry<String, String>, Double>>> kv : synonymMap.entrySet()) {
List<Tuple2<Map.Entry<String, String>, Double>> l = new ArrayList<>(kv.getValue());
Collections.sort(l, new Comparator<Tuple2<Map.Entry<String, String>, Double>>() {
@Override
public int compare(Tuple2<Map.Entry<String, String>, Double> o1, Tuple2<Map.Entry<String, String>, Double> o2) {
return -1*Double.compare(o1._2, o2._2);
}
}
);
Map<String, String> v = new LinkedHashMap<>();
for(Tuple2<Map.Entry<String, String>, Double> t : l) {
v.put(t._1.getKey(), t._1.getValue());
}
ret.put(kv.getKey(), v);
}
return ret;
}
public static Tuple2<String, String> word2columnVal(String word ) {
Iterable<String> tokens = Splitter.on(":").split(word);
String column = Iterables.getFirst(tokens, null);
if(column != null) {
return new Tuple2<>(column, Joiner.on(":").join(Iterables.skip(tokens, 1)));
}
return null;
}
public static List<Tuple2<String, Double> > getSynonyms(Word2VecModel model, String word, double scoreCutoff) {
List<Tuple2<String, Double>> ret = new ArrayList<>();
try {
for (Tuple2<String, Object> r : model.findSynonyms(word, 10)) {
String w = r._1;
Double score = Double.parseDouble(r._2.toString());
boolean eitherNonNumbers = (ConversionUtils.convert(word, Double.class) == null || ConversionUtils.convert(w, Double.class) == null );
boolean neitherNull = !StringUtils.isEmpty(word) && !StringUtils.isEmpty(w) && !word.trim().equals("null") && !w.trim().equals("null");
if(score > scoreCutoff && eitherNonNumbers && neitherNull) {
ret.add(new Tuple2<>(w, score));
}
}
}
catch(IllegalStateException ise) {
//in this situation we want to skip. Generally this means that a word isn't in the vocabulary
}
return ret;
}
}