/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.action.trainmodel; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ValidateActions; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.ParseFieldMatcher; import org.elasticsearch.common.ParseFieldMatcherSupplier; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.ml.training.DataSet; import org.elasticsearch.ml.training.ModelInputField; import org.elasticsearch.ml.training.ModelTargetField; import java.io.IOException; import java.util.Arrays; import java.util.Collections; import java.util.List; public class TrainModelRequest extends ActionRequest<TrainModelRequest> { public static ObjectParser<TrainModelRequest, ParseFieldMatcherSupplier> PARSER = new ObjectParser<>("train_model_request", TrainModelRequest::new); static { PARSER.declareString(TrainModelRequest::setModelType, new ParseField("type")); PARSER.declareString(TrainModelRequest::setModelId, new ParseField("id")); PARSER.declareField(TrainModelRequest::setModelSettings, (p) -> Settings.builder().put(p.mapOrdered()).build(), new ParseField("settings"), ObjectParser.ValueType.OBJECT); PARSER.declareField(TrainModelRequest::setTargetField, (xContentParser, parseFieldMatcherSupplier) -> { if (xContentParser.currentToken() == XContentParser.Token.VALUE_STRING) { try { return new ModelTargetField(xContentParser.text()); } catch (IOException ex) { throw new ElasticsearchParseException("cannot parse input field", ex); } } return ModelTargetField.PARSER.apply(xContentParser, parseFieldMatcherSupplier); }, new ParseField("target_field"), ObjectParser.ValueType.OBJECT_OR_STRING); PARSER.declareObjectArray(TrainModelRequest::setFields, (xContentParser, parseFieldMatcherSupplier) -> { if (xContentParser.currentToken() == XContentParser.Token.VALUE_STRING) { try { return new ModelInputField(xContentParser.text()); } catch (IOException ex) { throw new ElasticsearchParseException("cannot parse input field", ex); } } return ModelInputField.PARSER.apply(xContentParser, parseFieldMatcherSupplier); }, new ParseField("fields")); PARSER.declareObject(TrainModelRequest::setTrainingSet, DataSet.PARSER, new ParseField("training_set")); PARSER.declareObject(TrainModelRequest::setTestingSet, DataSet.PARSER, new ParseField("testing_set")); } private String modelType; @Nullable private String modelId; private DataSet trainingSet; @Nullable private DataSet testingSet; private Settings modelSettings = Settings.EMPTY; private ModelTargetField outputField; private List<ModelInputField> fields = Collections.emptyList(); public TrainModelRequest() { } public TrainModelRequest(String modelType, String modelId, DataSet trainingSet, DataSet testingSet, Settings modelSettings, ModelTargetField outputField, ModelInputField... fields) { this.modelType = modelType; this.modelId = modelId; this.trainingSet = trainingSet; this.testingSet = testingSet; this.modelSettings = modelSettings; this.outputField = outputField; this.fields = Arrays.asList(fields); } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; if (modelType == null) { validationException = ValidateActions.addValidationError("missing model type", validationException); } if (outputField == null) { validationException = ValidateActions.addValidationError("missing output field", validationException); } if (fields == null || fields.size() == 0) { validationException = ValidateActions.addValidationError("at least one input field is required", validationException); } if (modelSettings == null) { validationException = ValidateActions.addValidationError("missing model settings", validationException); } if (trainingSet == null) { validationException = ValidateActions.addValidationError("missing training set", validationException); } return validationException; } @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); modelType = in.readString(); modelId = in.readOptionalString(); trainingSet = new DataSet(in); testingSet = in.readOptionalWriteable(DataSet::new); modelSettings = Settings.readSettingsFromStream(in); outputField = new ModelTargetField(in); fields = in.readList(ModelInputField::new); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(modelType); out.writeOptionalString(modelId); trainingSet.writeTo(out); out.writeOptionalWriteable(testingSet); Settings.writeSettingsToStream(modelSettings, out); outputField.writeTo(out); out.writeList(fields); } public void source(BytesReference content) throws IOException { try (XContentParser parser = XContentHelper.createParser(content)) { TrainModelRequest.PARSER.parse(parser, this, () -> ParseFieldMatcher.STRICT); } } public void setModelType(String modelType) { this.modelType = modelType; } public void setModelId(String modelId) { this.modelId = modelId; } public void setModelSettings(Settings modelSettings) { this.modelSettings = modelSettings; } public void setTargetField(ModelTargetField outputField) { this.outputField = outputField; } public void setFields(List<ModelInputField> fields) { this.fields = fields; } public void setTrainingSet(DataSet trainingSet) { this.trainingSet = trainingSet; } public void setTestingSet(DataSet testingSet) { this.testingSet = testingSet; } public String getModelType() { return modelType; } public String getModelId() { return modelId; } public Settings getModelSettings() { return modelSettings; } public ModelTargetField getTargetField() { return outputField; } public List<ModelInputField> getFields() { return fields; } public DataSet getTrainingSet() { return trainingSet; } public DataSet getTestingSet() { return testingSet; } }