/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.ea.train.basic; import java.util.ArrayList; import java.util.List; import org.encog.ml.CalculateScore; import org.encog.ml.MLMethod; import org.encog.ml.TrainingImplementationType; import org.encog.ml.data.MLDataSet; import org.encog.ml.ea.population.Population; import org.encog.ml.train.MLTrain; import org.encog.ml.train.strategy.Strategy; import org.encog.ml.train.strategy.end.EndTrainingStrategy; import org.encog.neural.networks.training.TrainingSetScore; import org.encog.neural.networks.training.propagation.TrainingContinuation; /** * Provides a MLTrain compatible class that can be used to train genomes. */ public class TrainEA extends BasicEA implements MLTrain { /** * The serial ID. */ private static final long serialVersionUID = 1L; /** * The training strategies to use. */ private final List<Strategy> strategies = new ArrayList<Strategy>(); /** * Create a trainer for a score function. * @param thePopulation The population. * @param theScoreFunction The score function. */ public TrainEA(Population thePopulation, CalculateScore theScoreFunction) { super(thePopulation, theScoreFunction); } /** * Create a trainer for training data. * @param thePopulation The population. * @param trainingData The training data. */ public TrainEA(Population thePopulation, MLDataSet trainingData) { super(thePopulation, new TrainingSetScore(trainingData)); } /** * Not used. * * @param error * Not used. */ @Override public void setError(final double error) { } /** * @return True if training can progress no further. */ public boolean isTrainingDone() { for (Strategy strategy : this.strategies) { if (strategy instanceof EndTrainingStrategy) { EndTrainingStrategy end = (EndTrainingStrategy)strategy; if( end.shouldStop() ) { return true; } } } return false; } /** * {@inheritDoc} */ @Override public TrainingImplementationType getImplementationType() { return TrainingImplementationType.Iterative; } /** * Perform the specified number of training iterations. This is a basic * implementation that just calls iteration the specified number of times. * However, some training methods, particularly with the GPU, benefit * greatly by calling with higher numbers than 1. * * @param count * The number of training iterations. */ @Override public void iteration(final int count) { for (int i = 0; i < count; i++) { iteration(); } } /** * {@inheritDoc} */ @Override public TrainingContinuation pause() { return null; } /** * {@inheritDoc} */ @Override public void resume(final TrainingContinuation state) { } /** * Training strategies can be added to improve the training results. There * are a number to choose from, and several can be used at once. * * @param strategy * The strategy to add. */ public void addStrategy(final Strategy strategy) { strategy.init(this); this.strategies.add(strategy); } /** * {@inheritDoc} */ @Override public boolean canContinue() { return false; } /** * {@inheritDoc} */ @Override public void finishTraining() { super.finishTraining(); this.getPopulation().setBestGenome(this.getBestGenome()); } /** * @return A network created for the best genome. */ @Override public MLMethod getMethod() { return this.getPopulation(); } /** * Returns null, does not use a training set, rather uses a score function. * * @return null, not used. */ @Override public MLDataSet getTraining() { return null; } /** * @return The strategies to use. */ public List<Strategy> getStrategies() { return this.strategies; } @Override public void iteration() { preIteration(); super.iteration(); postIteration(); } /** * Call the strategies after an iteration. */ public void postIteration() { for (final Strategy strategy : this.strategies) { strategy.postIteration(); } } /** * Call the strategies before an iteration. */ public void preIteration() { for (final Strategy strategy : this.strategies) { strategy.preIteration(); } } }