/* * Copyright (c) 2017 Villu Ruusmann * * This file is part of JPMML-SkLearn * * JPMML-SkLearn is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-SkLearn is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>. */ package sklearn.feature_extraction.text; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.Reader; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Map; import javax.xml.parsers.DocumentBuilder; import com.google.common.base.Joiner; import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; import com.google.common.io.CharStreams; import numpy.DType; import numpy.core.Scalar; import org.dmg.pmml.Apply; import org.dmg.pmml.Constant; import org.dmg.pmml.DataType; import org.dmg.pmml.DefineFunction; import org.dmg.pmml.DerivedField; import org.dmg.pmml.FieldName; import org.dmg.pmml.FieldRef; import org.dmg.pmml.InlineTable; import org.dmg.pmml.OpType; import org.dmg.pmml.ParameterField; import org.dmg.pmml.TextIndex; import org.dmg.pmml.TextIndexNormalization; import org.jpmml.converter.ContinuousFeature; import org.jpmml.converter.DOMUtil; import org.jpmml.converter.Feature; import org.jpmml.converter.FeatureUtil; import org.jpmml.converter.PMMLEncoder; import org.jpmml.converter.PMMLUtil; import org.jpmml.converter.ValueUtil; import org.jpmml.sklearn.ClassDictUtil; import org.jpmml.sklearn.SkLearnEncoder; import sklearn.HasNumberOfFeatures; import sklearn.Transformer; import sklearn2pmml.feature_extraction.text.Splitter; public class CountVectorizer extends Transformer implements HasNumberOfFeatures { public CountVectorizer(String module, String name){ super(module, name); } @Override public OpType getOpType(){ return OpType.CATEGORICAL; } @Override public DataType getDataType(){ return DataType.STRING; } @Override public int getNumberOfFeatures(){ return 1; } @Override public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encoder){ Boolean lowercase = getLowercase(); Map<String, Scalar> vocabulary = getVocabulary(); ClassDictUtil.checkSize(1, features); Feature feature = features.get(0); BiMap<String, Integer> termIndexMap = HashBiMap.create(vocabulary.size()); Collection<Map.Entry<String, Scalar>> entries = vocabulary.entrySet(); for(Map.Entry<String, Scalar> entry : entries){ termIndexMap.put(entry.getKey(), ValueUtil.asInt((Number)(entry.getValue()).getOnlyElement())); } BiMap<Integer, String> indexTermMap = termIndexMap.inverse(); DType dtype = getDType(); if(lowercase){ FieldName name = FeatureUtil.createName("lowercase", feature); DerivedField derivedField = encoder.getDerivedField(name); if(derivedField == null){ Apply apply = PMMLUtil.createApply("lowercase", feature.ref()); derivedField = encoder.createDerivedField(name, OpType.CATEGORICAL, DataType.STRING, apply); } feature = new Feature(encoder, derivedField.getName(), derivedField.getDataType()){ @Override public ContinuousFeature toContinuousFeature(){ throw new UnsupportedOperationException(); } }; } DefineFunction defineFunction = encodeDefineFunction(); encoder.addDefineFunction(defineFunction); List<Feature> result = new ArrayList<>(); for(int i = 0, max = indexTermMap.size(); i < max; i++){ String term = indexTermMap.get(i); final Apply apply = encodeApply(defineFunction.getName(), feature, i, term); Feature termFeature = new Feature(encoder, FieldName.create(defineFunction.getName() + "(" + term + ")"), dtype != null ? dtype.getDataType() : DataType.DOUBLE){ @Override public ContinuousFeature toContinuousFeature(){ PMMLEncoder encoder = ensureEncoder(); DerivedField derivedField = encoder.getDerivedField(getName()); if(derivedField == null){ derivedField = encoder.createDerivedField(getName(), OpType.CONTINUOUS, getDataType(), apply); } return new ContinuousFeature(encoder, derivedField); } }; result.add(termFeature); } return result; } public DefineFunction encodeDefineFunction(){ String analyzer = getAnalyzer(); List<String> stopWords = getStopWords(); Object[] nGramRange = getNGramRange(); Boolean binary = getBinary(); Object preprocessor = getPreprocessor(); String stripAccents = getStripAccents(); Splitter tokenizer = getTokenizer(); switch(analyzer){ case "word": break; default: throw new IllegalArgumentException(analyzer); } if(preprocessor != null){ throw new IllegalArgumentException(); } // End if if(stripAccents != null){ throw new IllegalArgumentException(stripAccents); } ParameterField documentField = new ParameterField(FieldName.create("document")); ParameterField termField = new ParameterField(FieldName.create("term")); TextIndex textIndex = new TextIndex(documentField.getName()) .setTokenize(Boolean.TRUE) .setWordSeparatorCharacterRE(tokenizer.getSeparatorRE()) .setLocalTermWeights(binary ? TextIndex.LocalTermWeights.BINARY : null) .setExpression(new FieldRef(termField.getName())); if((stopWords != null && stopWords.size() > 0) && !Arrays.equals(nGramRange, new Integer[]{1, 1})){ DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder(); InlineTable inlineTable = new InlineTable() .addRows(DOMUtil.createRow(documentBuilder, Arrays.asList("string", "stem", "regex"), Arrays.asList("(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWords) + ")\\p{Punct}*(\\s+|$)", " ", "true"))); TextIndexNormalization textIndexNormalization = new TextIndexNormalization() .setRecursive(Boolean.TRUE) // Handles consecutive matches. See http://stackoverflow.com/a/25085385 .setInlineTable(inlineTable); textIndex.addTextIndexNormalizations(textIndexNormalization); } DefineFunction defineFunction = new DefineFunction("tf", OpType.CONTINUOUS, null) .setDataType(DataType.DOUBLE) .addParameterFields(documentField, termField) .setExpression(textIndex); return defineFunction; } public Apply encodeApply(String function, Feature feature, int index, String term){ Constant constant = PMMLUtil.createConstant(term) .setDataType(DataType.STRING); return PMMLUtil.createApply(function, feature.ref(), constant); } public String getAnalyzer(){ return (String)get("analyzer"); } public Boolean getBinary(){ return (Boolean)get("binary"); } public Boolean getLowercase(){ return (Boolean)get("lowercase"); } public Object[] getNGramRange(){ return (Object[])get("ngram_range"); } public Object getPreprocessor(){ return get("preprocessor"); } public List<String> getStopWords(){ Object stopWords = get("stop_words"); if(stopWords instanceof String){ return loadStopWords((String)stopWords); } return (List)stopWords; } public String getStripAccents(){ return (String)get("strip_accents"); } public Splitter getTokenizer(){ Object tokenizer = get("tokenizer"); try { if(tokenizer == null){ throw new NullPointerException(); } return (Splitter)tokenizer; } catch(RuntimeException re){ throw new IllegalArgumentException("The tokenizer object (" + ClassDictUtil.formatClass(tokenizer) + ") is not Splitter"); } } public String getTokenPattern(){ return (String)get("token_pattern"); } public Map<String, Scalar> getVocabulary(){ return (Map)get("vocabulary_"); } static private List<String> loadStopWords(String stopWords){ InputStream is = CountVectorizer.class.getResourceAsStream("/stop_words/" + stopWords + ".txt"); if(is == null){ throw new IllegalArgumentException(stopWords); } try(Reader reader = new InputStreamReader(is, "UTF-8")){ return CharStreams.readLines(reader); } catch(IOException ioe){ throw new IllegalArgumentException(stopWords, ioe); } } private static final Joiner JOINER = Joiner.on("|"); }