/*
* Created on Oct 25, 2006 Copyright (C) 2001-6, Anthony Harrison anh23@pitt.edu
* (jactr.org) This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of the License,
* or (at your option) any later version. This library is distributed in the
* hope that it will be useful, but WITHOUT ANY WARRANTY; without even the
* implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See
* the GNU Lesser General Public License for more details. You should have
* received a copy of the GNU Lesser General Public License along with this
* library; if not, write to the Free Software Foundation, Inc., 59 Temple
* Place, Suite 330, Boston, MA 02111-1307 USA
*/
package org.jactr.core.module.procedural.six.learning;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.Executor;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jactr.core.event.ACTREventDispatcher;
import org.jactr.core.logging.Logger;
import org.jactr.core.model.IModel;
import org.jactr.core.module.AbstractModule;
import org.jactr.core.module.procedural.IProceduralModule;
import org.jactr.core.module.procedural.event.ProceduralModuleEvent;
import org.jactr.core.module.procedural.event.ProceduralModuleListenerAdaptor;
import org.jactr.core.module.procedural.five.learning.IProductionCompiler;
import org.jactr.core.module.procedural.six.DefaultProceduralModule6;
import org.jactr.core.module.procedural.six.learning.event.IProceduralLearningModule6Listener;
import org.jactr.core.module.procedural.six.learning.event.ProceduralLearningEvent;
import org.jactr.core.production.IInstantiation;
import org.jactr.core.production.IProduction;
import org.jactr.core.production.action.IAction;
import org.jactr.core.production.action.IBufferAction;
import org.jactr.core.production.condition.IBufferCondition;
import org.jactr.core.production.condition.ICondition;
import org.jactr.core.production.six.ISubsymbolicProduction6;
import org.jactr.core.utils.parameter.IParameterized;
import org.jactr.core.utils.parameter.ParameterHandler;
/**
* production learning is accomplished by listening to the procedural module for
* firing events..<br>
* <br>
* This module will update a production's 'ExpectedUtility' based on reward
* signals (if any), and it's Utility. A production may have a 'Reward' value
* which is the reward applied when that production fires and back propogates a
* discounted reward signal all the way back to the most recently rewarded
* production. A 'Reward' value of NaN/'default' is the default and merely marks
* this production as participating in the reward process, but not to start it.
* 'Reward' of 'skip'/-Infinity will allow the production to be skipped during
* reward, or 'stop'/+Inf will permit the production to terminate the reward
* sequence.
*
* @see http://jactr.org/node/67
* @author developer
*/
public class DefaultProceduralLearningModule6 extends AbstractModule implements
IProceduralLearningModule6, IParameterized
{
static public final double SKIP_REWARD = Double.NEGATIVE_INFINITY;
static public final double STOP_REWARD = Double.POSITIVE_INFINITY;
static public final double PARTICIPATE = Double.NaN;
static public final String INCLUDE_BUFFERS_PARAM = "IncludeBuffers";
static final public String PRODUCTION_COMPILER_PARAM = "ProductionCompiler";
/**
* logger definition
*/
static final Log LOGGER = LogFactory
.getLog(DefaultProceduralLearningModule6.class);
protected boolean _productionCompilationEnabled = false;
protected double _parameterLearningRate = Double.NaN;
protected int _optimizationLevel = 0;
protected IExpectedUtilityEquation _utilityEquation;
protected IProductionCompiler _productionCompiler;
protected SortedMap<Double, IProduction> _firedProductions;
/**
* used to track new productions. Keyed on the first parent, if the first
* parent is ever part of the conflict set, but one of its children isn't,
* that child is invalid and should be removed
*/
private Map<IProduction, IProduction> _kindergarden;
private IProduction _justFired;
private IProduction _oneBack;
private IProceduralModule _proceduralModule;
private ACTREventDispatcher<IProceduralLearningModule6, IProceduralLearningModule6Listener> _dispatcher = new ACTREventDispatcher<IProceduralLearningModule6, IProceduralLearningModule6Listener>();
/**
* we only credit productions that reference one of these buffers on the LHS
* or RHS
*/
private Set<String> _includeBuffers;
public DefaultProceduralLearningModule6()
{
super("ProceduralLearningV6");
_firedProductions = new TreeMap<Double, IProduction>();
_kindergarden = new HashMap<IProduction, IProduction>();
_includeBuffers = new TreeSet<String>();
}
public boolean isProductionCompilationEnabled()
{
return _productionCompilationEnabled && _productionCompiler != null;
}
public void setProductionCompilationEnabled(boolean enabled)
{
_productionCompilationEnabled = enabled;
if (enabled = true && _productionCompiler == null)
_productionCompiler = new DefaultProductionCompiler6();
}
public double getParameterLearning()
{
return _parameterLearningRate;
}
public boolean isParameterLearningEnabled()
{
return !Double.isNaN(_parameterLearningRate);
}
public void setParameterLearning(double rate)
{
_parameterLearningRate = rate;
}
public boolean isLearningEnabled()
{
return isParameterLearningEnabled();
}
@Override
public void initialize()
{
/*
* use default set of buffer includes
*/
if (_includeBuffers.size() == 0)
{
_includeBuffers.add("goal");
_includeBuffers.add("retrieval");
_includeBuffers.add("imaginal");
}
_proceduralModule = getModel().getProceduralModule();
_proceduralModule.addListener(new ProceduralModuleListenerAdaptor() {
@Override
public void conflictSetAssembled(ProceduralModuleEvent pme)
{
/**
* check through the productions that might fire, if any are parents,
* check to see if their children are represented too...
*/
if (_kindergarden.size() == 0) return;
HashSet<IProduction> masterList = new HashSet<IProduction>();
for (IProduction production : pme.getProductions())
masterList.add(((IInstantiation) production).getProduction());
IModel model = getModel();
for (IProduction production : masterList)
if (_kindergarden.containsKey(production))
{
IProduction child = _kindergarden.get(production);
if (child != null && !masterList.contains(child))
{
String msg = production
+ " produced "
+ child
+ " who could not be instantiated with its parent. Bad children shall be killed!";
if (Logger.hasLoggers(model))
Logger.log(model, Logger.Stream.PROCEDURAL, msg);
if (LOGGER.isDebugEnabled()) LOGGER.debug(msg);
((DefaultProceduralModule6) model.getProceduralModule())
.removeProduction(child);
}
_kindergarden.remove(production);
}
}
@Override
public void productionFired(ProceduralModuleEvent pme)
{
IInstantiation instantiation = (IInstantiation) pme.getProduction();
DefaultProceduralLearningModule6.this.productionFired(
instantiation.getProduction(), pme.getSimulationTime());
if (isProductionCompilationEnabled())
{
IProduction newProduction = getProductionCompiler().productionFired(
instantiation, pme.getSource());
if (newProduction != null)
{
/**
* add the new production
*/
pme.getSource().addProduction(newProduction);
/**
* hop back two productions to get its initial parent
*/
if (_oneBack != null) _kindergarden.put(_oneBack, newProduction);
}
}
}
}, getExecutor());
}
protected void productionFired(IProduction production, double when)
{
_oneBack = _justFired;
_justFired = production;
if (isParameterLearningEnabled())
{
_firedProductions.put(when, production);
/*
* if the production has a reward value that is finite, we start the
* rewarding
*/
double reward = ((ISubsymbolicProduction6) production
.getSubsymbolicProduction()).getReward();
if (!Double.isNaN(reward) && Double.isFinite(reward)) reward(reward);
}
}
public IProductionCompiler getProductionCompiler()
{
return _productionCompiler;
}
public void setProductionCompiler(IProductionCompiler compiler)
{
_productionCompiler = compiler;
}
public void setExpectedUtilityEquation(IExpectedUtilityEquation equation)
{
_utilityEquation = equation;
}
public IExpectedUtilityEquation getExpectedUtilityEquation()
{
return _utilityEquation;
}
public int getOptimizationLevel()
{
return _optimizationLevel;
}
public void setOptimizationLevel(int level)
{
_optimizationLevel = level;
}
public void reward(double initialReward)
{
IModel model = getModel();
boolean log = LOGGER.isDebugEnabled() || Logger.hasLoggers(model);
if (log)
{
String msg = "Rewarding " + _firedProductions.size() + " productions by "
+ initialReward;
LOGGER.debug(msg);
Logger.log(model, Logger.Stream.PROCEDURAL, msg);
}
// double now = ACTRRuntime.getRuntime().getClock(model).getTime();
double now = model.getAge();
IExpectedUtilityEquation equation = getExpectedUtilityEquation();
if (_dispatcher.hasListeners())
_dispatcher.fire(new ProceduralLearningEvent(this,
ProceduralLearningEvent.Type.START_REWARDING, initialReward));
try
{
for (Map.Entry<Double, IProduction> entry : _firedProductions.entrySet())
{
double discountedReward = initialReward - (now - entry.getKey());
IProduction p = entry.getValue();
if (!shouldInclude(p))
{
if (log)
{
String msg = String.format(
"Excluding %s from rewarding since it doesn't reference %s", p,
_includeBuffers);
if (LOGGER.isDebugEnabled()) LOGGER.debug(msg);
if (Logger.hasLoggers(model))
Logger.log(model, Logger.Stream.PROCEDURAL, msg);
}
continue;
}
ISubsymbolicProduction6 ssp = (ISubsymbolicProduction6) p
.getSubsymbolicProduction();
double productionsReward = ssp.getReward();
/*
* we only apply the utility learning if the production's reward is a
* discrete or NaN value.
*/
if (Double.isFinite(productionsReward)
|| Double.isNaN(productionsReward))
{
double utility = equation.computeExpectedUtility(p, model,
discountedReward);
if (!(Double.isNaN(utility) || Double.isInfinite(utility)))
ssp.setExpectedUtility(utility);
if (log)
{
String msg = "Discounted reward for " + p + " to "
+ discountedReward + " for a learned utility of " + utility;
if (LOGGER.isDebugEnabled()) LOGGER.debug(msg);
if (Logger.hasLoggers(model))
Logger.log(model, Logger.Stream.PROCEDURAL, msg);
}
if (_dispatcher.hasListeners())
_dispatcher.fire(new ProceduralLearningEvent(this, p,
discountedReward));
}
else if (productionsReward < 0) // negative inf, skip
{
if(log)
{
String msg = String.format("Skipping rewarding of %s",p);
if (LOGGER.isDebugEnabled()) LOGGER.debug(msg);
if (Logger.hasLoggers(model))
Logger.log(model, Logger.Stream.PROCEDURAL, msg);
}
continue; //skip
}
else
// pos inf, stop
{
if(log)
{
String msg = String.format("Stopping reward crediation at %s",p);
if (LOGGER.isDebugEnabled()) LOGGER.debug(msg);
if (Logger.hasLoggers(model))
Logger.log(model, Logger.Stream.PROCEDURAL, msg);
}
break; //stop entirely
}
}
}
finally
{
if (_dispatcher.hasListeners())
_dispatcher.fire(new ProceduralLearningEvent(this,
ProceduralLearningEvent.Type.END_REWARDING, initialReward));
_firedProductions.clear();
}
}
protected boolean shouldInclude(IProduction production)
{
for (ICondition condition : production.getSymbolicProduction()
.getConditions())
if (condition instanceof IBufferCondition)
if (_includeBuffers.contains(((IBufferCondition) condition)
.getBufferName())) return true;
for (IAction action : production.getSymbolicProduction().getActions())
if (action instanceof IBufferAction)
if (_includeBuffers.contains(((IBufferAction) action).getBufferName()))
return true;
return false;
}
public String getParameter(String key)
{
if (PARAMETER_LEARNING_RATE.equalsIgnoreCase(key))
return "" + getParameterLearning();
else if (OPTIMIZED_LEARNING.equalsIgnoreCase(key))
return "" + getOptimizationLevel();
else if (PRODUCTION_COMPILATION_PARAM.equalsIgnoreCase(key))
return "" + isProductionCompilationEnabled();
else if (EXPECTED_UTILITY_EQUATION_PARAM.equalsIgnoreCase(key))
return "" + getExpectedUtilityEquation().getClass().getName();
else if (INCLUDE_BUFFERS_PARAM.equalsIgnoreCase(key))
{
StringBuilder sb = new StringBuilder();
for (String bufferName : _includeBuffers)
sb.append(bufferName).append(", ");
if (sb.length() > 2) sb.delete(sb.length() - 2, sb.length());
return sb.toString();
}
return null;
}
public Collection<String> getPossibleParameters()
{
return getSetableParameters();
}
public Collection<String> getSetableParameters()
{
return Arrays.asList(PARAMETER_LEARNING_RATE, OPTIMIZED_LEARNING,
PRODUCTION_COMPILATION_PARAM, EXPECTED_UTILITY_EQUATION_PARAM,
INCLUDE_BUFFERS_PARAM, PRODUCTION_COMPILER_PARAM);
}
public void setParameter(String key, String value)
{
if (EXPECTED_UTILITY_EQUATION_PARAM.equalsIgnoreCase(key))
try
{
setExpectedUtilityEquation((IExpectedUtilityEquation) ParameterHandler
.classInstance().coerce(value).newInstance());
}
catch (Exception e)
{
if (LOGGER.isWarnEnabled())
LOGGER.warn(String.format("Could not instantiate %s, using default",
value));
setExpectedUtilityEquation(new DefaultExpectedUtilityEquation());
}
else if (PRODUCTION_COMPILER_PARAM.equalsIgnoreCase(key))
try
{
setProductionCompiler((IProductionCompiler) ParameterHandler
.classInstance().coerce(value).newInstance());
}
catch (Exception e)
{
if (LOGGER.isWarnEnabled())
LOGGER.warn(String.format("Could not instantiate %s, using default",
value));
setProductionCompiler(new DefaultProductionCompiler6());
}
else if (PARAMETER_LEARNING_RATE.equalsIgnoreCase(key))
setParameterLearning(ParameterHandler.numberInstance().coerce(value)
.doubleValue());
else if (OPTIMIZED_LEARNING.equalsIgnoreCase(key))
setOptimizationLevel(ParameterHandler.numberInstance().coerce(value)
.intValue());
else if (PRODUCTION_COMPILATION_PARAM.equalsIgnoreCase(key))
setProductionCompilationEnabled(ParameterHandler.booleanInstance()
.coerce(value).booleanValue());
else if (INCLUDE_BUFFERS_PARAM.equalsIgnoreCase(key))
{
String[] buffers = value.split(",");
for (String bufferName : buffers)
{
bufferName = bufferName.trim().toLowerCase();
if (bufferName.length() == 0) continue;
_includeBuffers.add(bufferName);
}
}
else if (LOGGER.isWarnEnabled())
LOGGER.warn(String.format(
"%s doesn't recognize %s. Available parameters : %s", getClass()
.getSimpleName(), key, getSetableParameters()));
}
public void addListener(IProceduralLearningModule6Listener listener,
Executor executor)
{
_dispatcher.addListener(listener, executor);
}
public void removeListener(IProceduralLearningModule6Listener listener)
{
_dispatcher.removeListener(listener);
}
public void reset()
{
// noop
}
}