/* * Ivory: A Hadoop toolkit for web-scale information retrieval * * Licensed under the Apache License, Version 2.0 (the "License"); you * may not use this file except in compliance with the License. You may * obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. See the License for the specific language governing * permissions and limitations under the License. */ package ivory.smrf.model.expander; import ivory.core.ConfigurationException; import ivory.core.RetrievalEnvironment; import ivory.core.RetrievalException; import ivory.core.data.document.IntDocVector; import ivory.core.data.document.IntDocVector.Reader; import ivory.core.util.XMLTools; import ivory.smrf.model.MarkovRandomField; import ivory.smrf.model.Parameter; import ivory.smrf.model.VocabFrequencyPair; import ivory.smrf.model.builder.MRFBuilder; import ivory.smrf.model.importance.ConceptImportanceModel; import ivory.smrf.retrieval.Accumulator; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Set; import org.w3c.dom.Node; import org.w3c.dom.NodeList; import tl.lin.data.map.HMapIV; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import com.google.common.collect.Maps; /** * @author Don Metzler */ public abstract class MRFExpander { protected RetrievalEnvironment env = null; // Ivory retrieval environment. protected int numFeedbackDocs; // Number of feedback documents. protected int numFeedbackTerms; // Number of feedback terms. protected Set<String> stopwords = null; // Stopwords list. // The expansion MRF cliques should be scaled according to this weight. protected float expanderWeight; // Maximum number of candidates to consider for expansion; non-positive numbers result in all // candidates being considered. protected int maxCandidates = 0; /** * @param mrf * @param results */ public abstract MarkovRandomField getExpandedMRF(MarkovRandomField mrf, Accumulator[] results) throws ConfigurationException; /** * @param words list of words to ignore when constructing expansion concepts */ public void setStopwordList(Set<String> words) { this.stopwords = Preconditions.checkNotNull(words); } public void setMaxCandidates(int maxCandidates) { this.maxCandidates = maxCandidates; } /** * @param env * @param model * @throws ConfigurationException */ public static MRFExpander getExpander(RetrievalEnvironment env, Node model) throws ConfigurationException { Preconditions.checkNotNull(env); Preconditions.checkNotNull(model); // Get model type. String expanderType = XMLTools.getAttributeValueOrThrowException(model, "type", "Expander type must be specified!"); // Get normalized model type. String normExpanderType = expanderType.toLowerCase().trim(); // Build the expander. MRFExpander expander = null; if ("unigramlatentconcept".equals(normExpanderType)) { int fbDocs = XMLTools.getAttributeValue(model, "fbDocs", 10); int fbTerms = XMLTools.getAttributeValue(model, "fbTerms", 10); float expanderWeight = XMLTools.getAttributeValue(model, "weight", 1.0f); List<Parameter> parameters = Lists.newArrayList(); List<Node> scoreFunctionNodes = Lists.newArrayList(); List<ConceptImportanceModel> importanceModels = Lists.newArrayList(); // Get the expandermodel, which describes how to actually build the expanded MRF. NodeList children = model.getChildNodes(); for (int i = 0; i < children.getLength(); i++) { Node child = children.item(i); if ("conceptscore".equals(child.getNodeName())) { String paramID = XMLTools.getAttributeValueOrThrowException(child, "id", "conceptscore node must specify an id attribute!"); float weight = XMLTools.getAttributeValue(child, "weight", 1.0f); parameters.add(new Parameter(paramID, weight)); scoreFunctionNodes.add(child); // Get concept importance source (if applicable). ConceptImportanceModel importanceModel = null; String importanceSource = XMLTools.getAttributeValue(child, "importance", null); if (importanceSource != null) { importanceModel = env.getImportanceModel(importanceSource); if (importanceModel == null) { throw new RetrievalException("Error: importancemodel " + importanceSource + " not found!"); } } importanceModels.add(importanceModel); } } // Make sure there's at least one expansion model specified. if (scoreFunctionNodes.size() == 0) { throw new ConfigurationException("No conceptscore specified!"); } // Create the expander. expander = new UnigramLatentConceptExpander(env, fbDocs, fbTerms, expanderWeight, parameters, scoreFunctionNodes, importanceModels); // Maximum number of candidate expansion terms to consider per query. int maxCandidates = XMLTools.getAttributeValue(model, "maxCandidates", 0); if (maxCandidates > 0) { expander.setMaxCandidates(maxCandidates); } } else if ("latentconcept".equals(normExpanderType)) { int defaultFbDocs = XMLTools.getAttributeValue(model, "fbDocs", 10); int defaultFbTerms = XMLTools.getAttributeValue(model, "fbTerms", 10); List<Integer> gramList = new ArrayList<Integer>(); List<MRFBuilder> builderList = new ArrayList<MRFBuilder>(); List<Integer> fbDocsList = new ArrayList<Integer>(); List<Integer> fbTermsList = new ArrayList<Integer>(); // Get the expandermodel, which describes how to actually build the expanded MRF. NodeList children = model.getChildNodes(); for (int i = 0; i < children.getLength(); i++) { Node child = children.item(i); if ("expansionmodel".equals(child.getNodeName())) { int gramSize = XMLTools.getAttributeValue(child, "gramSize", 1); int fbDocs = XMLTools.getAttributeValue(child, "fbDocs", defaultFbDocs); int fbTerms = XMLTools.getAttributeValue(child, "fbTerms", defaultFbTerms); // Set MRF builder parameters. gramList.add(gramSize); builderList.add(MRFBuilder.get(env, child)); fbDocsList.add(fbDocs); fbTermsList.add(fbTerms); } } // Make sure there's at least one expansion model specified. if (builderList.size() == 0) { throw new ConfigurationException("No expansionmodel specified!"); } // Create the expander. expander = new NGramLatentConceptExpander(env, gramList, builderList, fbDocsList, fbTermsList); // Maximum number of candidate expansion terms to consider per query. int maxCandidates = XMLTools.getAttributeValue(model, "maxCandidates", 0); if (maxCandidates > 0) { expander.setMaxCandidates(maxCandidates); } } else { throw new ConfigurationException("Unrecognized expander type -- " + expanderType); } return expander; } @SuppressWarnings("unchecked") protected TfDoclengthStatistics getTfDoclengthStatistics(IntDocVector[] docVecs) throws IOException { Preconditions.checkNotNull(docVecs); Map<String, Integer> vocab = Maps.newHashMap(); Map<String, Short>[] tfs = new HashMap[docVecs.length]; int[] doclens = new int[docVecs.length]; for (int i = 0; i < docVecs.length; i++) { IntDocVector doc = docVecs[i]; Map<String, Short> docTfs = new HashMap<String, Short>(); int doclen = 0; Reader dvReader = doc.getReader(); while (dvReader.hasMoreTerms()) { int termid = dvReader.nextTerm(); String stem = env.getTermFromId(termid); short tf = dvReader.getTf(); doclen += tf; if (stem != null && (stopwords == null || !stopwords.contains(stem))) { Integer df = vocab.get(stem); if (df != null) { vocab.put(stem, df + 1); } else { vocab.put(stem, 1); } } docTfs.put(stem, tf); } tfs[i] = docTfs; doclens[i] = doclen; } // Sort the vocab hashmap according to tf. VocabFrequencyPair[] entries = new VocabFrequencyPair[vocab.size()]; int entryNum = 0; for (Entry<String, Integer> entry : vocab.entrySet()) { entries[entryNum++] = new VocabFrequencyPair(entry.getKey(), entry.getValue()); } Arrays.sort(entries); return new TfDoclengthStatistics(entries, tfs, doclens); } /** * @param docVecs * @param gramSize * @throws IOException */ protected VocabFrequencyPair[] getVocabulary(IntDocVector[] docVecs, int gramSize) throws IOException { Map<String, Integer> vocab = new HashMap<String, Integer>(); for (IntDocVector doc : docVecs) { HMapIV<String> termMap = new HMapIV<String>(); int maxPos = Integer.MIN_VALUE; Reader dvReader = doc.getReader(); while (dvReader.hasMoreTerms()) { int termid = dvReader.nextTerm(); String stem = env.getTermFromId(termid); int[] pos = dvReader.getPositions(); for (int i = 0; i < pos.length; i++) { termMap.put(pos[i], stem); if (pos[i] > maxPos) { maxPos = pos[i]; } } } // Grab all grams of size gramSize that do not contain any out of vocabulary terms. for (int pos = 0; pos <= maxPos + 1 - gramSize; pos++) { String concept = new String(); boolean toAdd = true; for (int offset = 0; offset < gramSize; offset++) { String stem = termMap.get(pos + offset); if (stem == null || (stopwords != null && stopwords.contains(stem))) { toAdd = false; break; } if (offset == gramSize - 1) { concept += stem; } else { concept += stem + " "; } } if (toAdd) { Integer tf = vocab.get(concept); if (tf != null) { vocab.put(concept, tf + 1); } else { vocab.put(concept, 1); } } } } // Sort the vocab hashmap according to tf. VocabFrequencyPair[] entries = new VocabFrequencyPair[vocab.size()]; int entryNum = 0; for (Entry<String, Integer> entry : vocab.entrySet()) { entries[entryNum++] = new VocabFrequencyPair(entry.getKey(), entry.getValue()); } Arrays.sort(entries); return entries; } protected class TfDoclengthStatistics { private VocabFrequencyPair[] vocab = null; private Map<String, Short>[] tfs = null; private int[] doclengths = null; public TfDoclengthStatistics(VocabFrequencyPair[] entries, Map<String, Short>[] tfs, int[] doclengths) { this.vocab = Preconditions.checkNotNull(entries); this.tfs = Preconditions.checkNotNull(tfs); this.doclengths = Preconditions.checkNotNull(doclengths); } public VocabFrequencyPair[] getVocab() { return vocab; } public Map<String, Short>[] getTfs() { return tfs; } public int[] getDoclens() { return doclengths; } } }