package com.github.lwhite1.tablesaw.examples; import com.github.lwhite1.tablesaw.api.CategoryColumn; import com.github.lwhite1.tablesaw.api.ColumnType; import com.github.lwhite1.tablesaw.api.FloatColumn; import com.github.lwhite1.tablesaw.api.IntColumn; import com.github.lwhite1.tablesaw.api.ShortColumn; import com.github.lwhite1.tablesaw.api.Table; import com.github.lwhite1.tablesaw.api.ml.classification.ConfusionMatrix; import com.github.lwhite1.tablesaw.api.ml.classification.LogisticRegression; import com.github.lwhite1.tablesaw.io.csv.CsvReader; import static com.github.lwhite1.tablesaw.api.ColumnType.FLOAT; import static com.github.lwhite1.tablesaw.api.ColumnType.INTEGER; /** * */ public class SfCrimeTest { public static void main(String[] args) throws Exception { Table crime = Table.createFromCsv("/Users/larrywhite/IdeaProjects/testdata/bigdata/train.csv"); out(crime.shape()); out(crime.structure().print()); crime.removeColumns("DayOfWeek"); IntColumn precinct = crime.categoryColumn("PdDistrict").toIntColumn(); precinct.setName("Precinct"); crime.addColumn(precinct); ShortColumn year = crime.dateTimeColumn("Dates").year(); year.setName("Year"); crime.addColumn(year); CategoryColumn category = crime.categoryColumn("Category"); Table categorySummary = category.summary().sortDescendingOn("Count"); out(categorySummary.print()); ShortColumn minuteOfDay = crime.dateTimeColumn("Dates").minuteOfDay(); minuteOfDay.setName("MinuteOfDay"); crime.addColumn(minuteOfDay); ShortColumn dayOfYear = crime.dateTimeColumn("Dates").dayOfYear(); dayOfYear.setName("DayOfYear"); crime.addColumn(dayOfYear); ShortColumn dayOfWeekValue = crime.dateTimeColumn("Dates").dayOfWeekValue(); dayOfWeekValue.setName("DayOfWeek"); crime.addColumn(dayOfWeekValue); Table[] subTables = crime.sampleSplit(.1); Table train = subTables[0]; Table test = subTables[1]; out(CsvReader.printColumnTypes("/Users/larrywhite/IdeaProjects/testdata/bigdata/sampleSubmission.csv", true, ',')); LogisticRegression model = LogisticRegression.learn( train.categoryColumn("Category"), 0.1, 1.0E-3, 700, train.nCol("X"), train.nCol("Y"), train.nCol("MinuteOfDay"), train.nCol("DayOfYear"), train.nCol("DayOfWeek"), train.nCol("Year"), train.nCol("Precinct")); out("Model trained"); ConfusionMatrix matrix = model.predictMatrix(test.categoryColumn("Category"), test.nCol("X"), test.nCol("Y"), test.nCol("MinuteOfDay"), test.nCol("DayOfYear"), test.nCol("DayOfWeek"), test.nCol("Year"), test.nCol("Precinct")); out(matrix.accuracy()); out(matrix.toTable().print()); // Table trueCrime = Table.createFromCsv("/Users/larrywhite/IdeaProjects/testdata/bigdata/test.csv"); // out(CsvReader.printColumnTypes("/Users/larrywhite/IdeaProjects/testdata/bigdata/sampleSubmission.csv", // true, ',')); ColumnType[] columnTypes = { INTEGER, // 0 Id FLOAT, // 1 ARSON FLOAT, // 2 ASSAULT FLOAT, // 3 BAD CHECKS FLOAT, // 4 BRIBERY FLOAT, // 5 BURGLARY FLOAT, // 6 DISORDERLY CONDUCT FLOAT, // 7 DRIVING UNDER THE INFLUENCE FLOAT, // 8 DRUG/NARCOTIC FLOAT, // 9 DRUNKENNESS FLOAT, // 10 EMBEZZLEMENT FLOAT, // 11 EXTORTION FLOAT, // 12 FAMILY OFFENSES FLOAT, // 13 FORGERY/COUNTERFEITING FLOAT, // 14 FRAUD FLOAT, // 15 GAMBLING FLOAT, // 16 KIDNAPPING FLOAT, // 17 LARCENY/THEFT FLOAT, // 18 LIQUOR LAWS FLOAT, // 19 LOITERING FLOAT, // 20 MISSING PERSON FLOAT, // 21 NON-CRIMINAL FLOAT, // 22 OTHER OFFENSES FLOAT, // 23 PORNOGRAPHY/OBSCENE MAT FLOAT, // 24 PROSTITUTION FLOAT, // 25 RECOVERED VEHICLE FLOAT, // 26 ROBBERY FLOAT, // 27 RUNAWAY FLOAT, // 28 SECONDARY CODES FLOAT, // 29 SEX OFFENSES FORCIBLE FLOAT, // 30 SEX OFFENSES NON FORCIBLE FLOAT, // 31 STOLEN PROPERTY FLOAT, // 32 SUICIDE FLOAT, // 33 SUSPICIOUS OCC FLOAT, // 34 TREA FLOAT, // 35 TRESPASS FLOAT, // 36 VANDALISM FLOAT, // 37 VEHICLE THEFT FLOAT, // 38 WARRANTS FLOAT, // 39 WEAPON LAWS }; Table results = Table.createFromCsv(columnTypes, "/Users/larrywhite/IdeaProjects/testdata/bigdata/sampleSubmission.csv"); FloatColumn larceny = results.floatColumn("LARCENY/THEFT"); FloatColumn warrants = results.floatColumn("WARRANTS"); Table trueCrime = testData(); for (int row : trueCrime) { double[] posteriori = new double[39]; model.predictFromModel(row, posteriori, test.nCol("X"), test.nCol("Y"), test.nCol("MinuteOfDay"), test.nCol("DayOfYear"), test.nCol("DayOfWeek"), test.nCol("Year"), test.nCol("Precinct") ); larceny.set(row, 1.0f); warrants.set(row, 0f); } results.exportToCsv("newSubmission.csv"); } private static Table testData() throws Exception { // Setup actual test data Table trueCrime = Table.createFromCsv("/Users/larrywhite/IdeaProjects/testdata/bigdata/test.csv"); trueCrime.removeColumns("DayOfWeek"); IntColumn precinctT = trueCrime.categoryColumn("PdDistrict").toIntColumn(); precinctT.setName("Precinct"); trueCrime.addColumn(precinctT); ShortColumn yearT = trueCrime.dateTimeColumn("Dates").year(); yearT.setName("Year"); trueCrime.addColumn(yearT); ShortColumn minuteOfDayT = trueCrime.dateTimeColumn("Dates").minuteOfDay(); minuteOfDayT.setName("MinuteOfDay"); trueCrime.addColumn(minuteOfDayT); ShortColumn dayOfYearT = trueCrime.dateTimeColumn("Dates").dayOfYear(); dayOfYearT.setName("DayOfYear"); trueCrime.addColumn(dayOfYearT); ShortColumn dayOfWeekValueT = trueCrime.dateTimeColumn("Dates").dayOfWeekValue(); dayOfWeekValueT.setName("DayOfWeek"); trueCrime.addColumn(dayOfWeekValueT); return trueCrime; } private static void out(Object str) { System.out.println(String.valueOf(str)); } }