/*
* avenir: Predictive analytic based on Hadoop Map Reduce
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package storm.applications.topology;
import backtype.storm.Config;
import backtype.storm.generated.StormTopology;
import backtype.storm.tuple.Fields;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static storm.applications.constants.ReinforcementLearnerConstants.*;
import storm.applications.bolt.ReinforcementLearnerBolt;
import storm.applications.sink.BaseSink;
import storm.applications.spout.AbstractSpout;
/**
* Builds and submits storm topology for reinforcement learning
* @author pranab
*
*/
public class ReinforcementLearnerTopology extends AbstractTopology {
private static final Logger LOG = LoggerFactory.getLogger(ReinforcementLearnerTopology.class);
private AbstractSpout eventSpout;
private AbstractSpout rewardSpout;
private BaseSink actionSink;
private int eventSpoutThreads;
private int rewardSpoutThreads;
private int learnerThreads;
private int sinkThreads;
public ReinforcementLearnerTopology(String topologyName, Config config) {
super(topologyName, config);
}
@Override
public void initialize() {
eventSpout = loadSpout("event");
rewardSpout = loadSpout("reward");
actionSink = loadSink();
eventSpoutThreads = config.getInt(getConfigKey(Conf.SPOUT_THREADS, "event"), 1);
rewardSpoutThreads = config.getInt(getConfigKey(Conf.SPOUT_THREADS, "reward"), 1);
learnerThreads = config.getInt(Conf.LEARNER_THREADS, 1);
sinkThreads = config.getInt(getConfigKey(Conf.SINK_THREADS), 1);
}
@Override
public StormTopology buildTopology() {
eventSpout.setFields(new Fields(Field.EVENT_ID, Field.ROUND_NUM));
rewardSpout.setFields(new Fields(Field.ACTION_ID, Field.REWARD));
builder.setSpout(Component.EVENT_SPOUT, eventSpout, eventSpoutThreads);
builder.setSpout(Component.REWARD_SPOUT, rewardSpout, rewardSpoutThreads);
builder.setBolt(Component.LEARNER, new ReinforcementLearnerBolt(), learnerThreads)
.shuffleGrouping(Component.EVENT_SPOUT)
.allGrouping(Component.REWARD_SPOUT);
builder.setBolt(Component.SINK, actionSink, sinkThreads)
.shuffleGrouping(Component.LEARNER);
return builder.createTopology();
}
@Override
public Logger getLogger() {
return LOG;
}
@Override
public String getConfigPrefix() {
return PREFIX;
}
}