/*
* Apache License
* Version 2.0, January 2004
* http://www.apache.org/licenses/
*
* Copyright 2013 Aurelian Tutuianu
* Copyright 2014 Aurelian Tutuianu
* Copyright 2015 Aurelian Tutuianu
* Copyright 2016 Aurelian Tutuianu
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package rapaio.ml.regression;
import rapaio.core.CoreTools;
import rapaio.data.Frame;
import rapaio.data.Numeric;
import rapaio.data.SolidFrame;
import rapaio.printer.Printable;
import rapaio.printer.format.TextTable;
import java.util.*;
import java.util.stream.Collectors;
import static java.util.Collections.nCopies;
/**
* Result of a regression fit.
* <p>
* Created by <a href="mailto:padreati@yahoo.com>Aurelian Tutuianu</a> on 11/20/14.
*/
public class RFit implements Printable {
private final Regression model;
private final Frame df;
private final boolean withResiduals;
private final Map<String, Numeric> fit;
private final Map<String, Numeric> residuals;
private final Map<String, Double> tss;
private final Map<String, Double> ess;
private final Map<String, Double> rss;
private final Map<String, Double> rsquare;
// static builder
protected RFit(final Regression model, final Frame df, final boolean withResiduals) {
this.df = df;
this.model = model;
this.withResiduals = withResiduals;
this.fit = new HashMap<>();
this.residuals = new HashMap<>();
this.tss = new HashMap<>();
this.ess = new HashMap<>();
this.rss = new HashMap<>();
this.rsquare = new HashMap<>();
for (String targetName : model.targetNames()) {
fit.put(targetName, Numeric.empty(df.rowCount()).withName(targetName));
residuals.put(targetName, Numeric.empty(df.rowCount()).withName(targetName + "-residual"));
tss.put(targetName, Double.NaN);
ess.put(targetName, Double.NaN);
rss.put(targetName, Double.NaN);
rsquare.put(targetName, Double.NaN);
}
}
// private constructor
public static RFit build(Regression model, Frame df, boolean withResiduals) {
return new RFit(model, df, withResiduals);
}
public Frame getFrame() {
return df;
}
public boolean isWithResiduals() {
return withResiduals;
}
/**
* Returns target variables built at learning time
*
* @return target variable names
*/
public String[] targetNames() {
return model.targetNames();
}
/**
* Returns first target variable built at learning time
*
* @return target variable names
*/
public String firstTargetName() {
return model.firstTargetName();
}
/**
* Returns predicted target fit for each target variable name
*
* @return map with numeric variables as predicted values
*/
public Map<String, Numeric> fitMap() {
return fit;
}
/**
* Returns predicted target fit for each target variable name
*
* @return frame with fitted variables as columns
*/
public Frame fitFrame() {
return SolidFrame.byVars(Arrays.stream(targetNames()).map(fit::get).collect(Collectors.toList()));
}
/**
* Returns fitted target var for first target variable name
*
* @return numeric variable with predicted values
*/
public Numeric firstFit() {
return fit.get(firstTargetName());
}
/**
* Returns fitted target values for given target variable name
*
* @param targetVar given target variable name
* @return numeric variable with predicted values
*/
public Numeric fit(String targetVar) {
return fit.get(targetVar);
}
public Map<String, Numeric> residualsMap() {
return residuals;
}
public Frame residualsFrame() {
return SolidFrame.byVars(Arrays.stream(targetNames()).map(residuals::get).collect(Collectors.toList()));
}
public Numeric firstResidual() {
return residuals.get(firstTargetName());
}
public Numeric residual(String targetVar) {
return residuals.get(targetVar);
}
public void buildComplete() {
if (withResiduals) {
for (String target : targetNames()) {
for (int i = 0; i < df.rowCount(); i++) {
residuals.get(target).setValue(i, df.var(target).value(i) - fit(target).value(i));
}
double mu = CoreTools.mean(df.var(target)).value();
double tssValue = 0;
double essValue = 0;
double rssValue = 0;
for (int i = 0; i < df.rowCount(); i++) {
tssValue += Math.pow(df.var(target).value(i) - mu, 2);
essValue += Math.pow(fit(target).value(i) - mu, 2);
rssValue += Math.pow(df.var(target).value(i) - fit(target).value(i), 2);
}
tss.put(target, tssValue);
ess.put(target, essValue);
rss.put(target, rssValue);
rsquare.put(target, 1 - rssValue / tssValue);
}
}
}
@Override
public String summary() {
StringBuilder sb = new StringBuilder();
sb.append("Regression Fit Summary").append("\n");
sb.append("=======================\n");
sb.append("\n");
sb.append("Model type: ").append(model.name()).append("\n");
sb.append("Model instance: ").append(model.fullName()).append("\n");
sb.append("Predicted frame summary:\n");
sb.append("> rows: ").append(df.rowCount()).append("\n");
sb.append("> vars: ").append(df.varCount()).append("\n");
sb.append("\n");
// inputs
sb.append("> input variables: \n");
TextTable tt = TextTable.newEmpty(model.inputNames().length + 2, 3);
tt.set(0, 0, "no", 1);
tt.set(1, 0, "--", 1);
tt.set(0, 1, "name", -1);
tt.set(1, 1, "----", -1);
tt.set(0, 2, "type", 1);
tt.set(1, 2, "----", 1);
for (int i = 0; i < model.inputNames().length; i++) {
tt.set(i + 2, 0, String.valueOf(i + 1), 1);
tt.set(i + 2, 1, model.inputName(i), -1);
tt.set(i + 2, 2, model.inputType(i).code(), -1);
}
tt.withHeaderRows(2);
tt.withMerge();
sb.append(tt.summary()).append("\n");
// targets
sb.append("> target variables: \n");
tt = TextTable.newEmpty(model.inputNames().length + 2, 3);
tt.set(0, 0, "no", 1);
tt.set(1, 0, "--", 1);
tt.set(0, 1, "name", -1);
tt.set(1, 1, "----", -1);
tt.set(0, 2, "type", 1);
tt.set(1, 2, "----", 1);
for (int i = 0; i < model.targetNames().length; i++) {
tt.set(i + 2, 0, String.valueOf(i + 1), 1);
tt.set(i + 2, 1, model.targetName(i), -1);
tt.set(i + 2, 2, model.targetType(i).code(), -1);
}
tt.withHeaderRows(2);
tt.withMerge();
sb.append(tt.summary()).append("\n");
sb.append("\n");
for (String target : model.targetNames()) {
sb.append("Fit and residuals for ").append(target).append("\n");
sb.append("======================")
.append(String.join("", nCopies(target.length(), "="))).append("\n");
String fullSummary = SolidFrame.byVars(fit(target), residual(target)).summary();
List<String> list = Arrays.stream(fullSummary.split("\n")).skip(8).collect(Collectors.toList());
int pos = 0;
for (String line : list) {
pos++;
if (line.trim().isEmpty()) {
break;
}
}
sb.append(list.stream().collect(Collectors.joining("\n", "", "\n")));
double max = Math.max(Math.max(tss.get(target), ess.get(target)), rss.get(target));
int dec = 1 + 3;
while (max > 1) {
dec++;
max /= 10;
}
sb.append(String.format("Total sum of squares (TSS) : %" + dec + ".3f\n", tss.get(target)));
sb.append(String.format("Explained sum of squares (ESS) : %" + dec + ".3f\n", ess.get(target)));
sb.append(String.format("Residual sum of squares (RSS) : %" + dec + ".3f\n", rss.get(target)));
sb.append("\n");
sb.append(String.format("Coeff. of determination (R^2) : %" + dec + ".3f\n", 1 - rss.get(target) / tss.get(target)));
}
return sb.toString();
}
}