package org.jactr.tools.utility;
/*
* default logging
*/
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.regex.Pattern;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jactr.core.concurrent.ExecutorServices;
import org.jactr.core.model.IModel;
import org.jactr.core.module.procedural.event.IProceduralModuleListener;
import org.jactr.core.module.procedural.event.ProceduralModuleEvent;
import org.jactr.core.module.procedural.event.ProceduralModuleListenerAdaptor;
import org.jactr.core.production.IInstantiation;
import org.jactr.core.production.IProduction;
import org.jactr.core.production.six.ISubsymbolicProduction6;
import org.jactr.core.runtime.ACTRRuntime;
import org.jactr.core.utils.parameter.IParameterized;
import org.jactr.instrument.IInstrument;
/**
* tracks the expected utility of a set of productions over time
*
* @author harrison
*/
public class UtilityTracker implements IInstrument, IParameterized
{
/**
* Logger definition
*/
static private final transient Log LOGGER = LogFactory
.getLog(UtilityTracker.class);
static public final String FILE_NAME_PARAM = "FileName";
static public final String PATTERN_PARAM = "Pattern";
static public final String TRACK_INSTANTIATIONS_PARAM = "TrackInstantiations";
private Map<IProduction, Integer> _productionColumnIndices;
private Collection<Pattern> _productionNamePatterns;
private IProceduralModuleListener _prodListener;
private ArrayList<Double> _utilities;
private File _currentFile;
private String _fileName;
private PrintWriter _outputStream;
private IModel _attachedTo;
/**
* log the production's utility (sans noise) or the instantiation's utility
* (with noise)
*/
private boolean _logInstantiationUtility = true;
public UtilityTracker()
{
_productionColumnIndices = new HashMap<IProduction, Integer>();
_productionNamePatterns = new ArrayList<Pattern>();
_utilities = new ArrayList<Double>();
_utilities.add(null); // for time
_prodListener = new ProceduralModuleListenerAdaptor() {
private final SortedMap<String, IProduction> _sorter = new TreeMap<String, IProduction>();
public void conflictSetAssembled(ProceduralModuleEvent pme)
{
for (IProduction instantiation : pme.getProductions())
{
if (!_logInstantiationUtility)
instantiation = ((IInstantiation) instantiation).getProduction();
if (isTracked(instantiation) || matchesPattern(instantiation))
_sorter.put(instantiation.getSymbolicProduction().getName(),
instantiation);
}
/**
* we use the sorter to maintain a consistent (alphabetical) order of
* the productions as they are added, otherwise they will be initially
* tracked in their initial utility order
*/
if (_sorter.size() != 0)
{
for (IProduction production : _sorter.values())
{
if (!isTracked(production)) addTrackedProduction(production);
logUtility(production);
}
flushUtility(pme.getSimulationTime());
_sorter.clear();
}
}
};
}
private boolean isTracked(IProduction production)
{
if (production instanceof IInstantiation)
production = ((IInstantiation) production).getProduction();
return _productionColumnIndices.containsKey(production);
}
private boolean matchesPattern(IProduction production)
{
for (Pattern pattern : _productionNamePatterns)
if (pattern.matcher(production.getSymbolicProduction().getName())
.matches()) return true;
return false;
}
private void addTrackedProduction(IProduction production)
{
if (production instanceof IInstantiation)
production = ((IInstantiation) production).getProduction();
_productionColumnIndices.put(production, _utilities.size());
_utilities.add(null);
if (LOGGER.isDebugEnabled())
LOGGER.debug("Adding production " + production + " tracking "
+ (_utilities.size() - 1));
}
private void logUtility(IProduction production)
{
double utility = ((ISubsymbolicProduction6) production
.getSubsymbolicProduction()).getExpectedUtility();
/**
* perfectly legit if there has been no learning and we're tracking the
* production and not the instantiation
*/
if (Double.isNaN(utility))
utility = ((ISubsymbolicProduction6) production
.getSubsymbolicProduction()).getUtility();
/*
* make sure we get the column index right..
*/
if (production instanceof IInstantiation)
production = ((IInstantiation) production).getProduction();
_utilities.set(_productionColumnIndices.get(production), utility);
}
private void flushUtility(double when)
{
_utilities.set(0, when);
StringBuilder sb = new StringBuilder();
for (int i = 0; i < _utilities.size(); i++)
{
Double out = _utilities.set(i, null);
sb.append((out != null ? out : ""));
sb.append("\t");
}
_outputStream.println(sb.toString());
}
public void initialize()
{
// TODO Auto-generated method stub
}
synchronized public void install(IModel model)
{
if (_attachedTo != null)
{
if (LOGGER.isWarnEnabled())
LOGGER.warn("UtilityTracker is already attached to " + _attachedTo
+ ". Can only be attched to one model.");
return;
}
model.getProceduralModule().addListener(_prodListener,
ExecutorServices.INLINE_EXECUTOR);
try
{
_currentFile = File.createTempFile(_fileName, ".tmp");
_currentFile.deleteOnExit();
_outputStream = new PrintWriter(new BufferedWriter(new FileWriter(
_currentFile)));
}
catch (IOException e)
{
throw new IllegalStateException("Could not create temp file", e);
}
_attachedTo = model;
}
synchronized public void uninstall(IModel model)
{
if(_attachedTo!=model)
return;
_attachedTo = null;
model.getProceduralModule().removeListener(_prodListener);
/*
* now that we are done, merge the file with a header file
*/
_outputStream.flush();
_outputStream.close();
try
{
_outputStream = new PrintWriter(new BufferedWriter(new FileWriter(
new File(ACTRRuntime.getRuntime().getWorkingDirectory(), _fileName))));
/*
* output the header
*/
ArrayList<String> header = new ArrayList<String>();
header.add("Time");
for (int i = 0; i < _productionColumnIndices.size(); i++)
header.add("");
for (Map.Entry<IProduction, Integer> entry : _productionColumnIndices
.entrySet())
header.set(entry.getValue(), entry.getKey().getSymbolicProduction()
.getName());
for (String itm : header)
{
_outputStream.print(itm);
_outputStream.print("\t");
}
_outputStream.println();
BufferedReader reader = new BufferedReader(new FileReader(_currentFile));
while (reader.ready())
_outputStream.println(reader.readLine());
reader.close();
_outputStream.flush();
_outputStream.close();
_currentFile.delete();
}
catch (Exception e)
{
throw new IllegalStateException("Could not generate file output file ", e);
}
}
public String getParameter(String key)
{
return null;
}
public Collection<String> getPossibleParameters()
{
return getSetableParameters();
}
public Collection<String> getSetableParameters()
{
return Arrays.asList(FILE_NAME_PARAM, PATTERN_PARAM,
TRACK_INSTANTIATIONS_PARAM);
}
public void setParameter(String key, String value)
{
if (FILE_NAME_PARAM.equalsIgnoreCase(key))
_fileName = value;
else if (PATTERN_PARAM.equalsIgnoreCase(key))
_productionNamePatterns.add(Pattern.compile(value));
else if (TRACK_INSTANTIATIONS_PARAM.equalsIgnoreCase(key))
_logInstantiationUtility = Boolean.parseBoolean(value);
}
}