/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.solr.handler; import java.io.IOException; import java.util.HashMap; import java.util.List; import java.util.ArrayList; import java.util.Map; import java.util.Locale; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.solr.client.solrj.io.Tuple; import org.apache.solr.client.solrj.io.comp.StreamComparator; import org.apache.solr.client.solrj.io.stream.StreamContext; import org.apache.solr.client.solrj.io.stream.TupleStream; import org.apache.solr.client.solrj.io.stream.expr.Explanation; import org.apache.solr.client.solrj.io.stream.expr.Expressible; import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation; import org.apache.solr.client.solrj.io.stream.expr.StreamExpression; import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter; import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter; import org.apache.solr.client.solrj.io.stream.expr.StreamFactory; import org.apache.solr.common.SolrException; import org.apache.solr.core.SolrCore; import org.apache.lucene.analysis.*; /** * The classify expression retrieves a model trained by the train expression and uses it to classify documents from a stream * Syntax: * classif(model(...), anyStream(...), field="body") **/ public class ClassifyStream extends TupleStream implements Expressible { private TupleStream docStream; private TupleStream modelStream; private String field; private String analyzerField; private Tuple modelTuple; Analyzer analyzer; private Map<CharSequence, Integer> termToIndex; private List<Double> idfs; private List<Double> modelWeights; public ClassifyStream(StreamExpression expression, StreamFactory factory) throws IOException { List<StreamExpression> streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class); if (streamExpressions.size() != 2) { throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting two stream but found %d",expression, streamExpressions.size())); } modelStream = factory.constructStream(streamExpressions.get(0)); docStream = factory.constructStream(streamExpressions.get(1)); StreamExpressionNamedParameter fieldParameter = factory.getNamedOperand(expression, "field"); if (fieldParameter == null) { throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - field parameter must be specified",expression, streamExpressions.size())); } analyzerField = field = fieldParameter.getParameter().toString(); StreamExpressionNamedParameter analyzerFieldParameter = factory.getNamedOperand(expression, "analyzerField"); if (analyzerFieldParameter != null) { analyzerField = analyzerFieldParameter.getParameter().toString(); } } @Override public void setStreamContext(StreamContext context) { Object solrCoreObj = context.get("solr-core"); if (solrCoreObj == null || !(solrCoreObj instanceof SolrCore) ) { throw new SolrException(SolrException.ErrorCode.INVALID_STATE, "StreamContext must have SolrCore in solr-core key"); } SolrCore solrCore = (SolrCore) solrCoreObj; analyzer = solrCore.getLatestSchema().getFieldType(analyzerField).getIndexAnalyzer(); this.docStream.setStreamContext(context); this.modelStream.setStreamContext(context); } @Override public List<TupleStream> children() { List<TupleStream> l = new ArrayList<>(); l.add(docStream); l.add(modelStream); return l; } @Override public void open() throws IOException { this.docStream.open(); this.modelStream.open(); } @Override public void close() throws IOException { this.docStream.close(); this.modelStream.close(); } @Override public Tuple read() throws IOException { if (modelTuple == null) { modelTuple = modelStream.read(); if (modelTuple == null || modelTuple.EOF) { throw new IOException("Model tuple not found for classify stream!"); } termToIndex = new HashMap<>(); List<String> terms = modelTuple.getStrings("terms_ss"); for (int i = 0; i < terms.size(); i++) { termToIndex.put(terms.get(i), i); } idfs = modelTuple.getDoubles("idfs_ds"); modelWeights = modelTuple.getDoubles("weights_ds"); } Tuple docTuple = docStream.read(); if (docTuple.EOF) return docTuple; String text = docTuple.getString(field); double tfs[] = new double[termToIndex.size()]; TokenStream tokenStream = analyzer.tokenStream(analyzerField, text); CharTermAttribute termAtt = tokenStream.getAttribute(CharTermAttribute.class); tokenStream.reset(); int termCount = 0; while (tokenStream.incrementToken()) { termCount++; if (termToIndex.containsKey(termAtt.toString())) { tfs[termToIndex.get(termAtt.toString())]++; } } tokenStream.end(); tokenStream.close(); List<Double> tfidfs = new ArrayList<>(termToIndex.size()); tfidfs.add(1.0); for (int i = 0; i < tfs.length; i++) { if (tfs[i] != 0) { tfs[i] = 1 + Math.log(tfs[i]); } tfidfs.add(this.idfs.get(i) * tfs[i]); } double total = 0.0; for (int i = 0; i < tfidfs.size(); i++) { total += tfidfs.get(i) * modelWeights.get(i); } double score = total * ((float) (1.0 / Math.sqrt(termCount))); double positiveProb = sigmoid(total); docTuple.put("probability_d", positiveProb); docTuple.put("score_d", score); return docTuple; } private double sigmoid(double in) { double d = 1.0 / (1+Math.exp(-in)); return d; } @Override public StreamComparator getStreamSort() { return null; } @Override public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException { return toExpression(factory, true); } private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException { // function name StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass())); if (includeStreams) { if (docStream instanceof Expressible && modelStream instanceof Expressible) { expression.addParameter(((Expressible)modelStream).toExpression(factory)); expression.addParameter(((Expressible)docStream).toExpression(factory)); } else { throw new IOException("This ClassifyStream contains a non-expressible TupleStream - it cannot be converted to an expression"); } } expression.addParameter(new StreamExpressionNamedParameter("field", field)); expression.addParameter(new StreamExpressionNamedParameter("analyzerField", analyzerField)); return expression; } @Override public Explanation toExplanation(StreamFactory factory) throws IOException { StreamExplanation explanation = new StreamExplanation(getStreamNodeId().toString()); explanation.setFunctionName(factory.getFunctionName(this.getClass())); explanation.setImplementingClass(this.getClass().getName()); explanation.setExpressionType(Explanation.ExpressionType.STREAM_DECORATOR); explanation.setExpression(toExpression(factory, false).toString()); explanation.addChild(docStream.toExplanation(factory)); explanation.addChild(modelStream.toExplanation(factory)); return explanation; } }