/* * Copyright [2012-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dvarsel.dataset; import org.encog.ml.data.MLDataPair; import org.encog.ml.data.MLDataSet; import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.Set; /** * Created on 11/24/2014. */ public class TrainingDataSet { private static Random rd = new Random(System.currentTimeMillis()); private List<Integer> dataColumnIdList; private List<TrainingRecord> trainingRecords; public TrainingDataSet(List<Integer> dataColumnIdList) { this.dataColumnIdList = dataColumnIdList; this.trainingRecords = new ArrayList<TrainingRecord>(); } public void addTrainingRecord(TrainingRecord trainingRecord) { if (trainingRecord != null ) { this.trainingRecords.add(trainingRecord); } } public void generateValidateData(Set<Integer> workingColumnIdSet, double validationRate, MLDataSet trainingData, MLDataSet testingData ) { for ( TrainingRecord trainingRecord : trainingRecords ) { MLDataPair pair = trainingRecord.toMLDataPair(dataColumnIdList, workingColumnIdSet); double seed = rd.nextDouble(); if (seed > validationRate) { trainingData.add(pair); } else { testingData.add(pair); } } } public List<Integer> getDataColumnIdList() { return this.dataColumnIdList; } }