/* * Copyright [2013-2015] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.dataset; import java.io.File; import java.io.Serializable; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import org.encog.ml.data.buffer.BufferedMLDataSet; /** * Copy from {@link BufferedMLDataSet} to support float type data. */ public class BufferedFloatMLDataSet implements FloatMLDataSet, Serializable { /** * The version. */ private static final long serialVersionUID = 2577778772598513566L; /** * Error message for ADD. */ public static final String ERROR_ADD = "Add can only be used after calling beginLoad."; /** * Error message for REMOVE. */ public static final String ERROR_REMOVE = "Remove is not supported for BufferedNeuralDataSet."; /** * True, if we are in the process of loading. */ private transient boolean loading; /** * The file being used. */ private File file; /** * The EGB file we are working wtih. */ private transient EncogFloatEGBFile egb; /** * Additional sets that were opened. */ private transient List<BufferedFloatMLDataSet> additional = new ArrayList<BufferedFloatMLDataSet>(); /** * The owner. */ private transient BufferedFloatMLDataSet owner; /** * Construct the dataset using the specified binary file. * * @param binaryFile * The file to use. */ public BufferedFloatMLDataSet(final File binaryFile) { this.file = binaryFile; this.egb = new EncogFloatEGBFile(binaryFile); if(file.exists()) { this.egb.open(); } } /** * Open the binary file for reading. */ public final void open() { this.egb.open(); } /** * @return An iterator. */ @Override public final Iterator<FloatMLDataPair> iterator() { return new BufferedFloatDataSetIterator(this); } /** * @return The record count. */ @Override public final long getRecordCount() { if(this.egb == null) { return 0; } else { return this.egb.getNumberOfRecords(); } } /** * Read an individual record. * * @param index * The zero-based index. Specify 0 for the first record, 1 for * the second, and so on. * @param pair * THe data to read. */ @Override public final void getRecord(final long index, final FloatMLDataPair pair) { this.egb.setLocation((int) index); float[] inputTarget = pair.getInputArray(); this.egb.read(inputTarget); if(pair.getIdealArray() != null) { float[] idealTarget = pair.getIdealArray(); this.egb.read(idealTarget); } this.egb.read(); } /** * @return An additional training set. */ @Override public final BufferedFloatMLDataSet openAdditional() { BufferedFloatMLDataSet result = new BufferedFloatMLDataSet(this.file); result.setOwner(this); this.additional.add(result); return result; } /** * Add only input data, for an unsupervised dataset. * * @param data1 * The data to be added. */ public final void add(final FloatMLData data1) { if(!this.loading) { throw new RuntimeException(BufferedFloatMLDataSet.ERROR_ADD); } egb.write(data1.getData()); egb.write(1.0f); } /** * Add both the input and ideal data. * * @param inputData * The input data. * @param idealData * The ideal data. */ public final void add(final FloatMLData inputData, final FloatMLData idealData) { if(!this.loading) { throw new RuntimeException(BufferedFloatMLDataSet.ERROR_ADD); } this.egb.write(inputData.getData()); this.egb.write(idealData.getData()); this.egb.write((float) 1.0f); } /** * Add a data pair of both input and ideal data. * * @param pair * The pair to add. */ public final void add(final FloatMLDataPair pair) { if(!this.loading) { throw new RuntimeException(BufferedFloatMLDataSet.ERROR_ADD); } this.egb.write(pair.getInputArray()); this.egb.write(pair.getIdealArray()); this.egb.write(pair.getSignificance()); } /** * Close the dataset. */ @Override public final void close() { Object[] obj = this.additional.toArray(); for(int i = 0; i < obj.length; i++) { BufferedFloatMLDataSet set = (BufferedFloatMLDataSet) obj[i]; set.close(); } this.additional.clear(); if(this.owner != null) { this.owner.removeAdditional(this); } this.egb.close(); this.egb = null; } /** * @return The ideal data size. */ @Override public final int getIdealSize() { if(this.egb == null) { return 0; } else { return this.egb.getIdealCount(); } } /** * @return The input data size. */ @Override public final int getInputSize() { if(this.egb == null) { return 0; } else { return this.egb.getInputCount(); } } /** * @return True if this dataset is supervised. */ @Override public final boolean isSupervised() { if(this.egb == null) { return false; } else { return this.egb.getIdealCount() > 0; } } /** * @return If this dataset was created by openAdditional, the set that * created this object is the owner. Return the owner. */ public final BufferedFloatMLDataSet getOwner() { return owner; } /** * Set the owner of this dataset. * * @param theOwner * The owner. */ public final void setOwner(final BufferedFloatMLDataSet theOwner) { this.owner = theOwner; } /** * Remove an additional dataset that was created. * * @param child * The additional dataset to remove. */ public final void removeAdditional(final BufferedFloatMLDataSet child) { synchronized(this) { this.additional.remove(child); } } /** * Begin loading to the binary file. After calling this method the add * methods may be called. * * @param inputSize * The input size. * @param idealSize * The ideal size. */ public final void beginLoad(final int inputSize, final int idealSize) { this.egb.create(inputSize, idealSize); this.loading = true; } /** * This method should be called once all the data has been loaded. The * underlying file will be closed. The binary fill will then be opened for * reading. */ public final void endLoad() { if(!this.loading) { throw new RuntimeException("Must call beginLoad, before endLoad."); } this.egb.close(); open(); } /** * @return The binary file used. */ public final File getFile() { return this.file; } /** * @return The EGB file to use. */ public final EncogFloatEGBFile getEGB() { return this.egb; } /** * Load the binary dataset to memory. Memory access is faster. * * @return A memory dataset. */ public final FloatMLDataSet loadToMemory() { BasicFloatMLDataSet result = new BasicFloatMLDataSet(); for(FloatMLDataPair pair: this) { result.add(pair); } return result; } /** * Load the specified training set. * * @param training * The training set to load. */ public final void load(final FloatMLDataSet training) { beginLoad(training.getInputSize(), training.getIdealSize()); for(final FloatMLDataPair pair: training) { add(pair); } endLoad(); } }