/**
* Copyright 2013-2015 Pierre Merienne
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package stormy.pythian.component.classifier;
import static stormy.pythian.model.instance.Instance.INSTANCE_FIELD;
import static stormy.pythian.model.instance.Instance.NEW_INSTANCE_FIELD;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import storm.trident.Stream;
import storm.trident.TridentState;
import storm.trident.operation.TridentCollector;
import storm.trident.state.BaseQueryFunction;
import storm.trident.state.BaseStateUpdater;
import storm.trident.state.StateFactory;
import storm.trident.state.map.MapState;
import storm.trident.tuple.TridentTuple;
import stormy.pythian.model.annotation.InputStream;
import stormy.pythian.model.annotation.ListMapper;
import stormy.pythian.model.annotation.OutputStream;
import stormy.pythian.model.annotation.Property;
import stormy.pythian.model.annotation.State;
import stormy.pythian.model.component.Component;
import stormy.pythian.model.instance.Instance;
import stormy.pythian.model.instance.ListedFeaturesMapper;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Values;
import com.github.pmerienne.trident.ml.util.KeysUtil;
@SuppressWarnings("serial")
public abstract class Classifier<L> implements Component {
@InputStream(name = "update")
private transient Stream update;
@ListMapper(stream = "update")
protected ListedFeaturesMapper updateInputMapper;
@InputStream(name = "query")
private transient Stream query;
@ListMapper(stream = "query")
protected ListedFeaturesMapper queryInputMapper;
@OutputStream(name = "prediction", from = "query")
private transient Stream prediction;
@State(name = "Classifier's model")
private transient StateFactory stateFactory;
@Property(name = "Classifier name", mandatory = true)
private String classifierName;
protected abstract void update(Instance instance);
protected abstract void classify(Instance instance);
protected abstract void initClassifier();
@Override
public void init() {
initClassifier();
TridentState classifierState = update.partitionPersist(stateFactory, new Fields(INSTANCE_FIELD), new ClassifierUpdater<L>(classifierName, this, updateInputMapper));
prediction = query.stateQuery(classifierState, new Fields(INSTANCE_FIELD), new ClassifyQuery<L>(classifierName, queryInputMapper), new Fields(NEW_INSTANCE_FIELD));
}
private static class ClassifierUpdater<L> extends BaseStateUpdater<MapState<Classifier<L>>> {
private final String classifierName;
private final Classifier<L> initialClassifier;
private final ListedFeaturesMapper mapper;
public ClassifierUpdater(String classifierName, Classifier<L> initialClassifier, ListedFeaturesMapper mapper) {
this.classifierName = classifierName;
this.initialClassifier = initialClassifier;
this.mapper = mapper;
}
@Override
public void updateState(MapState<Classifier<L>> state, List<TridentTuple> tuples, TridentCollector collector) {
// Get model
List<Classifier<L>> classifiers = state.multiGet(KeysUtil.toKeys(this.classifierName));
Classifier<L> classifier = null;
if (classifiers != null && !classifiers.isEmpty()) {
classifier = classifiers.get(0);
}
// Init it if necessary
if (classifier == null) {
classifier = this.initialClassifier;
}
// Update model
for (TridentTuple tuple : tuples) {
Instance instance = Instance.get(tuple, mapper);
classifier.update(instance);
}
// Save model
state.multiPut(KeysUtil.toKeys(this.classifierName), Arrays.asList(classifier));
}
}
private static class ClassifyQuery<L> extends BaseQueryFunction<MapState<Classifier<L>>, Instance> {
private final String classifierName;
private final ListedFeaturesMapper mapper;
public ClassifyQuery(String classifierName, ListedFeaturesMapper mapper) {
this.classifierName = classifierName;
this.mapper = mapper;
}
@Override
public List<Instance> batchRetrieve(MapState<Classifier<L>> state, List<TridentTuple> tuples) {
List<Instance> instances = new ArrayList<Instance>();
List<Classifier<L>> classifiers = state.multiGet(KeysUtil.toKeys(this.classifierName));
if (classifiers != null && !classifiers.isEmpty()) {
Classifier<L> classifier = classifiers.get(0);
if (classifier == null) {
for (int i = 0; i < tuples.size(); i++) {
instances.add(null);
}
} else {
for (TridentTuple tuple : tuples) {
Instance instance = Instance.get(tuple, mapper);
classifier.classify(instance);
instances.add(instance);
}
}
} else {
for (int i = 0; i < tuples.size(); i++) {
instances.add(null);
}
}
return instances;
}
public void execute(TridentTuple tuple, Instance instance, TridentCollector collector) {
collector.emit(new Values(instance));
}
}
}