/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.ml.modelinput; import org.elasticsearch.action.preparespec.TransportPrepareSpecAction; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; public class VectorRangesToVectorJSON extends VectorRangesToVector { // number of entries public VectorRangesToVectorJSON(Map<String, Object> source) { assert source.get("sparse") == null || source.get("sparse") instanceof Boolean; sparse = TransportPrepareSpecAction.getSparse(source.get("sparse")); assert (source.containsKey("features")); @SuppressWarnings("unchecked") ArrayList<Map<String, Object>> featuresArray = (ArrayList<Map<String, Object>>) source.get("features"); int offset = 0; for (Map<String, Object> feature : featuresArray) { assert feature.get("field") != null; assert feature.get("type") != null; assert feature.get("type").equals("terms"); // nothing else implemented yet assert feature.get("terms") != null; assert feature.get("number") != null; if (sparse) { vectorRangeList.add(new AnalyzedTextVectorRange.SparseTermVectorRange((String) feature.get("field"), "int", getTerms(feature.get("terms")), (String) feature.get("number"), offset)); } else { vectorRangeList.add(new AnalyzedTextVectorRange.DenseTermVectorRange((String) feature.get("field"), "int", getTerms (feature.get("terms")), (String) feature.get("number"), offset)); } offset += vectorRangeList.get(vectorRangeList.size() - 1).size(); numEntries += vectorRangeList.get(vectorRangeList.size() - 1).size(); } } private String[] getTerms(Object terms) { assert terms instanceof ArrayList; @SuppressWarnings("unchecked") ArrayList<String> termsList = (ArrayList<String>) terms; String[] finalTerms = new String[termsList.size()]; int i = 0; for (String term : termsList) { finalTerms[i] = term; i++; } return finalTerms; } public Object vector(DataSource dataSource) { if (sparse) { int length = 0; List<EsSparseNumericVector> entries = new ArrayList<>(); for (VectorRange fieldEntry : vectorRangeList) { EsSparseNumericVector vec = (EsSparseNumericVector) fieldEntry.getVector(dataSource); entries.add(vec); length += vec.values.v1().length; } Map<String, Object> finalVector = new HashMap<>(); double[] values = new double[length]; int[] indices = new int[length]; int curPos = 0; for (EsSparseNumericVector vector : entries) { int numValues = vector.values.v1().length; System.arraycopy(vector.values.v1(), 0, indices, curPos, numValues); System.arraycopy(vector.values.v2(), 0, values, curPos, numValues); curPos += numValues; } finalVector.put("values", values); finalVector.put("indices", indices); finalVector.put("length", numEntries); return finalVector; } else { int length = 0; List<double[]> entries = new ArrayList<>(); for (VectorRange fieldEntry : vectorRangeList) { EsDenseNumericVector vec = (EsDenseNumericVector) fieldEntry.getVector(dataSource); entries.add(vec.values); length += vec.values.length; } Map<String, Object> finalVector = new HashMap<>(); double[] values = new double[length]; int curPos = 0; for (double[] vals : entries) { int numValues = vals.length; System.arraycopy(vals, 0, values, curPos, numValues); curPos += numValues; } finalVector.put("values", values); finalVector.put("length", numEntries); return finalVector; } } }