/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.ml.modelinput;
import org.elasticsearch.common.collect.Tuple;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public abstract class AnalyzedTextVectorRange extends VectorRange {
int offset;
public static final EsSparseNumericVector EMPTY_SPARSE = new EsSparseNumericVector(new Tuple<>(new int[]{}, new double[]{}));
public AnalyzedTextVectorRange(String field, String type) {
super(field, field, type);
}
public enum FeatureType {
OCCURRENCE,
TF,
TF_IDF,
BM25;
public String toString() {
switch (this.ordinal()) {
case 0:
return "occurrence";
case 1:
return "tf";
case 2:
return "tf_idf";
case 3:
return "bm25";
}
throw new IllegalStateException("There is no toString() for ordinal " + this.ordinal() +
" - someone forgot to implement toString().");
}
public static FeatureType fromString(String s) {
if (s.equals(OCCURRENCE.toString())) {
return OCCURRENCE;
} else if (s.equals(TF.toString())) {
return TF;
} else if (s.equals(TF_IDF.toString())) {
return TF_IDF;
} else if (s.equals(BM25.toString())) {
return BM25;
} else {
throw new IllegalStateException("Don't know what " + s + " is - choose one of " + OCCURRENCE.toString() + " " +
TF.toString() + " " + TF_IDF.toString() + " " + BM25.toString());
}
}
}
public static class SparseTermVectorRange extends AnalyzedTextVectorRange {
private String number;
Map<String, Integer> wordMap;
public SparseTermVectorRange(String field, String type, String[] terms, String number, int offset) {
super(field, type);
this.number = number;
this.field = field;
wordMap = new HashMap<>();
for (int i = 0; i < terms.length; i++) {
wordMap.put(terms[i], i + offset);
}
}
@Override
public EsVector getVector(DataSource dataSource) {
Tuple<int[], double[]> indicesAndValues;
if (AnalyzedTextVectorRange.FeatureType.fromString(number).equals(AnalyzedTextVectorRange.FeatureType.TF)) {
indicesAndValues = dataSource.getTfSparse(wordMap, field);
} else if (AnalyzedTextVectorRange.FeatureType.fromString(number).equals(AnalyzedTextVectorRange.FeatureType.OCCURRENCE)) {
indicesAndValues = dataSource.getOccurrenceSparse(wordMap, field);
} else if (AnalyzedTextVectorRange.FeatureType.fromString(number).equals(AnalyzedTextVectorRange.FeatureType.TF_IDF)) {
indicesAndValues = dataSource.getTfIdfSparse(wordMap, field);
} else {
throw new IllegalArgumentException(number + " not implemented yet for sparse vector");
}
if (indicesAndValues != null) {
return new EsSparseNumericVector(indicesAndValues);
} else {
return EMPTY_SPARSE;
}
}
@Override
public EsVector getVector(Map<String, List<Object>> fieldValues) {
throw new UnsupportedOperationException("Remove this later, we should not get here.");
}
@Override
public int size() {
return wordMap.size();
}
}
public static class DenseTermVectorRange extends AnalyzedTextVectorRange {
String[] terms;
String number;
public DenseTermVectorRange(String field, String type, String[] terms, String number, int offset) {
super(field, type);
this.terms = terms;
this.number = number;
this.offset = offset;
this.field = field;
}
@Override
public EsVector getVector(DataSource dataSource) {
if (AnalyzedTextVectorRange.FeatureType.fromString(number).equals(AnalyzedTextVectorRange.FeatureType.TF)) {
return new EsDenseNumericVector(dataSource.getTfDense(terms, field));
} else if (AnalyzedTextVectorRange.FeatureType.fromString(number).equals(AnalyzedTextVectorRange.FeatureType.OCCURRENCE)) {
return new EsDenseNumericVector(dataSource.getOccurrenceDense(terms, field));
} else if (AnalyzedTextVectorRange.FeatureType.fromString(number).equals(AnalyzedTextVectorRange.FeatureType.TF_IDF)) {
return new EsDenseNumericVector(dataSource.getTfIdfDense(terms, field));
} else {
throw new IllegalArgumentException(number + " not implemented yet for dense vector");
}
}
@Override
public EsVector getVector(Map<String, List<Object>> fieldValues) {
throw new UnsupportedOperationException("Remove this later, we should not get here.");
}
@Override
public int size() {
return terms.length;
}
}
}