/* * Encog(tm) Core v2.5 - Java Version * http://www.heatonresearch.com/encog/ * http://code.google.com/p/encog-java/ * Copyright 2008-2010 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.persist.persistors; import java.util.HashMap; import java.util.Map; import org.encog.EncogError; import org.encog.engine.network.flat.FlatNetwork; import org.encog.neural.networks.BasicNetwork; import org.encog.neural.networks.layers.BasicLayer; import org.encog.neural.networks.layers.Layer; import org.encog.neural.networks.logic.ART1Logic; import org.encog.neural.networks.logic.BAMLogic; import org.encog.neural.networks.logic.BoltzmannLogic; import org.encog.neural.networks.logic.FeedforwardLogic; import org.encog.neural.networks.logic.HopfieldLogic; import org.encog.neural.networks.logic.NeuralLogic; import org.encog.neural.networks.logic.SimpleRecurrentLogic; import org.encog.neural.networks.synapse.Synapse; import org.encog.parse.tags.read.ReadXML; import org.encog.parse.tags.write.WriteXML; import org.encog.persist.EncogPersistedCollection; import org.encog.persist.EncogPersistedObject; import org.encog.persist.Persistor; import org.encog.util.csv.CSVFormat; import org.encog.util.csv.NumberList; /** * The Encog persistor used to persist the BasicNetwork class. * * @author jheaton */ public class BasicNetworkPersistor implements Persistor { /** * The layers tag. */ public static final String TAG_LAYERS = "layers"; /** * The synapses tag. */ public static final String TAG_SYNAPSES = "synapses"; /** * The synapse tag. */ public static final String TAG_SYNAPSE = "synapse"; /** * The properties tag. */ public static final String TAG_PROPERTIES = "properties"; public static final String TAG_OUTPUT = "layerOutput"; /** * The tags tag. */ public static final String TAG_TAGS = "tags"; /** * The tag tag. */ public static final String TAG_TAG = "tag"; /** * The logic tag. */ public static final String TAG_LOGIC = "logic"; /** * The layer synapse. */ public static final String TAG_LAYER = "layer"; /** * The property tag. */ public static final String TAG_PROPERTY = "Property"; /** * The id attribute. */ public static final String ATTRIBUTE_ID = "id"; /** * The name attribute. */ public static final String ATTRIBUTE_NAME = "name"; /** * The value attribute. */ public static final String ATTRIBUTE_VALUE = "value"; /** * The type attribute. */ public static final String ATTRIBUTE_TYPE = "type"; /** * The input layer type. */ public static final String ATTRIBUTE_TYPE_INPUT = "input"; /** * The output layer type. */ public static final String ATTRIBUTE_TYPE_OUTPUT = "output"; /** * The hidden layer type. */ public static final String ATTRIBUTE_TYPE_HIDDEN = "hidden"; /** * The both layer type. */ public static final String ATTRIBUTE_TYPE_BOTH = "both"; /** * The unknown layer type. */ public static final String ATTRIBUTE_TYPE_UNKNOWN = "unknown"; /** * The from attribute. */ public static final String ATTRIBUTE_FROM = "from"; /** * The to attribute. */ public static final String ATTRIBUTE_TO = "to"; /** * The to attribute. */ public static final String ATTRIBUTE_LAYER = "layer"; /** * The network that is being loaded. */ private BasicNetwork currentNetwork; /** * A mapping from layers to index numbers. */ private final Map<Layer, Integer> layer2index = new HashMap<Layer, Integer>(); /** * A mapping from index numbers to layers. */ private final Map<Integer, Layer> index2layer = new HashMap<Integer, Layer>(); /** * Handle any layers that should be loaded. * * @param in * The XML reader. */ private void handleLayers(final ReadXML in) { final String end = in.getTag().getName(); while (in.readToTag()) { if (in.is(BasicNetworkPersistor.TAG_LAYER, true)) { final int num = in.getTag().getAttributeInt( BasicNetworkPersistor.ATTRIBUTE_ID); final String type = in.getTag().getAttributeValue( BasicNetworkPersistor.ATTRIBUTE_TYPE); in.readToTag(); final Persistor persistor = PersistorUtil.createPersistor(in .getTag().getName()); final Layer layer = (Layer) persistor.load(in); this.index2layer.put(num, layer); layer.setID(num); // the type attribute is actually "legacy", but if its there // then use it! if (type != null) { if (type.equals( BasicNetworkPersistor.ATTRIBUTE_TYPE_INPUT)) { this.currentNetwork.tagLayer(BasicNetwork.TAG_INPUT, layer); } else if (type.equals( BasicNetworkPersistor.ATTRIBUTE_TYPE_OUTPUT)) { this.currentNetwork.tagLayer(BasicNetwork.TAG_OUTPUT, layer); } else if (type.equals( BasicNetworkPersistor.ATTRIBUTE_TYPE_BOTH)) { this.currentNetwork.tagLayer(BasicNetwork.TAG_INPUT, layer); this.currentNetwork.tagLayer(BasicNetwork.TAG_OUTPUT, layer); } } // end of legacy processing } if (in.is(end, false)) { break; } } } /** * Handle reading the neural logic information. * @param in The object to read XML from. */ private void handleLogic(final ReadXML in) { final String value = in.readTextToTag(); if (value.equalsIgnoreCase("ART1Logic")) { this.currentNetwork.setLogic(new ART1Logic()); } else if (value.equalsIgnoreCase("BAMLogic")) { this.currentNetwork.setLogic(new BAMLogic()); } else if (value.equalsIgnoreCase("BoltzmannLogic")) { this.currentNetwork.setLogic(new BoltzmannLogic()); } else if (value.equalsIgnoreCase("FeedforwardLogic")) { this.currentNetwork.setLogic(new FeedforwardLogic()); } else if (value.equalsIgnoreCase("HopfieldLogic")) { this.currentNetwork.setLogic(new HopfieldLogic()); } else if (value.equalsIgnoreCase("SimpleRecurrentLogic")) { this.currentNetwork.setLogic(new SimpleRecurrentLogic()); } else { try { final NeuralLogic logic = (NeuralLogic) Class.forName(value) .newInstance(); this.currentNetwork.setLogic(logic); } catch (final ClassNotFoundException e) { throw new EncogError(e); } catch (final InstantiationException e) { throw new EncogError(e); } catch (final IllegalAccessException e) { throw new EncogError(e); } } } /** * Handle reading network properties. * @param in Where to read network properties from. */ private void handleProperties(final ReadXML in) { final String end = in.getTag().getName(); while (in.readToTag()) { if (in.is(BasicNetworkPersistor.TAG_PROPERTY, true)) { final String name = in.getTag().getAttributeValue( BasicNetworkPersistor.ATTRIBUTE_NAME); final String value = in.readTextToTag(); this.currentNetwork.setProperty(name, value); } if (in.is(end, false)) { break; } } } /** * Process any synapses that should be loaded. * * @param in * The XML reader. */ private void handleSynapses(final ReadXML in) { final String end = in.getTag().getName(); while (in.readToTag()) { if (in.is(BasicNetworkPersistor.TAG_SYNAPSE, true)) { final int from = in.getTag().getAttributeInt( BasicNetworkPersistor.ATTRIBUTE_FROM); final int to = in.getTag().getAttributeInt( BasicNetworkPersistor.ATTRIBUTE_TO); in.readToTag(); final Persistor persistor = PersistorUtil.createPersistor(in .getTag().getName()); final Synapse synapse = (Synapse) persistor.load(in); synapse.setFromLayer(this.index2layer.get(from)); synapse.setToLayer(this.index2layer.get(to)); synapse.getFromLayer().addSynapse(synapse); } if (in.is(end, false)) { break; } } } /** * Handle reading neural network tags. * @param in Where to read tag XML from. */ private void handleTags(final ReadXML in) { final String end = in.getTag().getName(); while (in.readToTag()) { if (in.is(BasicNetworkPersistor.TAG_TAG, true)) { final String name = in.getTag().getAttributeValue( BasicNetworkPersistor.ATTRIBUTE_NAME); final String layerStr = in.getTag().getAttributeValue( BasicNetworkPersistor.ATTRIBUTE_LAYER); final int layerInt = Integer.parseInt(layerStr); final Layer layer = this.index2layer.get(layerInt); this.currentNetwork.tagLayer(name, layer); in.readToTag(); } if (in.is(end, false)) { break; } } } /** * Load the specified Encog object from an XML reader. * * @param in * The XML reader to use. * @return The loaded object. */ public EncogPersistedObject load(final ReadXML in) { double[] output = null; final String name = in.getTag().getAttributes().get( EncogPersistedCollection.ATTRIBUTE_NAME); final String description = in.getTag().getAttributes().get( EncogPersistedCollection.ATTRIBUTE_DESCRIPTION); this.currentNetwork = new BasicNetwork(); this.currentNetwork.setName(name); this.currentNetwork.setDescription(description); while (in.readToTag()) { if (in.is(BasicNetworkPersistor.TAG_LAYERS, true)) { handleLayers(in); } else if (in.is(BasicNetworkPersistor.TAG_SYNAPSES, true)) { handleSynapses(in); } else if (in.is(BasicNetworkPersistor.TAG_PROPERTIES, true)) { handleProperties(in); } else if (in.is(BasicNetworkPersistor.TAG_LOGIC, true)) { handleLogic(in); } else if (in.is(BasicNetworkPersistor.TAG_TAGS, true)) { handleTags(in); } else if (in.is(BasicNetworkPersistor.TAG_OUTPUT, true)) { output = handleOutput(in); } else if (in.is(EncogPersistedCollection.TYPE_BASIC_NET, false)) { break; } } this.currentNetwork.getStructure().finalizeStructure(); return this.currentNetwork; } /** * Save the specified Encog object to an XML writer. * * @param obj * The object to save. * @param out * The XML writer to save to. */ public void save(final EncogPersistedObject obj, final WriteXML out) { PersistorUtil.beginEncogObject(EncogPersistedCollection.TYPE_BASIC_NET, out, obj, true); this.currentNetwork = (BasicNetwork) obj; this.currentNetwork.getStructure().finalizeStructure(); // save the layers out.beginTag(BasicNetworkPersistor.TAG_LAYERS); saveLayers(out); out.endTag(); // save the structure of these layers out.beginTag(BasicNetworkPersistor.TAG_SYNAPSES); saveSynapses(out); out.endTag(); saveProperties(out); saveTags(out); saveLogic(out); saveOutput(out); out.endTag(); } /** * Save the layers to the specified XML writer. * * @param out * The XML writer. */ private void saveLayers(final WriteXML out) { for (final Layer layer : this.currentNetwork.getStructure().getLayers()) { out.addAttribute(BasicNetworkPersistor.ATTRIBUTE_ID, "" + layer.getID()); out.beginTag(BasicNetworkPersistor.TAG_LAYER); final Persistor persistor = layer.createPersistor(); persistor.save(layer, out); out.endTag(); this.layer2index.put(layer, layer.getID()); } } /** * Save the neural logic. * @param out The output stream. */ private void saveLogic(final WriteXML out) { out.beginTag(BasicNetworkPersistor.TAG_LOGIC); final NeuralLogic logic = this.currentNetwork.getLogic(); if ((logic instanceof FeedforwardLogic) || (logic instanceof SimpleRecurrentLogic) || (logic instanceof BoltzmannLogic) || (logic instanceof ART1Logic) || (logic instanceof BAMLogic) || (logic instanceof HopfieldLogic)) { out.addText(logic.getClass().getSimpleName()); } else { out.addText(logic.getClass().getName()); } out.endTag(); } /** * Save the network properties. * @param out The object to write XML to. */ private void saveProperties(final WriteXML out) { // save any properties out.beginTag(BasicNetworkPersistor.TAG_PROPERTIES); for (final String key : this.currentNetwork.getProperties().keySet()) { final String value = this.currentNetwork.getProperties().get(key); out.addAttribute(BasicNetworkPersistor.ATTRIBUTE_NAME, key); out.beginTag(BasicNetworkPersistor.TAG_PROPERTY); out.addText(value.toString()); out.endTag(); } out.endTag(); } /** * Save the synapses to the specified XML writer. * * @param out * The XML writer. */ private void saveSynapses(final WriteXML out) { for (final Synapse synapse : this.currentNetwork.getStructure() .getSynapses()) { out.addAttribute(BasicNetworkPersistor.ATTRIBUTE_FROM, "" + this.layer2index.get(synapse.getFromLayer())); out.addAttribute(BasicNetworkPersistor.ATTRIBUTE_TO, "" + this.layer2index.get(synapse.getToLayer())); out.beginTag(BasicNetworkPersistor.TAG_SYNAPSE); final Persistor persistor = synapse.createPersistor(); persistor.save(synapse, out); out.endTag(); } } /** * Save the tags. * @param out The writer to save the tags to. */ private void saveTags(final WriteXML out) { // save any properties out.beginTag(BasicNetworkPersistor.TAG_TAGS); for (final String key : this.currentNetwork.getLayerTags().keySet()) { final Layer value = this.currentNetwork.getLayerTags().get(key); out.addAttribute(BasicNetworkPersistor.ATTRIBUTE_NAME, key); out.addAttribute(BasicNetworkPersistor.ATTRIBUTE_LAYER, "" + this.layer2index.get(value)); out.beginTag(BasicNetworkPersistor.TAG_TAG); out.endTag(); } out.endTag(); } private void saveOutput(final WriteXML out) { FlatNetwork flat = this.currentNetwork.getStructure().getFlat(); if (flat != null) { out.beginTag(BasicNetworkPersistor.TAG_OUTPUT); final StringBuilder result = new StringBuilder(); NumberList.toList(CSVFormat.EG_FORMAT, result, flat.getLayerOutput()); out.addProperty(BasicLayerPersistor.PROPERTY_THRESHOLD, result.toString()); out.endTag(); } } private double[] handleOutput(final ReadXML in) { String output = in.readTextToTag(); return NumberList.fromList(CSVFormat.EG_FORMAT, output); } }