/**
* Copyright 2013-2015 Pierre Merienne
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.github.pmerienne.trident.ml.testing;
import java.util.Map;
import java.util.Random;
import storm.trident.operation.TridentCollector;
import storm.trident.spout.IBatchSpout;
import backtype.storm.task.TopologyContext;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Values;
import com.google.common.base.Function;
public class RandomFeaturesSpout implements IBatchSpout {
private static final long serialVersionUID = -5293861317274377258L;
private int maxBatchSize = 10;
private int featureSize = 2;
private double variance = 3.0;
private boolean withLabel = true;
private final static Function<double[], Boolean> FEATURES_TO_LABEL = new Function<double[], Boolean>() {
@Override
public Boolean apply(double[] input) {
double sum = 0;
for (int i = 0; i < input.length; i++) {
sum += input[i];
}
return sum > 0;
}
};
private Random random = new Random();
public RandomFeaturesSpout() {
}
public RandomFeaturesSpout(boolean withLabel) {
this.withLabel = withLabel;
}
public RandomFeaturesSpout(int featureSize, double variance) {
this.featureSize = featureSize;
this.variance = variance;
}
public RandomFeaturesSpout(boolean withLabel, int featureSize, double variance) {
this.withLabel = withLabel;
this.featureSize = featureSize;
this.variance = variance;
}
@SuppressWarnings("rawtypes")
@Override
public void open(Map conf, TopologyContext context) {
}
@Override
public void emitBatch(long batchId, TridentCollector collector) {
for (int i = 0; i < this.maxBatchSize; i++) {
Values values = new Values();
double[] features = new double[this.featureSize];
for (int j = 0; j < this.featureSize; j++) {
features[j] = j + this.random.nextGaussian() * this.variance;
}
if (this.withLabel) {
values.add(FEATURES_TO_LABEL.apply(features));
}
for (double feature : features) {
values.add(feature);
}
collector.emit(values);
}
}
@Override
public void ack(long batchId) {
}
@Override
public void close() {
}
@SuppressWarnings("rawtypes")
@Override
public Map getComponentConfiguration() {
return null;
}
@Override
public Fields getOutputFields() {
String[] fieldNames;
if (this.withLabel) {
fieldNames = new String[this.featureSize + 1];
fieldNames[0] = "label";
for (int i = 0; i < this.featureSize; i++) {
fieldNames[i + 1] = "x" + i;
}
} else {
fieldNames = new String[this.featureSize];
for (int i = 0; i < this.featureSize; i++) {
fieldNames[i] = "x" + i;
}
}
return new Fields(fieldNames);
}
public int getMaxBatchSize() {
return maxBatchSize;
}
public void setMaxBatchSize(int maxBatchSize) {
this.maxBatchSize = maxBatchSize;
}
public int getFeatureSize() {
return featureSize;
}
public void setFeatureSize(int featureSize) {
this.featureSize = featureSize;
}
public double getVariance() {
return variance;
}
public void setVariance(double variance) {
this.variance = variance;
}
public boolean isWithLabel() {
return withLabel;
}
public void setWithLabel(boolean withLabel) {
this.withLabel = withLabel;
}
public Random getRandom() {
return random;
}
public void setRandom(Random random) {
this.random = random;
}
}