/* * Copyright [2013-2015] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.shifu.core.dtrain.dt; import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.util.Set; import ml.shifu.guagua.io.Bytable; /** * Split for Both continuous and categorical features. * * <p> * For continuous feature, only a double threshold can be used to split a variable into two splits. While for * categorical features, we only store left node category list, check if in left category list to determine which split. * * @author Zhang David (pengzhang@paypal.com) * * @see FeatureType */ public class Split implements Bytable { /** * Column number in ColumnConfig.json */ private int columnNum; /** * CONTINUOUS or CATEGORICAL, should not be null */ private FeatureType featureType; /** * For CONTINUOUS feature, this should be valid value to split feature */ private double threshold; /** * For categorical feature, if isLeft = true, {@link #leftOrRightCategories} stores left categories. If false, * {@link #leftOrRightCategories} stores right categories. */ private boolean isLeft = true;; /** * Indexes of left categories or right categories, list of categories will be saved in model files or in * TreeModel as short indexes to save space, short is safe so far as max bin size is limit to Short.MAX_VALUE. */ private Set<Short> leftOrRightCategories; public Split() { } public Split(int columnNum, FeatureType featureType, double threshold, boolean isLeft, Set<Short> leftOrRightCategories) { this.columnNum = columnNum; this.featureType = featureType; this.threshold = threshold; this.isLeft = isLeft; this.leftOrRightCategories = leftOrRightCategories; } /** * @return the featureIndex */ public int getColumnNum() { return columnNum; } /** * @return the featureType */ public FeatureType getFeatureType() { return featureType; } /** * @return the threshold */ public double getThreshold() { return threshold; } /** * @return the leftCategories */ public Set<Short> getLeftOrRightCategories() { return leftOrRightCategories; } /** * @param columnNum * the columnNum to set */ public void setColumnNum(int columnNum) { this.columnNum = columnNum; } /** * @param featureType * the featureType to set */ public void setFeatureType(FeatureType featureType) { this.featureType = featureType; } /** * @param threshold * the threshold to set */ public void setThreshold(double threshold) { this.threshold = threshold; } /** * @param leftCategories * the leftCategories to set */ public void setLeftOrRightCategories(Set<Short> leftCategories) { this.leftOrRightCategories = leftCategories; } /** * @return the isLeft */ public boolean isLeft() { return isLeft; } /** * @param isLeft * the isLeft to set */ public void setLeft(boolean isLeft) { this.isLeft = isLeft; } @Override public void write(DataOutput out) throws IOException { out.writeInt(this.columnNum); // use byte type to save space, should not be null out.writeByte(this.featureType.getByteType()); switch(this.featureType) { case CATEGORICAL: out.writeBoolean(this.isLeft); if(leftOrRightCategories == null) { out.writeBoolean(true); } else { out.writeBoolean(false); if(leftOrRightCategories instanceof Bytable) { ((Bytable) leftOrRightCategories).write(out); } } break; case CONTINUOUS: out.writeDouble(this.threshold); break; } } @Override public void readFields(DataInput in) throws IOException { this.columnNum = in.readInt(); this.featureType = FeatureType.of(in.readByte()); switch(this.featureType) { case CATEGORICAL: this.isLeft = in.readBoolean(); boolean isNull = in.readBoolean(); if(isNull) { leftOrRightCategories = null; } else { leftOrRightCategories = new SimpleBitSet<Short>(); ((Bytable) leftOrRightCategories).readFields(in); } break; case CONTINUOUS: this.threshold = in.readDouble(); break; } } @Override public String toString() { return "Split [featureIndex=" + columnNum + ", featureType=" + featureType + ", threshold=" + threshold + ", leftCategories=" + leftOrRightCategories + "]"; } }