/* * Created on Oct 25, 2006 Copyright (C) 2001-6, Anthony Harrison anh23@pitt.edu * (jactr.org) This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of the License, * or (at your option) any later version. This library is distributed in the * hope that it will be useful, but WITHOUT ANY WARRANTY; without even the * implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See * the GNU Lesser General Public License for more details. You should have * received a copy of the GNU Lesser General Public License along with this * library; if not, write to the Free Software Foundation, Inc., 59 Temple * Place, Suite 330, Boston, MA 02111-1307 USA */ package org.jactr.core.module.procedural.six.learning; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.SortedMap; import java.util.TreeMap; import java.util.concurrent.Executor; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.jactr.core.event.ACTREventDispatcher; import org.jactr.core.logging.Logger; import org.jactr.core.model.IModel; import org.jactr.core.module.AbstractModule; import org.jactr.core.module.procedural.IProceduralModule; import org.jactr.core.module.procedural.event.ProceduralModuleEvent; import org.jactr.core.module.procedural.event.ProceduralModuleListenerAdaptor; import org.jactr.core.module.procedural.five.learning.IProductionCompiler; import org.jactr.core.module.procedural.six.DefaultProceduralModule6; import org.jactr.core.module.procedural.six.learning.event.IProceduralLearningModule6Listener; import org.jactr.core.module.procedural.six.learning.event.ProceduralLearningEvent; import org.jactr.core.production.IInstantiation; import org.jactr.core.production.IProduction; import org.jactr.core.production.six.ISubsymbolicProduction6; import org.jactr.core.runtime.ACTRRuntime; import org.jactr.core.utils.parameter.IParameterized; import org.jactr.core.utils.parameter.ParameterHandler; /** * production learning is accomplished by listening to the procedural module for * firing events..<br> * <br> * * @author developer */ public class DefaultProceduralLearningModule6 extends AbstractModule implements IProceduralLearningModule6, IParameterized { /** * logger definition */ static private final Log LOGGER = LogFactory .getLog(DefaultProceduralLearningModule6.class); protected boolean _productionCompilationEnabled = false; protected double _parameterLearningRate = Double.NaN; protected int _optimizationLevel = 0; protected IExpectedUtilityEquation _utilityEquation; protected IProductionCompiler _productionCompiler; protected SortedMap<Double, IProduction> _firedProductions; /** * used to track new productions. Keyed on the first parent, if the first * parent is ever part of the conflict set, but one of its children isn't, * that child is invalid and should be removed */ private Map<IProduction, IProduction> _kindergarden; private IProduction _justFired; private IProduction _oneBack; private IProceduralModule _proceduralModule; private ACTREventDispatcher<IProceduralLearningModule6, IProceduralLearningModule6Listener> _dispatcher = new ACTREventDispatcher<IProceduralLearningModule6, IProceduralLearningModule6Listener>(); public DefaultProceduralLearningModule6() { super("ProceduralLearningV6"); _firedProductions = new TreeMap<Double, IProduction>(); _kindergarden = new HashMap<IProduction, IProduction>(); } public boolean isProductionCompilationEnabled() { return _productionCompilationEnabled && _productionCompiler != null; } public void setProductionCompilationEnabled(boolean enabled) { _productionCompilationEnabled = enabled; } public double getParameterLearning() { return _parameterLearningRate; } public boolean isParameterLearningEnabled() { return !Double.isNaN(_parameterLearningRate); } public void setParameterLearning(double rate) { _parameterLearningRate = rate; } public boolean isLearningEnabled() { return isParameterLearningEnabled(); } @Override public void initialize() { setExpectedUtilityEquation(this.new DefaultExpectedUtilityEquation()); _proceduralModule = getModel().getProceduralModule(); _proceduralModule.addListener(new ProceduralModuleListenerAdaptor() { @Override public void conflictSetAssembled(ProceduralModuleEvent pme) { /** * check through the productions that might fire, if any are parents, * check to see if their children are represented too... */ if (_kindergarden.size() == 0) return; HashSet<IProduction> masterList = new HashSet<IProduction>(); for (IProduction production : pme.getProductions()) masterList.add(((IInstantiation) production).getProduction()); IModel model = getModel(); for (IProduction production : masterList) if (_kindergarden.containsKey(production)) { IProduction child = _kindergarden.get(production); if (child != null && !masterList.contains(child)) { String msg = production + " produced " + child + " who could not be instantiated with its parent. Bad children shall be killed!"; if (Logger.hasLoggers(model)) Logger.log(model, Logger.Stream.PROCEDURAL, msg); if (LOGGER.isDebugEnabled()) LOGGER.debug(msg); ((DefaultProceduralModule6) model.getProceduralModule()) .removeProduction(child); } _kindergarden.remove(production); } } @Override public void productionFired(ProceduralModuleEvent pme) { IInstantiation instantiation = (IInstantiation) pme.getProduction(); DefaultProceduralLearningModule6.this.productionFired(instantiation .getProduction(), pme.getSimulationTime()); if (isProductionCompilationEnabled()) { IProduction newProduction = getProductionCompiler().productionFired( instantiation, pme.getSource()); if (newProduction != null) { /** * add the new production */ pme.getSource().addProduction(newProduction); /** * hop back two productions to get its initial parent */ if (_oneBack != null) _kindergarden.put(_oneBack, newProduction); } } } }, getExecutor()); } protected void productionFired(IProduction production, double when) { _oneBack = _justFired; _justFired = production; if (isParameterLearningEnabled()) { _firedProductions.put(when, production); /* * if the production has a reward value that is not nan.. */ double reward = ((ISubsymbolicProduction6) production .getSubsymbolicProduction()).getReward(); if (!Double.isNaN(reward)) { reward(reward); if (_dispatcher.hasListeners()) _dispatcher .fire(new ProceduralLearningEvent(this, production, reward)); } } } public IProductionCompiler getProductionCompiler() { return _productionCompiler; } public void setProductionCompiler(IProductionCompiler compiler) { _productionCompiler = compiler; } public void setExpectedUtilityEquation(IExpectedUtilityEquation equation) { _utilityEquation = equation; } public IExpectedUtilityEquation getExpectedUtilityEquation() { return _utilityEquation; } public int getOptimizationLevel() { return _optimizationLevel; } public void setOptimizationLevel(int level) { _optimizationLevel = level; } public void reward(double initialReward) { IModel model = getModel(); boolean log = LOGGER.isDebugEnabled() || Logger.hasLoggers(model); if (log) { String msg = "Rewarding " + _firedProductions.size() + " productions by " + initialReward; LOGGER.debug(msg); Logger.log(model, Logger.Stream.PROCEDURAL, msg); } double now = ACTRRuntime.getRuntime().getClock(model).getTime(); IExpectedUtilityEquation equation = getExpectedUtilityEquation(); for (Map.Entry<Double, IProduction> entry : _firedProductions.entrySet()) { double discountedReward = initialReward - (now - entry.getKey()); IProduction p = entry.getValue(); ISubsymbolicProduction6 ssp = (ISubsymbolicProduction6) p .getSubsymbolicProduction(); double utility = equation.computeExpectedUtility(p, model, discountedReward); if (!(Double.isNaN(utility) || Double.isInfinite(utility))) ssp.setExpectedUtility(utility); if (log) { String msg = "Discounted reward for " + p + " to " + discountedReward + " for a learned utility of " + utility; LOGGER.debug(msg); Logger.log(model, Logger.Stream.PROCEDURAL, msg); } } _firedProductions.clear(); } private class DefaultExpectedUtilityEquation implements IExpectedUtilityEquation { public double computeExpectedUtility(IProduction production, IModel model, double reward) { ISubsymbolicProduction6 ssp = (ISubsymbolicProduction6) production .getSubsymbolicProduction(); double previousUtility = ssp.getExpectedUtility(); if (Double.isNaN(previousUtility)) previousUtility = ssp.getUtility(); double partial = 0; if (isParameterLearningEnabled() && !(Double.isNaN(reward) || Double.isInfinite(reward))) partial = getParameterLearning() * (reward - previousUtility); double utility = previousUtility + partial; if (LOGGER.isDebugEnabled()) LOGGER.debug(production + ".expectedUtility=" + utility + " previous=" + previousUtility + " partial=" + partial + " reward=" + reward + " rate=" + getParameterLearning()); return utility; } } public String getParameter(String key) { if (PARAMETER_LEARNING_RATE.equalsIgnoreCase(key)) return "" + getParameterLearning(); else if (OPTIMIZED_LEARNING.equalsIgnoreCase(key)) return "" + getOptimizationLevel(); else if (PRODUCTION_COMPILATION_PARAM.equalsIgnoreCase(key)) return "" + isProductionCompilationEnabled(); return null; } public Collection<String> getPossibleParameters() { return getSetableParameters(); } public Collection<String> getSetableParameters() { return Arrays.asList(PARAMETER_LEARNING_RATE, OPTIMIZED_LEARNING, PRODUCTION_COMPILATION_PARAM); } public void setParameter(String key, String value) { if (PARAMETER_LEARNING_RATE.equalsIgnoreCase(key)) setParameterLearning(ParameterHandler.numberInstance().coerce(value) .doubleValue()); else if (OPTIMIZED_LEARNING.equalsIgnoreCase(key)) setOptimizationLevel(ParameterHandler.numberInstance().coerce(value) .intValue()); else if (PRODUCTION_COMPILATION_PARAM.equalsIgnoreCase(key)) setProductionCompilationEnabled(ParameterHandler.booleanInstance() .coerce(value).booleanValue()); else if (LOGGER.isWarnEnabled()) LOGGER.warn("No clue how to set " + key + "=" + value); } public void addListener(IProceduralLearningModule6Listener listener, Executor executor) { _dispatcher.addListener(listener, executor); } public void removeListener(IProceduralLearningModule6Listener listener) { _dispatcher.removeListener(listener); } public void reset() { // noop } }