package org.wikipedia.miner.web.service; import gnu.trove.map.hash.TIntFloatHashMap; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.servlet.ServletConfig; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import opennlp.tools.util.Span; import org.simpleframework.xml.Attribute; import org.simpleframework.xml.ElementList; import org.wikipedia.miner.comparison.ArticleComparer; import org.wikipedia.miner.model.Wikipedia; import org.wikipedia.miner.util.NGrammer; import org.wikipedia.miner.util.NGrammer.NGramSpan; import org.wikipedia.miner.util.RelatednessCache; import org.dmilne.xjsf.Service; import org.dmilne.xjsf.UtilityMessages.ErrorMessage; import org.dmilne.xjsf.UtilityMessages.ParameterMissingMessage; import org.dmilne.xjsf.param.BooleanParameter; import org.dmilne.xjsf.param.FloatParameter; import org.dmilne.xjsf.param.StringParameter; import com.google.gson.annotations.Expose; import org.wikipedia.miner.web.util.xjsfParameters.StringListParameter; /** * * * * NOTE: this does not support {@link Service.ResponseFormat#DIRECT} */ public class SearchService extends WMService { private static final long serialVersionUID = 5011451347638265017L; Pattern topicPattern = Pattern.compile("\\[\\[(\\d+)\\|(.*?)\\]\\]"); Pattern quotePattern = Pattern.compile("\"(.*?)\""); private StringParameter prmQuery; private BooleanParameter prmComplex; private FloatParameter prmMinPriorProb; private StringListParameter prmQueryList; public SearchService() { super("core", "Lists the senses (wikipedia articles) of terms and phrases", "<p>This service takes a term or phrase, and returns the different Wikipedia articles that these could refer to.</p>" + "<p>By default, it will treat the entire query as one term, but it can be made to break it down into its components " + "(to recognize, for example, that <i>hiking new zealand</i> contains two terms: <i>hiking</i> and <i>new zealand</i>)</p>" + "<p>For each component term, the service will list the different articles (or concepts) that it could refer to, in order of prior probability " + "so that the most obvious senses are listed first.</p>" + "<p>For queries that contain multiple terms, the senses of each term will be compared against each other to disambiguate them. This " + "provides the weight attribute, which is larger for senses that are likely to be the correct interpretation of the query.</p>", false); } @Override public void init(ServletConfig config) throws ServletException { super.init(config); prmQuery = new StringParameter("query", "Your query", null); addGlobalParameter(prmQuery); prmQueryList = new StringListParameter("queryList", "The simple queries you want to run using ',' as separator", null); prmComplex = new BooleanParameter("complex", "<b>true</b> if your query might reference multiple topics, otherwise <b>false</b>", false); addGlobalParameter(prmComplex); prmMinPriorProb = new FloatParameter("minPriorProbability", "the minimum prior probability that a sense must have for it to be returned", 0.01F); addGlobalParameter(prmMinPriorProb); addExample( new ExampleBuilder("List senses of an ambiguous term"). addParam(prmQuery, "kiwi"). build() ); addExample( new ExampleBuilder("Break a complex multi-topic query into its component terms, and list thier senses"). addParam(prmQuery, "hiking new zealand"). addParam(prmComplex, true). build() ); } @Override public Service.Message buildWrappedResponse(HttpServletRequest request) throws Exception { String query = prmQuery.getValue(request); String[] queryList = prmQueryList.getValue(request); if (query == null) { if (queryList == null) return new ParameterMissingMessage(request); else { if (queryList.length > 1) return resolveQueryList(queryList, request); } }else{ if (prmComplex.getValue(request)) return resolveComplexQuery(query, request); else return resolveSimpleQuery(query, request); } return new ParameterMissingMessage(request); } public Service.Message resolveSimpleQuery(String query, HttpServletRequest request) { Wikipedia wikipedia = getWikipedia(request); NGrammer nGrammer = new NGrammer(wikipedia.getConfig().getSentenceDetector(), wikipedia.getConfig().getTokenizer()); NGramSpan span = nGrammer.ngramPosDetect(query)[0]; org.wikipedia.miner.model.Label label = wikipedia.getLabel(span, query); Label rLabel = new Label(label); float minPriorProb = prmMinPriorProb.getValue(request); for (org.wikipedia.miner.model.Label.Sense sense : label.getSenses()) { if (sense.getPriorProbability() < minPriorProb) { break; } rLabel.addSense(new Sense(sense)); } Message msg = new Message(request); msg.addLabel(rLabel); return msg; } private Message resolveQueryList(String[] queryList, HttpServletRequest request) { Wikipedia wikipedia = getWikipedia(request); NGrammer nGrammer = new NGrammer(wikipedia.getConfig().getSentenceDetector(), wikipedia.getConfig().getTokenizer()); Message msg = new Message(request); float minPriorProb = prmMinPriorProb.getValue(request); for (String query : queryList) { NGramSpan span = nGrammer.ngramPosDetect(query)[0]; org.wikipedia.miner.model.Label label = wikipedia.getLabel(span, query); Label rLabel = new Label(label); for (org.wikipedia.miner.model.Label.Sense sense : label.getSenses()) { if (sense.getPriorProbability() < minPriorProb) { break; } rLabel.addSense(new Sense(sense)); } msg.addLabel(rLabel); } return msg; } public Service.Message resolveComplexQuery(String query, HttpServletRequest request) throws Exception { Wikipedia wikipedia = getWikipedia(request); ArticleComparer artComparer = getWMHub().getArticleComparer(getWikipediaName(request)); if (artComparer == null) { return new ErrorMessage(request, "article comparisons are not available with this wikipedia instance"); } ExhaustiveDisambiguator disambiguator = new ExhaustiveDisambiguator(artComparer); float minPriorProb = prmMinPriorProb.getValue(request); Message msg = new Message(request); //resolve query ArrayList<QueryLabel> queryLabels = getReferences(query, wikipedia); queryLabels = resolveCollisions(queryLabels); queryLabels = disambiguator.disambiguate(queryLabels, minPriorProb); for (QueryLabel queryLabel : queryLabels) { Label rLabel = new Label(queryLabel); for (org.wikipedia.miner.model.Label.Sense sense : queryLabel.getLabel().getSenses()) { if (sense.getPriorProbability() < minPriorProb) { break; } Sense rSense = new Sense(sense); rSense.setWeight(disambiguator.getSenseWeight(sense.getId())); if (queryLabel.getSelectedSenseId() != null && queryLabel.getSelectedSenseId() == sense.getId()) { rSense.setIsSelected(true); } rLabel.addSense(rSense); } rLabel.sortSensesByWeight(); msg.addLabel(rLabel); } return msg; } private ArrayList<QueryLabel> getReferences(String query, Wikipedia wikipedia) { ArrayList<QueryLabel> queryLabels = new ArrayList<QueryLabel>(); NGrammer nGrammer = new NGrammer(wikipedia.getConfig().getSentenceDetector(), wikipedia.getConfig().getTokenizer()); //spans that can't be interrupted or intersected ArrayList<Span> contiguousSpans = new ArrayList<Span>(); //spans that have already been disambiguated HashMap<Long, Integer> topicIdsBySpan = new HashMap<Long, Integer>(); String cleanedQuery = cleanTopicMarkup(query, contiguousSpans, topicIdsBySpan); cleanedQuery = cleanQuoteMarkup(cleanedQuery, contiguousSpans); Collections.sort(contiguousSpans); //System.out.println("Cleaned query: " + cleanedQuery) ; for (NGramSpan span : nGrammer.ngramPosDetect(cleanedQuery)) { //System.out.println(" ngram: " + span.getCoveredText(cleanedQuery) + " " + span.getStart() + ", " + span.getEnd()) ; if (!isSpanValid(span, contiguousSpans)) { //System.out.println(" invalid") ; continue; } Integer topicId = topicIdsBySpan.get(getKey(span)); org.wikipedia.miner.model.Label lbl = wikipedia.getLabel(span, cleanedQuery); QueryLabel ql = new QueryLabel(lbl, wikipedia.getConfig().isStopword(span.getNgram(cleanedQuery)), span, topicId); queryLabels.add(ql); //System.out.println(" lp: " + label.getLinkProbability()) ; } return queryLabels; } private ArrayList<QueryLabel> resolveCollisions(ArrayList<QueryLabel> queryLabels) { for (int i = 0; i < queryLabels.size(); i++) { QueryLabel lbl1 = queryLabels.get(i); List<QueryLabel> overlappingTopics = new ArrayList<QueryLabel>(); double qtWeight = lbl1.getWeight(); double overlapWeight = 0; for (int j = i + 1; j < queryLabels.size(); j++) { QueryLabel lbl2 = queryLabels.get(j); //TODO: contains might not be right if (lbl1.getSpan().intersects(lbl2.getSpan())) { overlappingTopics.add(lbl2); if (!lbl2.isStopword) { overlapWeight = overlapWeight + lbl2.getWeight(); } } else { break; } } if (overlappingTopics.size() > 0) { overlapWeight = overlapWeight / overlappingTopics.size(); } if (overlapWeight > qtWeight) { // want to keep the overlapped items queryLabels.remove(i); i = i - 1; } else { //want to keep the overlapping item for (int j = 0; j < overlappingTopics.size(); j++) { queryLabels.remove(i + 1); } } } return queryLabels; } private Long getKey(Span s) { long key = s.getStart() + (s.getEnd() << 30); return key; } private boolean isSpanValid(Span span, ArrayList<Span> contiguousSpans) { for (Span s : contiguousSpans) { if (s.equals(span)) { return true; } if (s.intersects(span) || s.crosses(span) || s.contains(span) || span.contains(s)) { return false; } if (s.getStart() > span.getEnd()) { break; } } return true; } private String cleanTopicMarkup(String query, ArrayList<Span> contiguousSpans, HashMap<Long, Integer> topicIdsBySpan) { StringBuilder sb = new StringBuilder(); int lastCopyPoint = 0; Matcher m = topicPattern.matcher(query); while (m.find()) { sb.append(query.substring(lastCopyPoint, m.start())); Span span = new Span(sb.length(), sb.length() + m.group(2).length()); contiguousSpans.add(span); topicIdsBySpan.put(getKey(span), Integer.parseInt(m.group(1))); sb.append(m.group(2)); lastCopyPoint = m.end(); } sb.append(query.substring(lastCopyPoint)); return sb.toString(); } private String cleanQuoteMarkup(String query, ArrayList<Span> contiguousSpans) { StringBuilder sb = new StringBuilder(); int lastCopyPoint = 0; Matcher m = quotePattern.matcher(query); while (m.find()) { sb.append(query.substring(lastCopyPoint, m.start())); Span span = new Span(sb.length(), sb.length() + m.group(1).length()); contiguousSpans.add(span); sb.append(m.group(1)); lastCopyPoint = m.end(); } sb.append(query.substring(lastCopyPoint)); return sb.toString(); } public class ExhaustiveDisambiguator { //TODO: make this use disambiguator in labelComparer instead. ArrayList<QueryLabel> queryTerms; RelatednessCache rc; Integer[] selectedSenses; org.wikipedia.miner.model.Label.Sense currCombo[]; //org.wikipedia.miner.model.Label.Sense bestCombo[] ; float bestComboWeight; private TIntFloatHashMap bestSenseWeights; public ExhaustiveDisambiguator(ArticleComparer comparer) { rc = new RelatednessCache(comparer); } public ArrayList<QueryLabel> disambiguate(ArrayList<QueryLabel> queryTerms, float minPriorProb) throws Exception { this.queryTerms = queryTerms; this.currCombo = new org.wikipedia.miner.model.Label.Sense[queryTerms.size()]; //this.bestCombo = null ; this.bestComboWeight = 0; this.selectedSenses = new Integer[queryTerms.size()]; for (int i = 0; i < queryTerms.size(); i++) { this.selectedSenses[i] = queryTerms.get(i).selectedSenseId; } bestSenseWeights = new TIntFloatHashMap(); //recursively check and weight every possible combination of senses checkSenses(0, minPriorProb); return queryTerms; } public float getSenseWeight(int id) { return bestSenseWeights.get(id); } private void checkSenses(int termIndex, float minPriorProb) throws Exception { if (termIndex == queryTerms.size()) { //this is a complete (and unique) combination of senses, so lets weight it weightCombo(); } else { // this is not a complete combination of senses, so continue recursion QueryLabel qt = queryTerms.get(termIndex); if (qt.isStopword || qt.getLabel().getSenses().length == 0) { checkSenses(termIndex + 1, minPriorProb); } else { for (org.wikipedia.miner.model.Label.Sense s : qt.getLabel().getSenses()) { if (s.getPriorProbability() < minPriorProb) { break; } currCombo[termIndex] = s; checkSenses(termIndex + 1, minPriorProb); } } } } private void weightCombo() throws Exception { float commoness = 0; float relatedness = 0; int comparisons = 0; for (int i = 0; i < currCombo.length; i++) { if (currCombo[i] != null) { commoness += currCombo[i].getPriorProbability(); for (int j = 0; j < currCombo.length; j++) { if (currCombo[j] != null && i != j) { if (selectedSenses[j] == null || currCombo[j].getId() == selectedSenses[j]) { relatedness += rc.getRelatedness(currCombo[i], currCombo[j]); } comparisons++; } } } i++; } //average commonness and relatedness commoness = commoness / currCombo.length; if (comparisons == 0) { relatedness = (float) 0.5; } else { relatedness = relatedness / comparisons; } //relatedness is three times as important as commonness (hmmm, ad-hoc) float weight = (commoness + (3 * relatedness)) / 4; //check if this is best overall combination if (weight > bestComboWeight) { bestComboWeight = weight; //bestCombo = currCombo.clone() ; } //check if this is best weight for each individual sense for (org.wikipedia.miner.model.Label.Sense s : currCombo) { if (s != null) { double sWeight = bestSenseWeights.get(s.getId()); if (sWeight < weight) { bestSenseWeights.put(s.getId(), weight); } } } } } public class QueryLabel { private org.wikipedia.miner.model.Label label; private Span span; private boolean isStopword; private Integer selectedSenseId; public QueryLabel(org.wikipedia.miner.model.Label label, boolean isStopword, Span span, Integer selectedSenseId) { this.label = label; this.isStopword = isStopword; this.span = span; this.selectedSenseId = selectedSenseId; } /** * @return true if this overlaps the given reference, otherwise false. */ //public boolean overlaps(QueryLabel qt) { // return position.overlaps(qt.getPosition()) ; //} public org.wikipedia.miner.model.Label getLabel() { return label; } /** * @return the position (start and end character locations) in the * document where this reference was found. */ public Span getSpan() { return span; } public Integer getSelectedSenseId() { return selectedSenseId; } public double getWeight() { if (isStopword) { return 0; } else if (selectedSenseId != null) { return 1; } else { return label.getLinkProbability(); } } /* @Override public int compareTo(QueryLabel qt) { //starts first, then goes first int c = new Integer(span.getStart()).compareTo(qt.getSpan().getStart()) ; if (c != 0) return c ; //starts at same time, so longest one goes first c = new Integer(qt.getSpan().getEnd()).compareTo(span.getEnd()) ; return c ; }*/ } public static class Message extends Service.Message { @Expose @ElementList(entry = "label", inline = true) private ArrayList<Label> labels = new ArrayList<Label>(); private Message(HttpServletRequest request) { super(request); } private void addLabel(Label lbl) { labels.add(lbl); } public List<Label> getLabels() { return Collections.unmodifiableList(labels); } } public static class Label { @Expose @Attribute private final String text; @Expose @Attribute private final long linkDocCount; @Expose @Attribute private final long linkOccCount; @Expose @Attribute private final long docCount; @Expose @Attribute private final long occCount; @Expose @Attribute private final double linkProbability; @Expose @Attribute(required = false) private Boolean isStopword; //@Expose //@Attribute(required = false) //private Integer start ; //@Expose //@Attribute(required = false) //private Integer end ; @Expose @ElementList(entry = "sense") private final ArrayList<Sense> senses; private Label(org.wikipedia.miner.model.Label lbl) { text = lbl.getText(); linkDocCount = lbl.getLinkDocCount(); linkOccCount = lbl.getLinkOccCount(); docCount = lbl.getDocCount(); occCount = lbl.getOccCount(); linkProbability = lbl.getLinkProbability(); senses = new ArrayList<Sense>(); } private Label(QueryLabel lbl) { text = lbl.getLabel().getText(); linkDocCount = lbl.getLabel().getLinkDocCount(); linkOccCount = lbl.getLabel().getLinkOccCount(); docCount = lbl.getLabel().getDocCount(); occCount = lbl.getLabel().getOccCount(); linkProbability = lbl.getLabel().getLinkProbability(); this.isStopword = lbl.isStopword; //this.start = lbl.position.getStart() ; //this.end = lbl.position.getEnd() ; senses = new ArrayList<Sense>(); } private void addSense(Sense s) { senses.add(s); } private void sortSensesByWeight() { Collections.sort(senses, new Comparator<Sense>() { @Override public int compare(Sense s1, Sense s2) { int cmp = 0; if (s1.weight != null && s2.weight != null) { cmp = s2.weight.compareTo(s1.weight); } if (cmp != 0) { return cmp; } cmp = s2.priorProbability.compareTo(s1.priorProbability); if (cmp != 0) { return cmp; } return s1.id.compareTo(s2.id); } }); } public String getText() { return text; } public long getLinkDocCount() { return linkDocCount; } public long getLinkOccCount() { return linkOccCount; } public long getDocCount() { return docCount; } public long getOccCount() { return occCount; } public double getLinkProbability() { return linkProbability; } public Boolean getIsStopword() { return isStopword; } public ArrayList<Sense> getSenses() { return senses; } } public static class Sense { @Expose @Attribute private Integer id; @Expose @Attribute private final String title; @Expose @Attribute private final long linkDocCount; @Expose @Attribute private final long linkOccCount; @Expose @Attribute private Double priorProbability; @Expose @Attribute private final boolean fromTitle; @Expose @Attribute private final boolean fromRedirect; @Expose @Attribute private boolean isSelected; @Expose @Attribute(required = false) private Double weight; private Sense(org.wikipedia.miner.model.Label.Sense sense) { id = sense.getId(); title = sense.getTitle(); linkDocCount = sense.getLinkDocCount(); linkOccCount = sense.getLinkOccCount(); priorProbability = sense.getPriorProbability(); fromTitle = sense.isFromTitle(); fromRedirect = sense.isFromTitle(); isSelected = false; } private void setIsSelected(boolean val) { isSelected = val; } private void setWeight(double weight) { this.weight = weight; } public Integer getId() { return id; } public String getTitle() { return title; } public long getLinkDocCount() { return linkDocCount; } public long getLinkOccCount() { return linkOccCount; } public Double getPriorProbability() { return priorProbability; } public boolean isFromTitle() { return fromTitle; } public boolean isFromRedirect() { return fromRedirect; } public Double getWeight() { return weight; } public boolean isSelected() { return isSelected; } } }