package com.github.lwhite1.tablesaw.api.ml; import com.github.lwhite1.tablesaw.api.BooleanColumn; import com.github.lwhite1.tablesaw.api.ColumnType; import com.github.lwhite1.tablesaw.api.Table; import com.github.lwhite1.tablesaw.api.ml.classification.LogisticRegression; import com.github.lwhite1.tablesaw.api.plot.Bar; import com.github.lwhite1.tablesaw.api.plot.Pareto; import com.github.lwhite1.tablesaw.reducing.NumericSummaryTable; import com.github.lwhite1.tablesaw.store.StorageManager; import com.google.common.base.Stopwatch; import java.util.concurrent.TimeUnit; import static com.github.lwhite1.tablesaw.api.ColumnType.*; import static com.github.lwhite1.tablesaw.api.QueryHelper.*; import static com.github.lwhite1.tablesaw.reducing.NumericReduceUtils.mean; import static java.lang.System.out; /** * */ public class AirlineDelays { private static Table flt2007; public static void main(String[] args) throws Exception { new AirlineDelays(); } private AirlineDelays() throws Exception { Stopwatch stopwatch = Stopwatch.createStarted(); out.println("loading"); ColumnType[] columnTypes = { SHORT_INT, // 0 Year SHORT_INT, // 1 Month SHORT_INT, // 2 DayofMonth SHORT_INT, // 3 DayOfWeek LOCAL_TIME, // 4 DepTime LOCAL_TIME, // 5 CRSDepTime LOCAL_TIME, // 6 ArrTime LOCAL_TIME, // 7 CRSArrTime CATEGORY, // 8 UniqueCarrier SHORT_INT, // 9 FlightNum CATEGORY, // 10 TailNum SHORT_INT, // 11 ActualElapsedTime SHORT_INT, // 12 CRSElapsedTime SHORT_INT, // 13 AirTime SHORT_INT, // 14 ArrDelay SHORT_INT, // 15 DepDelay CATEGORY, // 16 Origin CATEGORY, // 17 Dest SHORT_INT, // 18 Distance SHORT_INT, // 19 TaxiIn SHORT_INT, // 20 TaxiOut SHORT_INT, // 21 Cancelled CATEGORY, // 22 CancellationCode SHORT_INT, // 23 Diverted SHORT_INT, // 24 CarrierDelay SHORT_INT, // 25 WeatherDelay SHORT_INT, // 26 NASDelay SHORT_INT, // 27 SecurityDelay SHORT_INT, // 28 LateAircraftDelay }; // flt2007 = Table.createFromCsv(columnTypes, "/Users/larrywhite/Downloads/flight delays/2007.csv"); //String tableName = StorageManager.saveTable("bigdata", flt2007); //out("Wrote to saw store " + tableName); flt2007 = StorageManager.readTable("bigdata/2007.csv.saw"); out.println(String.format("loaded %d records in %d seconds", flt2007.rowCount(), (int) stopwatch.elapsed(TimeUnit.SECONDS))); out(flt2007.shape()); Table ord = flt2007.selectWhere( both(column("Origin").isEqualTo("ORD"), column("DepDelay").isNotMissing())); BooleanColumn delayed = ord.selectIntoColumn("Delayed?", column("DepDelay").isGreaterThanOrEqualTo(15)); ord.addColumn(delayed); out("total flights: " + ord.rowCount()); out("total delays: " + delayed.countTrue()); // Compute average number of delayed flights per month NumericSummaryTable monthGroup = ord.summarize("DepDelay", mean).by("Month"); Bar.show("Departure delay by month", monthGroup); NumericSummaryTable dayOfWeekGroup = ord.summarize("DepDelay", mean).by("DayOfWeek"); Bar.show("Departure delay by day-of-week", dayOfWeekGroup); ord.addColumn(ord.timeColumn("CRSDepTime").hour()); NumericSummaryTable hourGroup = ord.summarize("DepDelay", mean).by("CRSDepTime[hour]"); Bar.show("Departure delay by hour-of-day", hourGroup); // Compute average number of delayed flights per carrier NumericSummaryTable carrierGroup = ord.mean("DepDelay").by("UniqueCarrier"); Pareto.show("Departure delay by Carrier", carrierGroup); // we have no cancelled flights because we removed them earlier by filtering where delay is missing; out(ord.shape()); double lambda = 0.1; LogisticRegression logit = LogisticRegression.learn( ord.booleanColumn("Delayed?"), ord.nCol("dayOfWeek"), ord.nCol("CRSDepTime[hour]")); out(logit.toString()); } private static void out(Object obj) { System.out.println(String.valueOf(obj)); } }