/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.data.basic; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.Iterator; import java.util.List; import org.encog.EncogError; import org.encog.ml.data.MLData; import org.encog.ml.data.MLDataError; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import org.encog.ml.data.MLSequenceSet; import org.encog.util.EngineArray; import org.encog.util.obj.ObjectCloner; /** * A basic implementation of the MLSequenceSet. */ public class BasicMLSequenceSet implements Serializable, MLSequenceSet { /** * An iterator to be used with the BasicMLDataSet. This iterator does not * support removes. * * @author jheaton */ public class BasicMLSeqIterator implements Iterator<MLDataPair> { /** * The index that the iterator is currently at. */ private int currentIndex = 0; /** * The sequence index. */ private int currentSequenceIndex = 0; /** * {@inheritDoc} */ @Override public boolean hasNext() { if( this.currentSequenceIndex>=sequences.size() ) { return false; } MLDataSet seq = sequences.get(this.currentSequenceIndex); if(this.currentIndex>=seq.getRecordCount()) { return false; } return true; } /** * {@inheritDoc} */ @Override public MLDataPair next() { if (!hasNext()) { return null; } MLDataSet target = sequences.get(this.currentSequenceIndex); MLDataPair result = ((BasicMLDataSet)target).getData().get(this.currentIndex); this.currentIndex++; if( this.currentIndex>=target.getRecordCount()) { this.currentIndex = 0; this.currentSequenceIndex++; } return result; } /** * {@inheritDoc} */ @Override public void remove() { throw new EncogError("Called remove, unsupported operation."); } } /** * The serial id. */ private static final long serialVersionUID = -2279722928570071183L; /** * The data held by this object. */ private List<MLDataSet> sequences = new ArrayList<MLDataSet>(); private MLDataSet currentSequence; /** * Default constructor. */ public BasicMLSequenceSet() { this.currentSequence = new BasicMLDataSet(); sequences.add(this.currentSequence); } public BasicMLSequenceSet(BasicMLSequenceSet other) { this.sequences = other.sequences; this.currentSequence = other.currentSequence; } /** * Construct a data set from an input and ideal array. * * @param input * The input into the machine learning method for training. * @param ideal * The ideal output for training. */ public BasicMLSequenceSet(final double[][] input, final double[][] ideal) { this.currentSequence = new BasicMLDataSet(input,ideal); this.sequences.add(this.currentSequence); } /** * Construct a data set from an already created list. Mostly used to * duplicate this class. * * @param theData * The data to use. */ public BasicMLSequenceSet(final List<MLDataPair> theData) { this.currentSequence = new BasicMLDataSet(theData); this.sequences.add(this.currentSequence); } /** * Copy whatever dataset type is specified into a memory dataset. * * @param set * The dataset to copy. */ public BasicMLSequenceSet(final MLDataSet set) { this.currentSequence = new BasicMLDataSet(); this.sequences.add(this.currentSequence); final int inputCount = set.getInputSize(); final int idealCount = set.getIdealSize(); for (final MLDataPair pair : set) { BasicMLData input = null; BasicMLData ideal = null; if (inputCount > 0) { input = new BasicMLData(inputCount); EngineArray.arrayCopy(pair.getInputArray(), input.getData()); } if (idealCount > 0) { ideal = new BasicMLData(idealCount); EngineArray.arrayCopy(pair.getIdealArray(), ideal.getData()); } this.currentSequence.add(new BasicMLDataPair(input, ideal)); } } /** * {@inheritDoc} */ @Override public void add(final MLData theData) { this.currentSequence.add(theData); } /** * {@inheritDoc} */ @Override public void add(final MLData inputData, final MLData idealData) { final MLDataPair pair = new BasicMLDataPair(inputData, idealData); this.currentSequence.add(pair); } /** * {@inheritDoc} */ @Override public void add(final MLDataPair inputData) { this.currentSequence.add(inputData); } /** * {@inheritDoc} */ @Override public Object clone() { return ObjectCloner.deepCopy(this); } /** * {@inheritDoc} */ @Override public void close() { // nothing to close } /** * {@inheritDoc} */ @Override public int getIdealSize() { if (this.sequences.get(0).getRecordCount()==0) { return 0; } return this.sequences.get(0).getIdealSize(); } /** * {@inheritDoc} */ @Override public int getInputSize() { if (this.sequences.get(0).getRecordCount()==0) { return 0; } return this.sequences.get(0).getIdealSize(); } /** * {@inheritDoc} */ @Override public void getRecord(final long index, final MLDataPair pair) { long recordIndex = index; int sequenceIndex = 0; while( this.sequences.get(sequenceIndex).getRecordCount()<recordIndex) { recordIndex-=this.sequences.get(sequenceIndex).getRecordCount(); sequenceIndex++; if( sequenceIndex>this.sequences.size() ) { throw new MLDataError("Record out of range: " + index); } } this.sequences.get(sequenceIndex).getRecord(recordIndex, pair); } /** * {@inheritDoc} */ @Override public long getRecordCount() { long result = 0; for(MLDataSet ds: this.sequences) { result+=ds.getRecordCount(); } return result; } /** * {@inheritDoc} */ @Override public boolean isSupervised() { if (this.sequences.get(0).getRecordCount() == 0) { return false; } return this.sequences.get(0).isSupervised(); } /** * {@inheritDoc} */ @Override public Iterator<MLDataPair> iterator() { final BasicMLSeqIterator result = new BasicMLSeqIterator(); return result; } /** * {@inheritDoc} */ @Override public MLDataSet openAdditional() { return new BasicMLSequenceSet(this); } @Override public void startNewSequence() { if (this.currentSequence.getRecordCount() > 0) { this.currentSequence = new BasicMLDataSet(); this.sequences.add(this.currentSequence); } } @Override public int getSequenceCount() { return this.sequences.size(); } @Override public MLDataSet getSequence(int i) { return this.sequences.get(i); } @Override public Collection<MLDataSet> getSequences() { return this.sequences; } @Override public int size() { return (int)getRecordCount(); } @Override public MLDataPair get(int index) { MLDataPair result = BasicMLDataPair.createPair(getInputSize(), getIdealSize()); this.getRecord(index, result); return result; } @Override public void add(MLDataSet sequence) { for(MLDataPair pair: sequence) { add(pair); } } }