/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.dtrain.dt;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import ml.shifu.guagua.io.HaltBytable;
/**
* Master parameters transferred from master to all workers in all iterations.
*
* <p>
* Every iteration, tree root nodes {@link #trees} are transferred to avoid maintain such updated trees in workers.
*
* <p>
* Every time for Random Forest, all {@link #trees} will be transfered to workers. While for GBDT, only current tree
* will be transfered to workers. Worker recover from checkpoint trees in each iteration from worker.
*
* <p>
* {@link #tmpTrees} is transient and only for GBDT, in {@link DTOutput}, {@link #tmpTrees} is used to save model to
* HDFS while not sent to workers.
*
* @author Zhang David (pengzhang@paypal.com)
*/
public class DTMasterParams extends HaltBytable {
/**
* All updated trees.
*/
private List<TreeNode> trees;
/**
* nodeIndexInGroup => (treeId, Node); nodeIndexInGroup is starting from 0 in each iteration.
*/
private Map<Integer, TreeNode> todoNodes;
/**
* Sum of weighted training counts accumulated by workers.
*/
private double trainCount;
/**
* Sum of weighted validation counts accumulated by workers.
*/
private double validationCount;
/**
* Sum of train error accumulated by workers.
*/
private double trainError;
/**
* Sum of validation error accumulated by workers.
*/
private double validationError;
/**
* For GBDT only, in GBDT, this means move compute to next tree.
*/
private boolean isSwitchToNextTree = false;
/**
* Tree depth per tree index which is used to show on each iteration.
*/
private List<Integer> treeDepth = new ArrayList<Integer>();
/**
* If it is continuous running at first iteration in master
*/
private boolean isContinuousRunningStart = false;
/**
* Check if it is the first tree
*/
private boolean isFirstTree = false;
/**
* Tmp trees and reference from DTMaster#trees, which cannot and will not be serialized from master worker
* iteration, only for DTOutput reference.
*/
private List<TreeNode> tmpTrees;
public DTMasterParams() {
}
public DTMasterParams(double trainCount, double trainError, double validationCount, double validationError) {
this.trainCount = trainCount;
this.trainError = trainError;
this.validationCount = validationCount;
this.validationError = validationError;
}
public DTMasterParams(List<TreeNode> trees, Map<Integer, TreeNode> todoNodes) {
this.trees = trees;
this.todoNodes = todoNodes;
}
/**
* @return the trees
*/
public List<TreeNode> getTrees() {
return trees;
}
/**
* @return the todoNodes
*/
public Map<Integer, TreeNode> getTodoNodes() {
return todoNodes;
}
/**
* @param trees
* the trees to set
*/
public void setTrees(List<TreeNode> trees) {
this.trees = trees;
}
/**
* @param todoNodes
* the todoNodes to set
*/
public void setTodoNodes(Map<Integer, TreeNode> todoNodes) {
this.todoNodes = todoNodes;
}
@Override
public void doWrite(DataOutput out) throws IOException {
out.writeDouble(trainCount);
out.writeDouble(validationCount);
out.writeDouble(trainError);
out.writeDouble(validationError);
out.writeBoolean(this.isSwitchToNextTree);
assert trees != null;
out.writeInt(trees.size());
for(TreeNode node: trees) {
node.writeWithoutFeatures(out);
}
if(todoNodes == null) {
out.writeInt(0);
} else {
out.writeInt(todoNodes.size());
for(Map.Entry<Integer, TreeNode> node: todoNodes.entrySet()) {
out.writeInt(node.getKey());
// for todo nodes, no left and right node, so node serialization not waste space
node.getValue().write(out);
}
}
out.writeBoolean(isContinuousRunningStart);
out.writeBoolean(isFirstTree);
}
@Override
public void doReadFields(DataInput in) throws IOException {
this.trainCount = in.readDouble();
this.validationCount = in.readDouble();
this.trainError = in.readDouble();
this.validationError = in.readDouble();
this.isSwitchToNextTree = in.readBoolean();
int treeNum = in.readInt();
this.trees = new ArrayList<TreeNode>(treeNum);
for(int i = 0; i < treeNum; i++) {
TreeNode treeNode = new TreeNode();
treeNode.readFieldsWithoutFeatures(in);
this.trees.add(treeNode);
}
int todoNodesSize = in.readInt();
if(todoNodesSize > 0) {
todoNodes = new HashMap<Integer, TreeNode>(todoNodesSize, 1f);
for(int i = 0; i < todoNodesSize; i++) {
int key = in.readInt();
TreeNode treeNode = new TreeNode();
treeNode.readFields(in);
todoNodes.put(key, treeNode);
}
}
this.isContinuousRunningStart = in.readBoolean();
this.isFirstTree = in.readBoolean();
}
/**
* @return the trainCount
*/
public double getTrainCount() {
return trainCount;
}
/**
* @return the validationCount
*/
public double getValidationCount() {
return validationCount;
}
/**
* @param trainCount
* the trainCount to set
*/
public void setTrainCount(double trainCount) {
this.trainCount = trainCount;
}
/**
* @param validationCount
* the validationCount to set
*/
public void setValidationCount(double validationCount) {
this.validationCount = validationCount;
}
/**
* @return the squareError
*/
public double getTrainError() {
return trainError;
}
/**
* @param squareError
* the squareError to set
*/
public void setTrainError(double squareError) {
this.trainError = squareError;
}
/**
* @return the isSwitchToNextTree
*/
public boolean isSwitchToNextTree() {
return isSwitchToNextTree;
}
/**
* @param isSwitchToNextTree
* the isSwitchToNextTree to set
*/
public void setSwitchToNextTree(boolean isSwitchToNextTree) {
this.isSwitchToNextTree = isSwitchToNextTree;
}
/**
* @return the treeDepth
*/
public List<Integer> getTreeDepth() {
return treeDepth;
}
/**
* @param treeDepth
* the treeDepth to set
*/
public void setTreeDepth(List<Integer> treeDepth) {
this.treeDepth = treeDepth;
}
/**
* @return the validationError
*/
public double getValidationError() {
return validationError;
}
/**
* @param validationError
* the validationError to set
*/
public void setValidationError(double validationError) {
this.validationError = validationError;
}
/**
* @return the isContinuousRunningStart
*/
public boolean isContinuousRunningStart() {
return isContinuousRunningStart;
}
/**
* @param isContinuousRunningStart
* the isContinuousRunningStart to set
*/
public void setContinuousRunningStart(boolean isContinuousRunningStart) {
this.isContinuousRunningStart = isContinuousRunningStart;
}
/**
* @return the tmpTrees
*/
public List<TreeNode> getTmpTrees() {
return tmpTrees;
}
/**
* @param tmpTrees
* the tmpTrees to set
*/
public void setTmpTrees(List<TreeNode> tmpTrees) {
this.tmpTrees = tmpTrees;
}
/**
* @return the isFirstTree
*/
public boolean isFirstTree() {
return isFirstTree;
}
/**
* @param isFirstTree
* the isFirstTree to set
*/
public void setFirstTree(boolean isFirstTree) {
this.isFirstTree = isFirstTree;
}
}