/* * Apache License * Version 2.0, January 2004 * http://www.apache.org/licenses/ * * Copyright 2013 Aurelian Tutuianu * Copyright 2014 Aurelian Tutuianu * Copyright 2015 Aurelian Tutuianu * Copyright 2016 Aurelian Tutuianu * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package rapaio.core.tools; import rapaio.core.RandomSource; import rapaio.data.Numeric; import rapaio.data.Var; import rapaio.data.VarType; import rapaio.printer.Printable; import rapaio.printer.format.TextTable; import rapaio.sys.WS; import java.io.Serializable; import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.function.DoublePredicate; import java.util.stream.DoubleStream; /** * Nominal distribution vector. * <p> * Vector tool class to facilitate various operations on nominal variables regarding frequencies. * * @author <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a> */ public class DVector implements Printable, Serializable { /** * Builds a distribution vector with given levels * * @param labels used to name values * @return new empty distribution vector */ public static DVector empty(boolean useFirst, String... labels) { return new DVector(useFirst, labels); } /** * Builds a distribution vector with given dimestion. Names are generated automatically. * * @param rows size of the distribution vector * @return new empty distribution vector */ public static DVector empty(boolean useFirst, int rows) { String[] labels = new String[rows]; for (int i = 0; i < labels.length; i++) { if (i == 0) { labels[i] = useFirst ? "v0" : "?"; } else { labels[i] = "v" + i; } } return new DVector(useFirst, labels); } /** * Builds a distribution vector as a frequency table from a * given nominal variable. For each cell value it will hold * the number of appearances and will have the cell names * from the levels of the nominal value given as input. * * @param var given nominal value * @return new distribution vector filled with counts */ public static DVector fromCount(boolean useFirst, Var var) { Var weights = Numeric.fill(var.rowCount(), 1); return new DVector(useFirst, var.levels(), var, weights); } /** * Builds a distribution vector as a table with one cell for each * value in the nominal variable and as value the sum of it's * corresponding weights. * * @param var given nominal variable * @param weights given numeric weights * @return new distribution variable */ public static DVector fromWeights(boolean useFirst, Var var, Var weights) { return new DVector(useFirst, var.levels(), var, weights); } /** * Builds a new distribution vector, with given names, grouped by * the nominal variable and with values as sums on numeric weights * * @param labels levels used for names * @param var defines nominal grouping * @param weights weights used to compute sums for each cell * @return new distribution vector */ public static DVector fromWeights(boolean useFirst, Var var, Var weights, String... labels) { return new DVector(useFirst, labels, var, weights); } private static final long serialVersionUID = -546802690694348698L; private final String[] levels; private final Map<String, Integer> reverse = new HashMap<>(); private final double[] values; private boolean useFirst; private int start; private double total; private DVector(boolean useFirst, String[] labels) { this.useFirst = useFirst; this.start = useFirst ? 0 : 1; this.levels = labels; for (int i = 0; i < labels.length; i++) { reverse.put(labels[i], i); } this.values = new double[this.levels.length]; } private DVector(boolean useFirst, String[] labels, Var var, Var weights) { this(useFirst, labels); int off = var.type().equals(VarType.BINARY) ? 1 : 0; var.stream().forEach(s -> values[s.index() + off] += weights.value(s.row())); total = Arrays.stream(values).sum(); } public boolean first() { return useFirst; } public DVector withFirst(boolean useFirst) { this.useFirst = useFirst; this.start = useFirst ? 0 : 1; return this; } public String[] levels() { return levels; } /** * Getter for the value from a given position * * @param pos position of the value * @return value from the give position */ public double get(int pos) { return values[pos]; } public double get(String name) { return get(reverse.get(name)); } public String label(int pos) { return levels[pos]; } /** * Updates the value from the given position {@param pos} by adding the {@param value} * * @param pos position of the denity vector to be updated * @param value value to be added to given cell */ public void increment(int pos, double value) { values[pos] += value; total += value; } /** * Updates the value from the given position {@param pos} by adding the {@param value} * * @param dv density vector which will be added * @param factor the factor used to multiply added density vector with */ public void increment(DVector dv, double factor) { for (int i = 0; i < values.length; i++) { values[i] += dv.get(i) * factor; total += dv.get(i) * factor; } } public void increment(String name, double value) { increment(reverse.get(name), value); } public void increment(DVector dv) { if (values.length != dv.values.length) throw new IllegalArgumentException("Cannot update density vector, row count is different"); for (int i = 0; i < values.length; i++) { values[i] += dv.values[i]; } total += dv.total; } /** * Setter for the value from a given position * * @param pos position of the value * @param value value to be set at the given position */ public void set(int pos, double value) { total += value - values[pos]; values[pos] = value; } public void set(String name, double value) { set(reverse.get(name), value); } /** * Find the index of the greatest value from all cells, including eventual missing label. * If there are multiple maximal values, one at random is chosen * * @return index of the greatest value */ public int findBestIndex() { double n = 1; int bestIndex = start; double best = values[start]; for (int i = start + 1; i < values.length; i++) { if (values[i] > best) { best = values[i]; bestIndex = i; n = 1; continue; } if (values[i] == best) { if (RandomSource.nextDouble() > n / (n + 1)) { best = values[i]; bestIndex = i; } n++; } } return bestIndex; } /** * Normalize values from density vector */ public DVector normalize() { total = 0.0; for (int i = start; i < values.length; i++) { total += values[i]; } if (total == 0) return this; for (int i = start; i < values.length; i++) { values[i] /= total; } total = 1.0; return this; } /** * Computes the sum of all cells. * Missing cell might be used or not, * * @return sum of elements */ public double sum() { return useFirst ? total : total - values[0]; } /** * Computes the sum of all cells except a given one and eventually the missing value cell. * * @param except the cell excepted from computation * @return partial sum of cells */ public double sumExcept(int except) { return sum() - values[except]; } /** * Count values which respects the condition given by the predicate. * * @param pred condition used to filter the values * @return count of filtered values */ public int countValues(DoublePredicate pred) { int count = 0; for (int i = start; i < values.length; i++) { if (pred.test(values[i])) { count++; } } return count; } public int rowCount() { return levels.length; } /** * Builds a solid copy of the distribution vector * * @return a solid copy of distribution vector */ public DVector solidCopy() { DVector d = new DVector(useFirst, levels); System.arraycopy(values, 0, d.values, 0, levels.length); d.total = total; return d; } /** * @return index of the first cell, is 1 if missing cell exists and {@param useMissing} exists, 0 otherwise */ public int start() { return start; } public DoubleStream streamValues() { return Arrays.stream(values); } @Override public String toString() { return "DVector{" + "levels=" + Arrays.toString(levels) + ", values=" + Arrays.toString(values) + ", total=" + total + '}'; } public boolean equalsFull(DVector o) { if (levels.length - start != o.levels.length - o.start) { return false; } if (values.length - start != o.values.length - o.start) { return false; } for (int i = 0; i < levels.length - start; i++) { if (!levels[i + start].equals(o.levels[i + o.start])) return false; } for (int i = 0; i < values.length - start; i++) { if (Math.abs(values[i + start] - o.values[i + o.start]) > 1e-30) { return false; } } return true; } @Override public String summary() { TextTable tt = TextTable.newEmpty(3, levels.length); for (int i = start; i < levels.length; i++) { tt.set(0, i, levels[i], 1); tt.set(1, i, repeat(levels[i].length(), '-'), 1); tt.set(2, i, WS.formatShort(values[i]), 1); } return tt.summary(); } private String repeat(int len, char ch) { char[] lineChars = new char[len]; for (int i = 0; i < len; i++) { lineChars[i] = ch; } return String.valueOf(lineChars); } }