package qa.qcri.aidr.predict.featureextraction;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import org.apache.log4j.Logger;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
/**
* A DocumentFeature implementation consisting of a set of words.
*
* @author jrogstadius
*/
public class WordSet implements DocumentFeature, Serializable {
private static Logger logger = Logger.getLogger(WordSet.class);
private static final long serialVersionUID = 1L;
private static final String STR_TYPE = "type",
STR_WORDVECTOR = "wordvector", STR_WORDS = "words";
HashSet<String> words = new HashSet<String>();
public void addAll(Collection<String> words) {
this.words.addAll(words);
}
public void addAll(String[] words) {
this.words.addAll(Arrays.asList(words));
}
public List<String> getWords() {
return new ArrayList<String>(words);
}
public JSONObject toJSONObject() {
if (words.isEmpty())
return null;
JSONArray wordsArr = new JSONArray();
for (String w : words)
wordsArr.put(w);
JSONObject obj = new JSONObject();
try {
obj.put(STR_TYPE, STR_WORDVECTOR);
obj.put(STR_WORDS, wordsArr);
} catch (JSONException e) {
logger.error("Error in json parsing: " + words);
throw new RuntimeException(e);
}
return obj;
}
public static WordSet join(Collection<WordSet> sets) {
WordSet set = new WordSet();
for (WordSet s : sets)
set.addAll(s.getWords());
return set;
}
public double getSimilarity(WordSet other) {
int l1 = words.size();
int l2 = other.words.size();
// intersection of two sets
HashSet<String> intersection = (HashSet<String>) (l1 < l2 ? words.clone() : other.words.clone());
intersection.retainAll(l1 < l2 ? other.words : words);
int l3 = intersection.size();
// similarity using Jaccard coefficient
return l3 / (double) (l1 + l2 -l3);
}
}