package com.bahadirakin.ml.service;
import com.bahadirakin.ml.dto.CompanyInfo;
import com.bahadirakin.ml.dto.CompanyPrediction;
import com.bahadirakin.ml.dto.RiskStatus;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.Collections;
@Service
public class QualitativeBankruptcyServiceImpl implements QualitativeBankruptcyService {
private final JavaSparkContext javaSparkContext;
private final LogisticRegressionModel logisticRegressionModel;
@Autowired
public QualitativeBankruptcyServiceImpl(JavaSparkContext javaSparkContext,
LogisticRegressionModel logisticRegressionModel) {
this.javaSparkContext = javaSparkContext;
this.logisticRegressionModel = logisticRegressionModel;
}
@Override
public CompanyPrediction predict(CompanyInfo companyInfo) {
final JavaRDD<Vector> normalizedCompanyInfo = javaSparkContext
.parallelize(Collections.singletonList(companyInfo))
.map(info -> Vectors.dense( // Order is important!
normalizeFeature(companyInfo.getIndustrialRisk()),
normalizeFeature(companyInfo.getManagementRisk()),
normalizeFeature(companyInfo.getFinancialFlexibility()),
normalizeFeature(companyInfo.getCredibility()),
normalizeFeature(companyInfo.getCompetitiveness()),
normalizeFeature(companyInfo.getOperatingRisk())
));
final double prediction = logisticRegressionModel.predict(normalizedCompanyInfo).first();
return deNormalizeResult(prediction);
}
private static double normalizeFeature(RiskStatus riskStatus) {
if (riskStatus == RiskStatus.POSITIVE) return 1.0;
if (riskStatus == RiskStatus.AVERAGE) return 0.0;
if (riskStatus == RiskStatus.NEGATIVE) return -1.0;
throw new IllegalArgumentException("Unexpected riskStatus: " + riskStatus);
}
private static CompanyPrediction deNormalizeResult(double result) {
if (result == 1.0) return CompanyPrediction.NON_BANKRUPTCY;
if (result == 0.0) return CompanyPrediction.BANKRUPTCY;
throw new IllegalArgumentException("Unexpected prediction result: " + result);
}
}