package com.bahadirakin.ml;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Tuple2;
public class QualitativeBankruptcyModelGenerator {
public static void main(String[] args) {
SparkConf conf = new SparkConf()
.setAppName("QualitativeBankruptcyModelGenerator")
.setMaster("local");
final SparkContext sparkContext = new SparkContext(conf);
final JavaSparkContext sc = new JavaSparkContext(sparkContext);
final String path = "src/main/resources/Qualitative_Bankruptcy.data.txt";
final JavaRDD<String> textFile = sc.textFile(path, 1);
System.out.println("Data Count: " + textFile.count());
final JavaRDD<LabeledPoint> data = textFile.map(line -> {
final String[] split = line.split(",");
final double label = normalizeLabel(split[split.length - 1]);
final double[] doubles = new double[split.length - 1];
for (int i = 0; i < split.length - 1; i++) {
doubles[i] = normalizeFeature(split[i]);
}
final Vector features = Vectors.dense(doubles);
return new LabeledPoint(label, features);
});
data.take(10).forEach(System.out::println);
// Split initial RDD into two... [60% training data, 40% testing data].
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L);
JavaRDD<LabeledPoint> training = splits[0].cache();
JavaRDD<LabeledPoint> test = splits[1];
// Run training algorithm to build the model.
final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
.setNumClasses(2)
.run(training.rdd());
// Compute raw scores on the test set.
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(p -> {
Double prediction = model.predict(p.features());
return new Tuple2<>(prediction, p.label());
});
// Get evaluation metrics.
final MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
System.out.println("Accuracy = " + metrics.accuracy());
// Save model
final String modelLink = "target/model/Qualitative_Bankruptcy_Model";
model.save(sparkContext, modelLink);
sparkContext.stop();
sc.stop();
}
private static double normalizeFeature(String data) {
if ("P".equals(data)) return 1.0;
if ("A".equals(data)) return 0.0;
if ("N".equals(data)) return -1.0;
throw new IllegalArgumentException("Unexpected data: " + data);
}
private static double normalizeLabel(String data) {
if ("NB".equals(data)) return 1.0;
if ("B".equals(data)) return 0.0;
throw new IllegalArgumentException("Unexpected data: " + data);
}
}