/*
* avenir: Predictive analytic based on Hadoop Map Reduce
* Author: Pranab Ghosh
*
* Licensed under the Apache License, Version 2.0 (the "License"); you
* may not use this file except in compliance with the License. You may
* obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. See the License for the specific language governing
* permissions and limitations under the License.
*/
package org.avenir.reinforce;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.chombo.storm.GenericBolt;
import org.chombo.storm.MessageHolder;
import org.chombo.util.ConfigUtility;
import org.chombo.util.Pair;
import redis.clients.jedis.Jedis;
import backtype.storm.task.TopologyContext;
import backtype.storm.tuple.Tuple;
/**
* Reinforcement learner bolt. Any RL algorithm can be used
* @author pranab
*
*/
public class ReinforcementLearnerBolt extends GenericBolt {
private static final long serialVersionUID = 6746219511729480056L;
public static final String EVENT_ID = "eventID";
public static final String ACTION_ID = "actionID";
public static final String ROUND_NUM = "roundNUm";
public static final String REWARD = "reward";
private List<MessageHolder> messages = new ArrayList<MessageHolder>();
private ReinforcementLearner learner = null;
private Jedis jedis;
private String actionQueue;
private ActionWriter actionWriter;
private RewardReader rewardReader;
private static final Logger LOG = Logger.getLogger(ReinforcementLearnerBolt.class);
@Override
public Map<String, Object> getComponentConfiguration() {
// TODO Auto-generated method stub
return null;
}
/* (non-Javadoc)
* @see org.chombo.storm.GenericBolt#intialize(java.util.Map, backtype.storm.task.TopologyContext)
*/
@Override
public void intialize(Map stormConf, TopologyContext context) {
//intialize learner
String learnerType = ConfigUtility.getString(stormConf, "reinforcement.learner.type");
String[] actions = ConfigUtility.getString(stormConf, "reinforcement.learrner.actions").split(",");
Map<String, Object> typedConf = ConfigUtility.toTypedMap(stormConf);
learner = ReinforcementLearnerFactory.create(learnerType, actions, typedConf);
//action output queue
if (ConfigUtility.getString(stormConf, "reinforcement.learrner.action.writer").equals("redis")) {
actionWriter = new RedisActionWriter();
actionWriter.intialize(stormConf);
rewardReader = new RedisRewardReader();
rewardReader.intialize(stormConf);
}
debugOn = ConfigUtility.getBoolean(stormConf,"debug.on", false);
if (debugOn) {
LOG.setLevel(Level.INFO);;
}
messageCountInterval = ConfigUtility.getInt(stormConf,"log.message.count.interval", 100);
LOG.info("debugOn:" + debugOn);
}
/* (non-Javadoc)
* @see org.chombo.storm.GenericBolt#process(backtype.storm.tuple.Tuple)
*/
@Override
public boolean process(Tuple input) {
if (input.contains(ROUND_NUM)) {
//get rewards
List<Pair<String, Integer>> rewards = rewardReader.readRewards();
for (Pair<String, Integer> reward : rewards) {
learner.setReward(reward.getLeft(), reward.getRight());
}
if (debugOn && rewards.size() > 0) {
LOG.info("number of reward data:" + rewards.size() );
}
//select action for next round
String eventID = input.getStringByField(EVENT_ID);
int roundNum = input.getIntegerByField(ROUND_NUM);
Action[] actions = learner.nextActions();
actionWriter.write(eventID, actions);
if (debugOn) {
if (messageCounter % messageCountInterval == 0)
LOG.info("processed event message - message counter:" + messageCounter );
LOG.info("learner stat:" + learner.getStat());
}
} else {
//reward feedback
String action = input.getStringByField(ACTION_ID);
int reward = input.getIntegerByField(REWARD);
learner.setReward(action, reward);
if (debugOn) {
if (messageCounter % messageCountInterval == 0)
LOG.info("processed reward message - message counter:" + messageCounter );
}
}
return true;
}
@Override
public List<MessageHolder> getOutput() {
return null;
}
}