/*
* Copyright [2013-2015] PayPal Software Foundation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.shifu.shifu.core.dtrain.dt;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Map.Entry;
import ml.shifu.guagua.io.Bytable;
import ml.shifu.guagua.io.Combinable;
import ml.shifu.guagua.io.HaltBytable;
/**
* Worker result return to master.
*
* <p>
* The first part is for error collections: {@link #trainCount} and {@link #trainError}.
*
* <p>
* {@link #nodeStatsMap} includes node statistics for each node, key is node group index id from master.
*
* @author Zhang David (pengzhang@paypal.com)
*
* @see NodeStats
*/
public class DTWorkerParams extends HaltBytable implements Combinable<DTWorkerParams> {
/**
* # of weighted training records per such worker.
*/
private double trainCount;
/**
* # of weighted training records per such worker.
*/
private double validationCount;
/**
* Train error for such worker and such iteration.
*/
private double trainError;
/**
* Validation error for such worker and such iteration.
*/
private double validationError;
/**
* Node statistic map including node group index and node stats object.
*/
private Map<Integer, NodeStats> nodeStatsMap;
public DTWorkerParams() {
}
public DTWorkerParams(double trainCount, double validationCount, double trainError, double validationError,
Map<Integer, NodeStats> nodeStatsMap) {
this.trainCount = trainCount;
this.validationCount = validationCount;
this.trainError = trainError;
this.validationError = validationError;
this.nodeStatsMap = nodeStatsMap;
}
@Override
public void doWrite(DataOutput out) throws IOException {
out.writeDouble(trainCount);
out.writeDouble(validationCount);
out.writeDouble(trainError);
out.writeDouble(validationError);
if(nodeStatsMap == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeInt(nodeStatsMap.size());
for(Entry<Integer, NodeStats> entry: nodeStatsMap.entrySet()) {
out.writeInt(entry.getKey());
entry.getValue().write(out);
}
}
}
@Override
public void doReadFields(DataInput in) throws IOException {
this.trainCount = in.readDouble();
this.validationCount = in.readDouble();
this.trainError = in.readDouble();
this.validationError = in.readDouble();
if(in.readBoolean()) {
this.nodeStatsMap = new HashMap<Integer, NodeStats>();
int len = in.readInt();
for(int i = 0; i < len; i++) {
int key = in.readInt();
NodeStats stats = new NodeStats();
stats.readFields(in);
this.nodeStatsMap.put(key, stats);
}
}
}
/**
* @return the nodeStatsMap
*/
public Map<Integer, NodeStats> getNodeStatsMap() {
return nodeStatsMap;
}
/**
* @param nodeStatsMap
* the nodeStatsMap to set
*/
public void setNodeStatsMap(Map<Integer, NodeStats> nodeStatsMap) {
this.nodeStatsMap = nodeStatsMap;
}
/**
* @return the squareError
*/
public double getTrainError() {
return trainError;
}
/**
* @param squareError
* the squareError to set
*/
public void setTrainError(double squareError) {
this.trainError = squareError;
}
@Override
public String toString() {
return "DTWorkerParams [count=" + trainCount + ", trainError=" + trainError + ", nodeStatsMap=" + nodeStatsMap
+ "]";
}
/**
* Node statistics with {@link #featureStatistics} including all statistics for all sub-sampling features.
*
* @author Zhang David (pengzhang@paypal.com)
*/
public static class NodeStats implements Bytable {
/**
* Node id in one node.
*/
private int nodeId;
/**
* Tree id for such node.
*/
private int treeId;
/**
* Feature statistics for sub-sampling features.
*/
private Map<Integer, double[]> featureStatistics;
public NodeStats() {
}
public NodeStats(int treeId, int nodeId, Map<Integer, double[]> featureStatistics) {
this.treeId = treeId;
this.nodeId = nodeId;
this.featureStatistics = featureStatistics;
}
/**
* @return the treeId
*/
public int getTreeId() {
return treeId;
}
/**
* @return the featureStatistics
*/
public Map<Integer, double[]> getFeatureStatistics() {
return featureStatistics;
}
/**
* @param treeId
* the treeId to set
*/
public void setTreeId(int treeId) {
this.treeId = treeId;
}
/**
* @param featureStatistics
* the featureStatistics to set
*/
public void setFeatureStatistics(Map<Integer, double[]> featureStatistics) {
this.featureStatistics = featureStatistics;
}
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(nodeId);
out.writeInt(treeId);
out.writeInt(this.featureStatistics.size());
for(Entry<Integer, double[]> entry: this.featureStatistics.entrySet()) {
out.writeInt(entry.getKey());
out.writeInt(entry.getValue().length);
for(int i = 0; i < entry.getValue().length; i++) {
out.writeDouble(entry.getValue()[i]);
}
}
}
@Override
public void readFields(DataInput in) throws IOException {
this.nodeId = in.readInt();
this.treeId = in.readInt();
int len = in.readInt();
this.featureStatistics = new HashMap<Integer, double[]>(len, 1f);
for(int i = 0; i < len; i++) {
int key = in.readInt();
int vLen = in.readInt();
double[] values = new double[vLen];
for(int j = 0; j < vLen; j++) {
values[j] = in.readDouble();
}
this.featureStatistics.put(key, values);
}
}
/**
* @return the nodeId
*/
public int getNodeId() {
return nodeId;
}
/**
* @param nodeId
* the nodeId to set
*/
public void setNodeId(int nodeId) {
this.nodeId = nodeId;
}
/*
* (non-Javadoc)
*
* @see java.lang.Object#toString()
*/
@Override
public String toString() {
return "NodeStats [nodeId=" + nodeId + ", treeId=" + treeId + ", featureStatistics="
+ toString(featureStatistics) + "]";
}
private String toString(Map<Integer, double[]> featureStatistics) {
Iterator<Map.Entry<Integer, double[]>> i = featureStatistics.entrySet().iterator();
if(!i.hasNext())
return "{}";
StringBuilder sb = new StringBuilder();
sb.append('{');
for(;;) {
Map.Entry<Integer, double[]> e = i.next();
Integer key = e.getKey();
double[] value = e.getValue();
sb.append(key);
sb.append('=');
sb.append(Arrays.toString(value));
if(!i.hasNext())
return sb.append('}').toString();
sb.append(", ");
}
}
}
@Override
public DTWorkerParams combine(DTWorkerParams that) {
assert that != null;
this.trainCount += that.trainCount;
this.trainError += that.trainError;
this.validationCount += that.validationCount;
this.validationError += that.validationError;
if(this.nodeStatsMap != null && that.nodeStatsMap != null) {
for(Entry<Integer, NodeStats> entry: this.nodeStatsMap.entrySet()) {
NodeStats nodeStats = entry.getValue();
NodeStats thatNodeStats = that.nodeStatsMap.get(entry.getKey());
assert nodeStats.nodeId == thatNodeStats.nodeId;
assert nodeStats.treeId == thatNodeStats.treeId;
for(Entry<Integer, double[]> featureStatsEntry: nodeStats.getFeatureStatistics().entrySet()) {
double[] thisFeatureStats = featureStatsEntry.getValue();
double[] thatFeatureStats = thatNodeStats.featureStatistics.get(featureStatsEntry.getKey());
for(int i = 0; i < thisFeatureStats.length; i++) {
thisFeatureStats[i] += thatFeatureStats[i];
}
}
}
}
return this;
}
/**
* @return the validationError
*/
public double getValidationError() {
return validationError;
}
/**
* @param validationError
* the validationError to set
*/
public void setValidationError(double validationError) {
this.validationError = validationError;
}
/**
* @return the trainCount
*/
public double getTrainCount() {
return trainCount;
}
/**
* @return the validationCount
*/
public double getValidationCount() {
return validationCount;
}
/**
* @param trainCount
* the trainCount to set
*/
public void setTrainCount(double trainCount) {
this.trainCount = trainCount;
}
/**
* @param validationCount
* the validationCount to set
*/
public void setValidationCount(double validationCount) {
this.validationCount = validationCount;
}
}