/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package hivemall.common; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; public final class ConversionState { private static final Log logger = LogFactory.getLog(ConversionState.class); /** Whether to check conversion */ protected final boolean conversionCheck; /** Threshold to determine convergence */ protected final double convergenceRate; /** being ready to end iteration */ protected boolean readyToFinishIterations; /** The cumulative errors in the training */ protected double totalErrors; /** The cumulative losses in an iteration */ protected double currLosses, prevLosses; protected int curIter; protected float curEta; public ConversionState() { this(true, 0.005d); } public ConversionState(boolean conversionCheck, double convergenceRate) { this.conversionCheck = conversionCheck; this.convergenceRate = convergenceRate; this.readyToFinishIterations = false; this.totalErrors = 0.d; this.currLosses = 0.d; this.prevLosses = Double.POSITIVE_INFINITY; this.curIter = 0; this.curEta = Float.NaN; } public double getTotalErrors() { return totalErrors; } public double getCumulativeLoss() { return currLosses; } public double getPreviousLoss() { return prevLosses; } public void incrError(double error) { this.totalErrors += error; } public void incrLoss(double loss) { this.currLosses += loss; } public void multiplyLoss(double multi) { this.currLosses = currLosses * multi; } public boolean isLossIncreased() { return currLosses > prevLosses; } public boolean isConverged(final int iter, final long obserbedTrainingExamples) { if (conversionCheck == false) { this.prevLosses = currLosses; this.currLosses = 0.d; return false; } if (currLosses > prevLosses) { if (logger.isInfoEnabled()) { logger.info("Iteration #" + iter + " currLoss `" + currLosses + "` > prevLosses `" + prevLosses + '`'); } this.prevLosses = currLosses; this.currLosses = 0.d; this.readyToFinishIterations = false; return false; } final double changeRate = (prevLosses - currLosses) / prevLosses; if (changeRate < convergenceRate) { if (readyToFinishIterations) { // NOTE: never be true at the first iteration where prevLosses == Double.POSITIVE_INFINITY logger.info("Training converged at " + iter + "-th iteration. [curLosses=" + currLosses + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate + ']'); return true; } else { this.readyToFinishIterations = true; } } else { if (logger.isDebugEnabled()) { logger.debug("Iteration #" + iter + " [curLosses=" + currLosses + ", prevLosses=" + prevLosses + ", changeRate=" + changeRate + ", #trainingExamples=" + obserbedTrainingExamples + ']'); } this.readyToFinishIterations = false; } this.prevLosses = currLosses; this.currLosses = 0.d; return false; } public void logState(int iter, float eta) { if (logger.isInfoEnabled()) { logger.info("Iteration #" + iter + " [curLoss=" + currLosses + ", prevLoss=" + prevLosses + ", eta=" + eta + ']'); } this.curIter = iter; this.curEta = eta; } public int getCurrentIteration() { return curIter; } public float getCurrentEta() { return curEta; } }