/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.core.tools; import rapaio.data.Numeric; import rapaio.data.Var; import rapaio.data.VarType; import rapaio.printer.Printable; import rapaio.printer.format.TextTable; import rapaio.sys.WS; import java.io.Serializable; import java.util.Arrays; import static rapaio.math.MathTools.log2; /** * Distribution table. * <p> * Table tool class to facilitate various operations on two variables regarding frequencies. * * @author <a href="mailto:padreati@yahoo.com">Aurelian Tutuianu</a> */ public final class DTable implements Printable, Serializable { public static final String[] NUMERIC_DEFAULT_LABELS = new String[]{"?", "less-equals", "greater"}; private static final long serialVersionUID = 4359080329548577980L; private final String[] rowLevels; private final String[] colLevels; // table with frequencies private final int start; private final double[][] values; // printing info private boolean totalSummary = true; /** * Builds a table with given test columns and target columns * * @param rowLevels labels for rows * @param colLevels labels for columns * @param useFirst true if using the first row and col, false otherwise */ public static DTable empty(String[] rowLevels, String[] colLevels, boolean useFirst) { return new DTable(rowLevels, colLevels, useFirst); } // private constructors /** * Builds a density table from two nominal vectors built from counts * * @param rowVar var on vertical axis * @param colVar var on horizontal axis * @param useFirst true if using the first row and col, false otherwise */ public static DTable fromCounts(Var rowVar, Var colVar, boolean useFirst) { return new DTable(rowVar, colVar, Numeric.fill(rowVar.rowCount(), 1), useFirst); } /** * Builds a density table from two nominal vectors. * If not null, weights are used instead of counts. * * @param rowVar row var * @param colVar col var * @param weights weights used instead of counts, if not null * @param useFirst true if using the first row and col, false otherwise */ public static DTable fromWeights(Var rowVar, Var colVar, Var weights, boolean useFirst) { return new DTable(rowVar, colVar, weights, useFirst); } /** * Builds a density table with a binary split, from two nominal vectors. * The first row contains instances which have test label equal with given testLabel, * second row contains frequencies for the rest of the instances. * * @param rowVar row var * @param colVar col var * @param weights if not null, weights used instead of counts * @param rowLevel row label used for binary split * @param useFirst true if using the first row and col, false otherwise */ public static DTable binaryFromWeights(Var rowVar, Var colVar, Var weights, String rowLevel, boolean useFirst) { return new DTable(rowVar, colVar, weights, rowLevel, useFirst); } private DTable(String[] rowLevels, String[] colLevels, boolean useFirst) { this.rowLevels = rowLevels; this.colLevels = colLevels; this.start = useFirst ? 0 : 1; this.values = new double[rowLevels.length][colLevels.length]; } private DTable(Var rowVar, Var colVar, Var weights, boolean useFirst) { this(rowVar.levels(), colVar.levels(), useFirst); if (!(rowVar.type().isNominal() || rowVar.type().equals(VarType.BINARY))) throw new IllegalArgumentException("row var must be nominal"); if (!(colVar.type().isNominal() || colVar.type().equals(VarType.BINARY))) throw new IllegalArgumentException("col var is not nominal"); if (rowVar.rowCount() != colVar.rowCount()) throw new IllegalArgumentException("row and col vars must have same row count"); int rowOffset = rowVar.type().equals(VarType.BINARY) ? 1 : 0; int colOffset = colVar.type().equals(VarType.BINARY) ? 1 : 0; for (int i = 0; i < rowVar.rowCount(); i++) { update(rowVar.index(i) + rowOffset, colVar.index(i) + colOffset, weights != null ? weights.value(i) : 1); } } private DTable(Var rowVar, Var colVar, Var weights, String rowLevel, boolean useFirst) { this(new String[]{"?", rowLevel, "other"}, colVar.levels(), useFirst); if (!rowVar.type().isNominal()) throw new IllegalArgumentException("row var must be nominal"); if (!colVar.type().isNominal()) throw new IllegalArgumentException("col var is not nominal"); if (rowVar.rowCount() != colVar.rowCount()) throw new IllegalArgumentException("row and col variables must have same size"); for (int i = 0; i < rowVar.rowCount(); i++) { int index = 0; if (!rowVar.missing(i)) { index = (rowVar.label(i).equals(rowLevel)) ? 1 : 2; } update(index, colVar.index(i), weights != null ? weights.value(i) : 1); } } public boolean useFirst() { return start == 0; } public int start() { return start; } public int rowCount() { return rowLevels.length; } public int colCount() { return colLevels.length; } public String[] rowLevels() { return rowLevels; } public String[] colLevels() { return colLevels; } public DTable withTotalSummary(boolean totalSummary) { this.totalSummary = totalSummary; return this; } public double get(int row, int col) { return values[row][col]; } public void reset() { for (double[] line : values) Arrays.fill(line, 0, line.length, 0); } public void update(int row, int col, double weight) { values[row][col] += weight; } public void moveOnCol(int row1, int row2, int col, double weight) { update(row1, col, -weight); update(row2, col, weight); } public void moveOnRow(int row, int col1, int col2, double weight) { update(row, col1, -weight); update(row, col2, weight); } public double totalColEntropy() { double[] totals = new double[colLevels.length]; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { totals[j] += values[i][j]; } } double total = 0; for (int i = start; i < totals.length; i++) { total += totals[i]; } double entropy = 0; for (int i = start; i < totals.length; i++) { if (totals[i] > 0) { entropy += -log2(totals[i] / total) * totals[i] / total; } } return entropy; } public double totalRowEntropy() { double[] totals = new double[rowLevels.length]; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { totals[i] += values[i][j]; } } double total = 0; for (int i = start; i < totals.length; i++) { total += totals[i]; } double entropy = 0; for (int i = start; i < totals.length; i++) { if (totals[i] > 0) { entropy += -log2(totals[i] / total) * totals[i] / total; } } return entropy; } public double splitByRowAverageEntropy() { double[] totals = new double[rowLevels.length]; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { totals[i] += values[i][j]; } } double total = 0; for (int i = start; i < totals.length; i++) { total += totals[i]; } double gain = 0; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { if (values[i][j] > 0) gain += -log2(values[i][j] / totals[i]) * values[i][j] / total; } } return gain; } public double splitByColAverageEntropy() { double[] totals = new double[colLevels.length]; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { totals[j] += values[i][j]; } } double total = 0; for (int i = start; i < totals.length; i++) { total += totals[i]; } double gain = 0; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { if (values[i][j] > 0) gain += -log2(values[i][j] / totals[j]) * values[i][j] / total; } } return gain; } public double splitByRowInfoGain() { return totalColEntropy() - splitByRowAverageEntropy(); } public double splitByColInfoGain() { return totalRowEntropy() - splitByColAverageEntropy(); } public double splitByRowIntrinsicInfo() { double[] totals = new double[rowLevels.length]; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { totals[i] += values[i][j]; } } double total = 0; for (int i = start; i < totals.length; i++) { total += totals[i]; } double splitInfo = 0; for (int i = start; i < totals.length; i++) { if (totals[i] > 0) { splitInfo += -log2(totals[i] / total) * totals[i] / total; } } return splitInfo; } public double splitByColIntrinsicInfo() { double[] totals = new double[colLevels.length]; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { totals[j] += values[i][j]; } } double total = 0; for (int i = start; i < totals.length; i++) { total += totals[i]; } double splitInfo = 0; for (int i = start; i < totals.length; i++) { if (totals[i] > 0) { splitInfo += -log2(totals[i] / total) * totals[i] / total; } } return splitInfo; } public double splitByRowGainRatio() { return splitByRowInfoGain() / splitByRowIntrinsicInfo(); } public double splitByColGainRatio() { return splitByColInfoGain() / splitByColIntrinsicInfo(); } /** * Computes the number of columns which have totals equal or greater than minWeight * * @return number of columns which meet criteria */ public boolean hasColsWithMinimumCount(double minWeight, int minCounts) { int count = 0; for (int i = start; i < rowLevels.length; i++) { double total = 0; for (int j = 1; j < colLevels.length; j++) { total += values[i][j]; } if (total >= minWeight) { count++; if (count >= minCounts) { return true; } } } return false; } public double splitByRowGiniGain() { double[] rowTotals = new double[rowLevels.length]; double[] colTotals = new double[colLevels.length]; double total = 0.0; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { rowTotals[i] += values[i][j]; colTotals[j] += values[i][j]; total += values[i][j]; } } if (total <= 0) { return 1; } double gini = 1.0; for (int i = start; i < colLevels.length; i++) { gini -= Math.pow(colTotals[i] / total, 2); } for (int i = start; i < rowLevels.length; i++) { double gini_k = 1; for (int j = start; j < colLevels.length; j++) { if (rowTotals[i] > 0) gini_k -= Math.pow(values[i][j] / rowTotals[i], 2); } gini -= gini_k * rowTotals[i] / total; } return gini; } public double splitByColGiniGain() { double[] rowTotals = new double[rowLevels.length]; double[] colTotals = new double[colLevels.length]; double total = 0.0; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { rowTotals[i] += values[i][j]; colTotals[j] += values[i][j]; total += values[i][j]; } } if (total <= 0) { return 1; } double gini = 1.0; for (int i = start; i < rowLevels.length; i++) { gini -= Math.pow(rowTotals[i] / total, 2); } for (int i = start; i < colLevels.length; i++) { double gini_k = 1; for (int j = start; j < rowLevels.length; j++) { if (colTotals[i] > 0) gini_k -= Math.pow(values[j][i] / colTotals[i], 2); } gini -= gini_k * colTotals[i] / total; } return gini; } public double[] rowTotals() { double[] totals = new double[rowLevels.length]; for (int i = 0; i < rowLevels.length; i++) { for (int j = 0; j < colLevels.length; j++) { totals[i] += values[i][j]; } } return totals; } public double[] colTotals() { double[] totals = new double[colLevels.length]; for (int i = 0; i < rowLevels.length; i++) { for (int j = 0; j < colLevels.length; j++) { totals[j] += values[i][j]; } } return totals; } public DTable normalizeOverall() { DTable norm = DTable.empty(rowLevels, colLevels, start == 0).withTotalSummary(totalSummary); double total = 0; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { norm.values[i][j] = values[i][j]; total += values[i][j]; } } if (total > 0) { for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { norm.values[i][j] /= total; } } } return norm; } public DTable normalizeOnRows() { DTable norm = DTable.empty(rowLevels, colLevels, start == 0).withTotalSummary(totalSummary); double[] rowTotals = new double[rowLevels.length]; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { norm.values[i][j] = values[i][j]; rowTotals[i] += values[i][j]; } } for (int i = start; i < rowLevels.length; i++) { if (rowTotals[i] > 0) for (int j = start; j < colLevels.length; j++) { norm.values[i][j] /= rowTotals[i]; } } return norm; } public DTable normalizeOnCols() { DTable norm = DTable.empty(rowLevels, colLevels, start == 0).withTotalSummary(totalSummary); double[] colTotals = new double[colLevels.length]; for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { norm.values[i][j] = values[i][j]; colTotals[j] += values[i][j]; } } for (int i = start; i < colLevels.length; i++) { if (colTotals[i] > 0) for (int j = start; j < rowLevels.length; j++) { norm.values[j][i] /= colTotals[i]; } } return norm; } @Override public String summary() { if (totalSummary) { TextTable tt = TextTable.newEmpty(rowLevels.length - start + 2, colLevels.length - start + 2); tt.withHeaderRows(1); tt.withSplit(WS.getPrinter().textWidth()); for (int i = start; i < rowLevels.length; i++) { tt.set(i - start + 1, 0, rowLevels[i], 1); } for (int i = start; i < colLevels.length; i++) { tt.set(0, i - start + 1, colLevels[i], 1); } tt.set(0, colLevels.length - start + 1, "total", 1); tt.set(rowLevels.length - start + 1, 0, "total", 1); for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { tt.set(i - start + 1, j - start + 1, WS.formatShort(values[i][j]), 1); } } double[] rowTotals = rowTotals(); for (int i = start; i < rowLevels.length; i++) { tt.set(i - start + 1, colLevels.length - start + 1, WS.formatShort(rowTotals[i]), 1); } double[] colTotals = colTotals(); for (int i = start; i < colLevels.length; i++) { tt.set(rowLevels.length - start + 1, i - start + 1, WS.formatShort(colTotals[i]), 1); } double total = Arrays.stream(rowTotals).skip(start).sum(); tt.set(rowLevels.length - start + 1, colLevels.length - start + 1, WS.formatShort(total), 1); return tt.summary(); } else { TextTable tt = TextTable.newEmpty(rowLevels.length - start + 1, colLevels.length - start + 1); tt.withHeaderRows(1); tt.withSplit(WS.getPrinter().textWidth()); for (int i = start; i < rowLevels.length; i++) { tt.set(i - start + 1, 0, rowLevels[i], 1); } for (int i = start; i < colLevels.length; i++) { tt.set(0, i - start + 1, colLevels[i], 1); } for (int i = start; i < rowLevels.length; i++) { for (int j = start; j < colLevels.length; j++) { tt.set(i - start + 1, j - start + 1, WS.formatShort(values[i][j]), 1); } } return tt.summary(); } } }