/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.actor.worker;
import akka.actor.ActorRef;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.container.obj.RawSourceData.SourceType;
import ml.shifu.shifu.core.AbstractTrainer;
import ml.shifu.shifu.fs.ShifuFileUtils;
import ml.shifu.shifu.message.StatsPartRawDataMessage;
import ml.shifu.shifu.message.TrainInstanceMessage;
import ml.shifu.shifu.message.TrainPartDataMessage;
import ml.shifu.shifu.util.Constants;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.data.buffer.BufferedMLDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
import java.io.IOException;
import java.util.List;
/**
* TrainDataPrepWorker class prepare the data for trainer
* Notice: if the training data is too large, user can train model using disk
*/
public class TrainDataPrepWorker extends AbstractWorkerActor {
private static Logger log = LoggerFactory.getLogger(TrainModelWorker.class);
private MLDataSet masterDataSet;
private int receivedMsgCnt = 0;
private List<AbstractTrainer> trainers;
private boolean initialized = false;
public TrainDataPrepWorker(ModelConfig modelConfig, List<ColumnConfig> columnConfigList, ActorRef parentActorRef,
ActorRef nextActorRef, List<AbstractTrainer> trainers) throws IOException {
super(modelConfig, columnConfigList, parentActorRef, nextActorRef);
this.trainers = trainers;
if(modelConfig.isTrainOnDisk()) {
log.info("Training Option: Disk");
ShifuFileUtils.createDirIfNotExists(Constants.TMP, SourceType.LOCAL);
masterDataSet = new BufferedMLDataSet(new File(Constants.TMP, "master.egb"));
} else {
log.info("Training Option: Memory");
masterDataSet = new BasicMLDataSet();
}
}
/*
* (non-Javadoc)
*
* @see akka.actor.UntypedActor#onReceive(java.lang.Object)
*/
@Override
public void handleMsg(Object message) throws IOException {
if(message instanceof TrainPartDataMessage) {
log.debug("Received value object list for training model.");
TrainPartDataMessage msg = (TrainPartDataMessage) message;
for(MLDataPair mlDataPir: msg.getMlDataPairList()) {
if(modelConfig.isTrainOnDisk() && !initialized) {
int inputSize = mlDataPir.getInput().size();
int idealSize = mlDataPir.getIdeal().size();
((BufferedMLDataSet) masterDataSet).beginLoad(inputSize, idealSize);
initialized = true;
}
masterDataSet.add(mlDataPir);
}
receivedMsgCnt++;
log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
if(receivedMsgCnt == msg.getTotalMsgCnt()) {
if(modelConfig.isTrainOnDisk() && initialized) {
((BufferedMLDataSet) masterDataSet).endLoad();
}
for(AbstractTrainer trainer: trainers) {
// if the trainOnDisk is true, setting the "D" option
if(modelConfig.isTrainOnDisk()) {
trainer.setTrainingOption("D");
}
trainer.setDataSet(masterDataSet);
nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
}
if(modelConfig.isTrainOnDisk() && initialized) {
masterDataSet.close();
masterDataSet = null;
}
}
} else if(message instanceof StatsPartRawDataMessage) {
StatsPartRawDataMessage msg = (StatsPartRawDataMessage) message;
receivedMsgCnt++;
log.debug("Expected " + msg.getTotalMsgCnt() + " messages, received " + receivedMsgCnt + " message(s).");
if(receivedMsgCnt == msg.getTotalMsgCnt()) {
for(AbstractTrainer trainer: trainers) {
// ((DecisionTreeTrainer)trainer).setDataSet(rawInstanceList);
nextActorRef.tell(new TrainInstanceMessage(trainer), this.getSelf());
}
}
} else {
unhandled(message);
}
}
}