/*
* Copyright [2012-2014] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.actor;
import akka.actor.*;
import ml.shifu.shifu.container.obj.ColumnConfig;
import ml.shifu.shifu.container.obj.EvalConfig;
import ml.shifu.shifu.container.obj.ModelConfig;
import ml.shifu.shifu.core.AbstractTrainer;
import ml.shifu.shifu.exception.ShifuErrorCode;
import ml.shifu.shifu.exception.ShifuException;
import ml.shifu.shifu.message.AkkaActorInputMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.List;
import java.util.Scanner;
/**
* AkkaSystemExecutor class
* The executor for AKKA system. It's singleton.
*/
public class AkkaSystemExecutor {
private static Logger log = LoggerFactory.getLogger(AkkaSystemExecutor.class);
private static AkkaSystemExecutor instance = new AkkaSystemExecutor();
private ActorSystem actorSystem;
// singleton
private AkkaSystemExecutor() {
}
/**
* Get executor for AKKA System
*
* @return - executor
*/
public static AkkaSystemExecutor getExecutor() {
return instance;
}
/**
* Submit job to calculate column stats
* Column stats including:
* - Binning of value range
* - max/min/average
* - ks/iv
*
* @param modelConfig
* - configuration for model
* @param columnConfigList
* - configurations for columns
* @param scanners
* - scanners of training data
*/
public void submitStatsCalJob(final ModelConfig modelConfig, final List<ColumnConfig> columnConfigList,
List<Scanner> scanners) {
actorSystem = ActorSystem.create("ShifuActorSystem");
final AkkaExecStatus akkaStatus = new AkkaExecStatus(true);
log.info("Create Akka system to calculate stats");
ActorRef statsCalRef = actorSystem.actorOf(new Props(new UntypedActorFactory() {
private static final long serialVersionUID = -1437127862571741369L;
public UntypedActor create() {
return new CalculateStatsActor(modelConfig, columnConfigList, akkaStatus);
}
}), "stats-calculator");
statsCalRef.tell(new AkkaActorInputMessage(scanners), statsCalRef);
// wait for termination and check the status
actorSystem.awaitTermination();
checkAkkaStatus(akkaStatus);
}
/**
* Submit job to normalize training data
*
* @param modelConfig
* - configuration for model
* @param columnConfigList
* - configurations for columns
* @param scanners
* - scanners of training data
*/
public void submitNormalizeJob(final ModelConfig modelConfig, final List<ColumnConfig> columnConfigList,
List<Scanner> scanners) {
actorSystem = ActorSystem.create("ShifuActorSystem");
final AkkaExecStatus akkaStatus = new AkkaExecStatus(true);
log.info("Create Akka system to normalize data");
ActorRef dataNormalizeRef = actorSystem.actorOf(new Props(new UntypedActorFactory() {
private static final long serialVersionUID = -2123098236012879296L;
public UntypedActor create() throws IOException {
return new NormalizeDataActor(modelConfig, columnConfigList, akkaStatus);
}
}), "data-normalizer");
dataNormalizeRef.tell(new AkkaActorInputMessage(scanners), dataNormalizeRef);
// wait for termination
actorSystem.awaitTermination();
checkAkkaStatus(akkaStatus);
}
/**
* Submit job to training model
*
* @param modelConfig
* - configuration for model
* @param columnConfigList
* - configurations for columns
* @param scanners
* - scanners of normalized training data
* @param trainers
* - model trainer
*/
public void submitModelTrainJob(final ModelConfig modelConfig, final List<ColumnConfig> columnConfigList,
List<Scanner> scanners, final List<AbstractTrainer> trainers) {
actorSystem = ActorSystem.create("ShifuActorSystem");
final AkkaExecStatus akkaStatus = new AkkaExecStatus(true);
log.info("Create Akka system to train model");
ActorRef modelTrainerRef = actorSystem.actorOf(new Props(new UntypedActorFactory() {
private static final long serialVersionUID = -1437127862571741369L;
public UntypedActor create() {
return new TrainModelActor(modelConfig, columnConfigList, akkaStatus, trainers);
}
}), "model-trainer");
modelTrainerRef.tell(new AkkaActorInputMessage(scanners), modelTrainerRef);
// wait for termination
actorSystem.awaitTermination();
checkAkkaStatus(akkaStatus);
}
/**
* Submit job to training decision-tree model
*
* @param modelConfig
* - configuration for model
* @param columnConfigList
* - configurations for columns
* @param scanners
* - scanners of normalized training data
* @param trainers
* - model trainer
*/
public void submitDecisionTreeTrainJob(final ModelConfig modelConfig, final List<ColumnConfig> columnConfigList,
List<Scanner> scanners, final List<AbstractTrainer> trainers) {
actorSystem = ActorSystem.create("ShifuActorSystem");
final AkkaExecStatus akkaStatus = new AkkaExecStatus(true);
log.info("Create Akka system to train dt-model");
ActorRef modelTrainerRef = actorSystem.actorOf(new Props(new UntypedActorFactory() {
private static final long serialVersionUID = 2394968604729416422L;
public UntypedActor create() {
return new TrainDtModelActor(modelConfig, columnConfigList, akkaStatus, trainers);
}
}), "dt-model-trainer");
modelTrainerRef.tell(new AkkaActorInputMessage(scanners), modelTrainerRef);
// wait for termination
actorSystem.awaitTermination();
checkAkkaStatus(akkaStatus);
}
/**
* Submit job to post-train the model
*
* @param modelConfig
* - configuration for model
* @param columnConfigList
* - configurations for columns
* @param scanners
* - scanners of select data that are normalized
*/
public void submitPostTrainJob(final ModelConfig modelConfig, final List<ColumnConfig> columnConfigList,
List<Scanner> scanners) {
actorSystem = ActorSystem.create("ShifuActorSystem");
final AkkaExecStatus akkaStatus = new AkkaExecStatus(true);
log.info("Create Akka system to post-train model");
ActorRef postTrainerRef = actorSystem.actorOf(new Props(new UntypedActorFactory() {
private static final long serialVersionUID = -1437127862571741369L;
public UntypedActor create() {
return new PostTrainActor(modelConfig, columnConfigList, akkaStatus);
}
}), "model-posttrainer");
postTrainerRef.tell(new AkkaActorInputMessage(scanners), postTrainerRef);
// wait for termination
actorSystem.awaitTermination();
checkAkkaStatus(akkaStatus);
}
/**
* Submit job to run model evaluation
*
* @param modelConfig
* - configuration for model
* @param columnConfigList
* - configurations for columns
* @param evalConfig
* the eval config instance
* @param scanners
* - scanners of evaluation data
*/
public void submitModelEvalJob(final ModelConfig modelConfig, final List<ColumnConfig> columnConfigList,
final EvalConfig evalConfig, List<Scanner> scanners) {
actorSystem = ActorSystem.create("ShifuActorSystem");
final AkkaExecStatus akkaStatus = new AkkaExecStatus(true);
log.info("Create Akka system to evaluate model");
ActorRef modelEvalRef = actorSystem.actorOf(new Props(new UntypedActorFactory() {
private static final long serialVersionUID = -1437127862571741369L;
public UntypedActor create() {
return new EvalModelActor(modelConfig, columnConfigList, akkaStatus, evalConfig);
}
}), "model-evaluator");
modelEvalRef.tell(new AkkaActorInputMessage(scanners), modelEvalRef);
// wait for termination
actorSystem.awaitTermination();
checkAkkaStatus(akkaStatus);
}
/**
* check the execute status of AKKA, if there is any Exceptions, wrap it with ShifuException and throw it
*
* @param akkaStatus
*/
private void checkAkkaStatus(final AkkaExecStatus akkaStatus) {
if(!akkaStatus.getStatus()) {
throw new ShifuException(ShifuErrorCode.ERROR_AKKA_EXECUTE_EXCEPTION, akkaStatus.getException());
}
}
}