package au.gov.amsa.spark.ais; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.DenseVector; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.tree.model.DecisionTreeModel; public class AnchoredPredictor { private DecisionTreeModel model; public AnchoredPredictor(JavaSparkContext sc) { String dataPath = AnchoredPredictor.class.getResource("/anchoredOrMooredModel").toString(); model = DecisionTreeModel.load(sc.sc(), dataPath); } public static enum Status { OTHER, MOORED, ANCHORED; } public Status predict(double lat, double lon, double speedKnots, double courseMinusHeading, double preEffectiveSpeedKnots, double preError, double postEffectiveSpeedKnots, double postError) { Vector features = new DenseVector(new double[] { lat, lon, speedKnots, courseMinusHeading, preEffectiveSpeedKnots, preError, postEffectiveSpeedKnots, postError }); double prediction = model.predict(features); if (is(prediction, 1)) return Status.MOORED; else if (is(prediction, 2)) return Status.ANCHORED; else return Status.OTHER; } private static boolean is(double a, double b) { return Math.abs(a - b) < 0.0001; } }