/* * Copyright 2015 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.optaplanner.examples.investment.solver.score; import java.util.HashMap; import java.util.List; import java.util.Map; import org.optaplanner.core.api.score.Score; import org.optaplanner.core.api.score.buildin.hardsoftlong.HardSoftLongScore; import org.optaplanner.core.impl.score.director.incremental.AbstractIncrementalScoreCalculator; import org.optaplanner.examples.investment.domain.AssetClassAllocation; import org.optaplanner.examples.investment.domain.InvestmentSolution; import org.optaplanner.examples.investment.domain.Region; import org.optaplanner.examples.investment.domain.Sector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class InvestmentIncrementalScoreCalculator extends AbstractIncrementalScoreCalculator<InvestmentSolution> { protected final transient Logger logger = LoggerFactory.getLogger(getClass()); private InvestmentSolution solution; private long squaredStandardDeviationFemtosMaximum; private long squaredStandardDeviationFemtos; private Map<Region, Long> regionQuantityTotalMap; private Map<Sector, Long> sectorQuantityTotalMap; private long hardScore; private long softScore; // ************************************************************************ // Lifecycle methods // ************************************************************************ @Override public void resetWorkingSolution(InvestmentSolution solution) { this.solution = solution; squaredStandardDeviationFemtosMaximum = solution.getParametrization() .calculateSquaredStandardDeviationFemtosMaximum(); squaredStandardDeviationFemtos = 0L; List<Region> regionList = solution.getRegionList(); regionQuantityTotalMap = new HashMap<>(); for (Region region : regionList) { regionQuantityTotalMap.put(region, 0L); } List<Sector> sectorList = solution.getSectorList(); sectorQuantityTotalMap = new HashMap<>(sectorList.size()); for (Sector sector : sectorList) { sectorQuantityTotalMap.put(sector, 0L); } hardScore = 0L; softScore = 0L; for (AssetClassAllocation allocation : solution.getAssetClassAllocationList()) { insertQuantityMillis(allocation, true); } } @Override public void beforeEntityAdded(Object entity) { // Do nothing } @Override public void afterEntityAdded(Object entity) { insertQuantityMillis((AssetClassAllocation) entity, false); } @Override public void beforeVariableChanged(Object entity, String variableName) { retractQuantityMillis((AssetClassAllocation) entity); } @Override public void afterVariableChanged(Object entity, String variableName) { insertQuantityMillis((AssetClassAllocation) entity, false); } @Override public void beforeEntityRemoved(Object entity) { retractQuantityMillis((AssetClassAllocation) entity); } @Override public void afterEntityRemoved(Object entity) { // Do nothing } // ************************************************************************ // Modify methods // ************************************************************************ private void insertQuantityMillis(AssetClassAllocation allocation, boolean reset) { // Standard deviation maximum if (squaredStandardDeviationFemtos > squaredStandardDeviationFemtosMaximum) { hardScore += squaredStandardDeviationFemtos - squaredStandardDeviationFemtosMaximum; } squaredStandardDeviationFemtos += calculateStandardDeviationSquaredFemtosDelta(allocation, reset); if (squaredStandardDeviationFemtos > squaredStandardDeviationFemtosMaximum) { hardScore -= squaredStandardDeviationFemtos - squaredStandardDeviationFemtosMaximum; } Long quantityMillis = allocation.getQuantityMillis(); if (quantityMillis != null) { // Region quantity maximum Region region = allocation.getRegion(); long regionQuantityMaximum = region.getQuantityMillisMaximum(); long oldRegionQuantity = regionQuantityTotalMap.get(region); long oldRegionAvailable = regionQuantityMaximum - oldRegionQuantity; long newRegionQuantity = oldRegionQuantity + quantityMillis; long newRegionAvailable = regionQuantityMaximum - newRegionQuantity; hardScore += Math.min(newRegionAvailable, 0L) - Math.min(oldRegionAvailable, 0L); regionQuantityTotalMap.put(region, newRegionQuantity); // Sector quantity maximum Sector sector = allocation.getSector(); long sectorQuantityMaximum = sector.getQuantityMillisMaximum(); long oldSectorQuantity = sectorQuantityTotalMap.get(sector); long oldSectorAvailable = sectorQuantityMaximum - oldSectorQuantity; long newSectorQuantity = oldSectorQuantity + quantityMillis; long newSectorAvailable = sectorQuantityMaximum - newSectorQuantity; hardScore += Math.min(newSectorAvailable, 0L) - Math.min(oldSectorAvailable, 0L); sectorQuantityTotalMap.put(sector, newSectorQuantity); } // Maximize expected return softScore += allocation.getQuantifiedExpectedReturnMicros(); } private void retractQuantityMillis(AssetClassAllocation allocation) { // Standard deviation maximum if (squaredStandardDeviationFemtos > squaredStandardDeviationFemtosMaximum) { hardScore += squaredStandardDeviationFemtos - squaredStandardDeviationFemtosMaximum; } squaredStandardDeviationFemtos -= calculateStandardDeviationSquaredFemtosDelta(allocation, false); if (squaredStandardDeviationFemtos > squaredStandardDeviationFemtosMaximum) { hardScore -= squaredStandardDeviationFemtos - squaredStandardDeviationFemtosMaximum; } Long quantityMillis = allocation.getQuantityMillis(); if (quantityMillis != null) { // Region quantity maximum Region region = allocation.getRegion(); long regionQuantityMaximum = region.getQuantityMillisMaximum(); long oldRegionQuantity = regionQuantityTotalMap.get(region); long oldRegionAvailable = regionQuantityMaximum - oldRegionQuantity; long newRegionQuantity = oldRegionQuantity - quantityMillis; long newRegionAvailable = regionQuantityMaximum - newRegionQuantity; hardScore += Math.min(newRegionAvailable, 0L) - Math.min(oldRegionAvailable, 0L); regionQuantityTotalMap.put(region, newRegionQuantity); // Sector quantity maximum Sector sector = allocation.getSector(); long sectorQuantityMaximum = sector.getQuantityMillisMaximum(); long oldSectorQuantity = sectorQuantityTotalMap.get(sector); long oldSectorAvailable = sectorQuantityMaximum - oldSectorQuantity; long newSectorQuantity = oldSectorQuantity - quantityMillis; long newSectorAvailable = sectorQuantityMaximum - newSectorQuantity; hardScore += Math.min(newSectorAvailable, 0L) - Math.min(oldSectorAvailable, 0L); sectorQuantityTotalMap.put(sector, newSectorQuantity); } // Maximize expected return softScore -= allocation.getQuantifiedExpectedReturnMicros(); } private long calculateStandardDeviationSquaredFemtosDelta(AssetClassAllocation allocation, boolean reset) { long squaredFemtos = 0L; for (AssetClassAllocation other : solution.getAssetClassAllocationList()) { if (allocation == other) { long micros = allocation.getQuantifiedStandardDeviationRiskMicros(); squaredFemtos += micros * micros * 1000L; } else { long picos = allocation.getQuantifiedStandardDeviationRiskMicros() * other.getQuantifiedStandardDeviationRiskMicros(); squaredFemtos += picos * allocation.getAssetClass().getCorrelationMillisMap().get(other.getAssetClass()); // TODO FIXME the reset hack only works if there are no moves that mix multiple before/after notifications if (!reset) { squaredFemtos += picos * other.getAssetClass().getCorrelationMillisMap().get(allocation.getAssetClass()); } } } return squaredFemtos; } @Override public Score calculateScore() { return HardSoftLongScore.valueOf(hardScore, softScore); } }