/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.dtrain.dataset;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.encog.util.obj.ObjectCloner;
/**
* Copy from {@link BasicFloatMLDataSet} to support float type data.
*/
public class BasicFloatMLDataSet implements Serializable, FloatMLDataSet, Cloneable {
/**
* An iterator to be used with the BasicFloatMLDataSet. This iterator does not
* support removes.
*/
public class BasicMLIterator implements Iterator<FloatMLDataPair> {
/**
* The index that the iterator is currently at.
*/
private int currentIndex = 0;
/**
* {@inheritDoc}
*/
@Override
public final boolean hasNext() {
return this.currentIndex < BasicFloatMLDataSet.this.data.size();
}
/**
* {@inheritDoc}
*/
@Override
public final FloatMLDataPair next() {
if(!hasNext()) {
return null;
}
return BasicFloatMLDataSet.this.data.get(this.currentIndex++);
}
/**
* {@inheritDoc}
*/
@Override
public final void remove() {
throw new RuntimeException("Called remove, unsupported operation.");
}
}
/**
* The serial id.
*/
private static final long serialVersionUID = -2279722928570071183L;
/**
* The data held by this object.
*/
private List<FloatMLDataPair> data = new ArrayList<FloatMLDataPair>();
/**
* Default constructor.
*/
public BasicFloatMLDataSet() {
}
/**
* Construct a data set from an input and idea array.
*
* @param input
* The input into the machine learning method for training.
* @param ideal
* The ideal output for training.
*/
public BasicFloatMLDataSet(final float[][] input, final float[][] ideal) {
if(ideal != null) {
for(int i = 0; i < input.length; i++) {
final BasicFloatMLData inputData = new BasicFloatMLData(input[i]);
final BasicFloatMLData idealData = new BasicFloatMLData(ideal[i]);
this.add(inputData, idealData);
}
} else {
for(final float[] element: input) {
final BasicFloatMLData inputData = new BasicFloatMLData(element);
this.add(inputData);
}
}
}
/**
* Construct a data set from an already created list. Mostly used to
* duplicate this class.
*
* @param theData
* The data to use.
*/
public BasicFloatMLDataSet(final List<FloatMLDataPair> theData) {
this.data = theData;
}
/**
* Copy whatever dataset type is specified into a memory dataset.
*
* @param set
* The dataset to copy.
*/
public BasicFloatMLDataSet(final FloatMLDataSet set) {
final int inputCount = set.getInputSize();
final int idealCount = set.getIdealSize();
for(final FloatMLDataPair pair: set) {
BasicFloatMLData input = null;
BasicFloatMLData ideal = null;
if(inputCount > 0) {
input = new BasicFloatMLData(inputCount);
BasicFloatMLDataSet.arrayCopy(pair.getInputArray(), input.getData());
}
if(idealCount > 0) {
ideal = new BasicFloatMLData(idealCount);
BasicFloatMLDataSet.arrayCopy(pair.getIdealArray(), ideal.getData());
}
add(new BasicFloatMLDataPair(input, ideal));
}
}
/**
* Copy an array of floats to an array of floats.
*
* @param source
* The source array.
* @param target
* The target array.
*/
public static void arrayCopy(final float[] source, final float[] target) {
for(int i = 0; i < source.length; i++) {
target[i] = source[i];
}
}
/**
* {@inheritDoc}
*/
@Override
public void add(final FloatMLData theData) {
this.data.add(new BasicFloatMLDataPair(theData));
}
/**
* {@inheritDoc}
*/
@Override
public void add(final FloatMLData inputData, final FloatMLData idealData) {
final FloatMLDataPair pair = new BasicFloatMLDataPair(inputData, idealData);
this.data.add(pair);
}
/**
* {@inheritDoc}
*/
@Override
public void add(final FloatMLDataPair inputData) {
this.data.add(inputData);
}
/**
* {@inheritDoc}
*/
@Override
public final Object clone() {
return ObjectCloner.deepCopy(this);
}
/**
* {@inheritDoc}
*/
@Override
public final void close() {
// nothing to close
}
/**
* Get the data held by this container.
*
* @return the data
*/
public final List<FloatMLDataPair> getData() {
return this.data;
}
/**
* {@inheritDoc}
*/
@Override
public final int getIdealSize() {
if(this.data.isEmpty()) {
return 0;
}
final FloatMLDataPair first = this.data.get(0);
if(first.getIdeal() == null) {
return 0;
}
return first.getIdeal().size();
}
/**
* {@inheritDoc}
*/
@Override
public final int getInputSize() {
if(this.data.isEmpty()) {
return 0;
}
final FloatMLDataPair first = this.data.get(0);
return first.getInput().size();
}
/**
* {@inheritDoc}
*/
@Override
public final void getRecord(final long index, final FloatMLDataPair pair) {
final FloatMLDataPair source = this.data.get((int) index);
pair.setInputArray(source.getInputArray());
if(pair.getIdealArray() != null) {
pair.setIdealArray(source.getIdealArray());
}
}
/**
* {@inheritDoc}
*/
@Override
public final long getRecordCount() {
return this.data.size();
}
/**
* {@inheritDoc}
*/
@Override
public final boolean isSupervised() {
if(this.data.size() == 0) {
return false;
}
return this.data.get(0).isSupervised();
}
/**
* {@inheritDoc}
*/
@Override
public final Iterator<FloatMLDataPair> iterator() {
final BasicMLIterator result = new BasicMLIterator();
return result;
}
/**
* {@inheritDoc}
*/
@Override
public final FloatMLDataSet openAdditional() {
return new BasicFloatMLDataSet(this.data);
}
/**
* @param theData
* the data to set
*/
public final void setData(final List<FloatMLDataPair> theData) {
this.data = theData;
}
}