/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.ltr.response.transform;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.ltr.CSVFeatureLogger;
import org.apache.solr.ltr.FeatureLogger;
import org.apache.solr.ltr.LTRRescorer;
import org.apache.solr.ltr.LTRScoringQuery;
import org.apache.solr.ltr.LTRThreadModule;
import org.apache.solr.ltr.SolrQueryRequestContextUtils;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.ltr.search.LTRQParserPlugin;
import org.apache.solr.ltr.store.FeatureStore;
import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.ResultContext;
import org.apache.solr.response.transform.DocTransformer;
import org.apache.solr.response.transform.TransformerFactory;
import org.apache.solr.search.SolrIndexSearcher;
import org.apache.solr.util.SolrPluginUtils;
/**
* This transformer will take care to generate and append in the response the
* features declared in the feature store of the current reranking model,
* or a specified feature store. Ex. <code>fl=id,[features store=myStore efi.user_text="ibm"]</code>
*
* <h3>Parameters</h3>
* <code>store</code> - The feature store to extract features from. If not provided it
* will default to the features used by your reranking model.<br>
* <code>efi.*</code> - External feature information variables required by the features
* you are extracting.<br>
* <code>format</code> - The format you want the features to be returned in. Supports (dense|sparse). Defaults to dense.<br>
*/
public class LTRFeatureLoggerTransformerFactory extends TransformerFactory {
// used inside fl to specify the format (dense|sparse) of the extracted features
private static final String FV_FORMAT = "format";
// used inside fl to specify the feature store to use for the feature extraction
private static final String FV_STORE = "store";
private static String DEFAULT_LOGGING_MODEL_NAME = "logging-model";
private String fvCacheName;
private String loggingModelName = DEFAULT_LOGGING_MODEL_NAME;
private String defaultStore;
private FeatureLogger.FeatureFormat defaultFormat = FeatureLogger.FeatureFormat.DENSE;
private char csvKeyValueDelimiter = CSVFeatureLogger.DEFAULT_KEY_VALUE_SEPARATOR;
private char csvFeatureSeparator = CSVFeatureLogger.DEFAULT_FEATURE_SEPARATOR;
private LTRThreadModule threadManager = null;
public void setFvCacheName(String fvCacheName) {
this.fvCacheName = fvCacheName;
}
public void setLoggingModelName(String loggingModelName) {
this.loggingModelName = loggingModelName;
}
public void setDefaultStore(String defaultStore) {
this.defaultStore = defaultStore;
}
public void setDefaultFormat(String defaultFormat) {
this.defaultFormat = FeatureLogger.FeatureFormat.valueOf(defaultFormat.toUpperCase(Locale.ROOT));
}
public void setCsvKeyValueDelimiter(String csvKeyValueDelimiter) {
if (csvKeyValueDelimiter.length() != 1) {
throw new IllegalArgumentException("csvKeyValueDelimiter must be exactly 1 character");
}
this.csvKeyValueDelimiter = csvKeyValueDelimiter.charAt(0);
}
public void setCsvFeatureSeparator(String csvFeatureSeparator) {
if (csvFeatureSeparator.length() != 1) {
throw new IllegalArgumentException("csvFeatureSeparator must be exactly 1 character");
}
this.csvFeatureSeparator = csvFeatureSeparator.charAt(0);
}
@Override
public void init(@SuppressWarnings("rawtypes") NamedList args) {
super.init(args);
threadManager = LTRThreadModule.getInstance(args);
SolrPluginUtils.invokeSetters(this, args);
}
@Override
public DocTransformer create(String name, SolrParams localparams,
SolrQueryRequest req) {
// Hint to enable feature vector cache since we are requesting features
SolrQueryRequestContextUtils.setIsExtractingFeatures(req);
// Communicate which feature store we are requesting features for
SolrQueryRequestContextUtils.setFvStoreName(req, localparams.get(FV_STORE, defaultStore));
// Create and supply the feature logger to be used
SolrQueryRequestContextUtils.setFeatureLogger(req,
createFeatureLogger(
localparams.get(FV_FORMAT)));
return new FeatureTransformer(name, localparams, req);
}
/**
* returns a FeatureLogger that logs the features
* 'featureFormat' param: 'dense' will write features in dense format,
* 'sparse' will write the features in sparse format, null or empty will
* default to 'sparse'
*
*
* @return a feature logger for the format specified.
*/
private FeatureLogger createFeatureLogger(String formatStr) {
final FeatureLogger.FeatureFormat format;
if (formatStr != null) {
format = FeatureLogger.FeatureFormat.valueOf(formatStr.toUpperCase(Locale.ROOT));
} else {
format = this.defaultFormat;
}
if (fvCacheName == null) {
throw new IllegalArgumentException("a fvCacheName must be configured");
}
return new CSVFeatureLogger(fvCacheName, format, csvKeyValueDelimiter, csvFeatureSeparator);
}
class FeatureTransformer extends DocTransformer {
final private String name;
final private SolrParams localparams;
final private SolrQueryRequest req;
private List<LeafReaderContext> leafContexts;
private SolrIndexSearcher searcher;
private LTRScoringQuery scoringQuery;
private LTRScoringQuery.ModelWeight modelWeight;
private FeatureLogger featureLogger;
private boolean docsWereNotReranked;
/**
* @param name
* Name of the field to be added in a document representing the
* feature vectors
*/
public FeatureTransformer(String name, SolrParams localparams,
SolrQueryRequest req) {
this.name = name;
this.localparams = localparams;
this.req = req;
}
@Override
public String getName() {
return name;
}
@Override
public void setContext(ResultContext context) {
super.setContext(context);
if (context == null) {
return;
}
if (context.getRequest() == null) {
return;
}
searcher = context.getSearcher();
if (searcher == null) {
throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST,
"searcher is null");
}
leafContexts = searcher.getTopReaderContext().leaves();
// Setup LTRScoringQuery
scoringQuery = SolrQueryRequestContextUtils.getScoringQuery(req);
docsWereNotReranked = (scoringQuery == null);
String featureStoreName = SolrQueryRequestContextUtils.getFvStoreName(req);
if (docsWereNotReranked || (featureStoreName != null && (!featureStoreName.equals(scoringQuery.getScoringModel().getFeatureStoreName())))) {
// if store is set in the transformer we should overwrite the logger
final ManagedFeatureStore fr = ManagedFeatureStore.getManagedFeatureStore(req.getCore());
final FeatureStore store = fr.getFeatureStore(featureStoreName);
featureStoreName = store.getName(); // if featureStoreName was null before this gets actual name
try {
final LoggingModel lm = new LoggingModel(loggingModelName,
featureStoreName, store.getFeatures());
scoringQuery = new LTRScoringQuery(lm,
LTRQParserPlugin.extractEFIParams(localparams),
true,
threadManager); // request feature weights to be created for all features
// Local transformer efi if provided
scoringQuery.setOriginalQuery(context.getQuery());
}catch (final Exception e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"retrieving the feature store "+featureStoreName, e);
}
}
if (scoringQuery.getFeatureLogger() == null){
scoringQuery.setFeatureLogger( SolrQueryRequestContextUtils.getFeatureLogger(req) );
}
scoringQuery.setRequest(req);
featureLogger = scoringQuery.getFeatureLogger();
try {
modelWeight = scoringQuery.createWeight(searcher, true, 1f);
} catch (final IOException e) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST, e.getMessage(), e);
}
if (modelWeight == null) {
throw new SolrException(SolrException.ErrorCode.BAD_REQUEST,
"error logging the features, model weight is null");
}
}
@Override
public void transform(SolrDocument doc, int docid, float score)
throws IOException {
Object fv = featureLogger.getFeatureVector(docid, scoringQuery, searcher);
if (fv == null) { // FV for this document was not in the cache
fv = featureLogger.makeFeatureVector(
LTRRescorer.extractFeaturesInfo(
modelWeight,
docid,
(docsWereNotReranked ? new Float(score) : null),
leafContexts));
}
doc.addField(name, fv);
}
}
private static class LoggingModel extends LTRScoringModel {
public LoggingModel(String name, String featureStoreName, List<Feature> allFeatures){
this(name, Collections.emptyList(), Collections.emptyList(),
featureStoreName, allFeatures, Collections.emptyMap());
}
protected LoggingModel(String name, List<Feature> features,
List<Normalizer> norms, String featureStoreName,
List<Feature> allFeatures, Map<String,Object> params) {
super(name, features, norms, featureStoreName, allFeatures, params);
}
@Override
public float score(float[] modelFeatureValuesNormalized) {
return 0;
}
@Override
public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
return Explanation.match(finalScore, toString()
+ " logging model, used only for logging the features");
}
}
}