package com.github.lwhite1.tablesaw.api.ml.regression; import com.github.lwhite1.tablesaw.api.IntColumn; import com.github.lwhite1.tablesaw.api.NumericColumn; import com.github.lwhite1.tablesaw.api.Table; import com.github.lwhite1.tablesaw.api.plot.Histogram; import com.github.lwhite1.tablesaw.api.plot.Scatter; import com.github.lwhite1.tablesaw.columns.Column; import static com.github.lwhite1.tablesaw.api.QueryHelper.column; /** * An example doing ordinary least squares regression */ public class MoneyballExample { public static void main(String[] args) throws Exception { // Get the data Table baseball = Table.createFromCsv("data/baseball.csv"); out(baseball.structure().print()); // filter to the data available in the 2002 season Table moneyball = baseball.selectWhere(column("year").isLessThan(2002)); // plot regular season wins against year, segregating on whether the team made the plays NumericColumn wins = moneyball.numericColumn("W"); NumericColumn year = moneyball.numericColumn("Year"); Column playoffs = moneyball.column("Playoffs"); Scatter.show("Regular season wins by year", wins, year, moneyball.splitOn(playoffs)); // Calculate the run difference for use in the regression model IntColumn runDifference = moneyball.shortColumn("RS").subtract(moneyball.shortColumn("RA")); moneyball.addColumn(runDifference); runDifference.setName("RD"); // Plot RD vs Wins to see if the relationship looks linear Scatter.show("RD x Wins", moneyball.numericColumn("RD"), moneyball.numericColumn("W")); // Create the regression model //ShortColumn wins = moneyball.shortColumn("W"); LeastSquares winsModel = LeastSquares.train(wins, runDifference); out(winsModel); // Make a prediction of how many games we win if we score 135 more runs than our opponents double[] testValue = new double[1]; testValue[0] = 135; double prediction = winsModel.predict(testValue); out("Predicted wins with RD = 135: " + prediction); // Predict runsScored based on On-base percentage, batting average and slugging percentage LeastSquares runsScored = LeastSquares.train(moneyball.nCol("RS"), moneyball.nCol("OBP"), moneyball.nCol("BA"), moneyball.nCol("SLG")); out(runsScored); LeastSquares runsScored2 = LeastSquares.train(moneyball.nCol("RS"), moneyball.nCol("OBP"), moneyball.nCol("SLG")); out(runsScored2); Histogram.show(runsScored2.residuals()); Scatter.fittedVsResidual(runsScored2); Scatter.actualVsFitted(runsScored2); // We use opponent OBP and opponent SLG to model the efficacy of our pitching and defence Table moneyball2 = moneyball.selectWhere(column("year").isGreaterThan(1998)); LeastSquares runsAllowed = LeastSquares.train(moneyball2.nCol("RA"), moneyball2.nCol("OOBP"), moneyball2.nCol("OSLG")); out(runsAllowed); } private static void out(Object o) { System.out.println(String.valueOf(o)); } }