/*
* Copyright (C) 2011 Laurent Caillette
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation, either
* version 3 of the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.novelang.novelist;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.novelang.logger.Logger;
import org.novelang.logger.LoggerFactory;
/**
* Given a {@code Map} of objects and their frequency expressed as percentages, this class
* finds them back with same distribution given a value in a [0..100[ fractional interval.
*
* <p>
* Design note: while the {@code Float} objects should be all between 0 and 100, there is
* no clear advantage on using a {@link Bounded.Percentage} at creation.
* The {@code Float} instance is held internally and is bound-checked.
* Exposing a {@link Bounded.Percentage} for Map creation would require boring code like
* {@code put( 'a', Bounded.newPercentage( 7.636f ) )} instead of {@code put( 'a', 7.636f )}.
*
* @author Laurent Caillette
*/
public abstract class Distribution< T > {
private static final Logger LOGGER = LoggerFactory.getLogger( Distribution.class );
private final Map< T, Float > summedFrequencies ;
protected Distribution(
final String requesterNameForLogging,
final Map< T, Float > summedFrequencies
) {
this.summedFrequencies = sumDistributions( requesterNameForLogging, summedFrequencies ) ;
}
public T get( final Bounded.Percentage probability ) {
T found = null ;
for( final Map.Entry< T, Float > entry : summedFrequencies.entrySet() ) {
if( probability.isStrictlySmallerThan( entry.getValue() ) ) {
found = entry.getKey() ;
break ;
}
}
if( found == null ) {
throw new IllegalStateException( "Should not happen: found nothing for " + probability ) ;
}
return found ;
}
/**
* Returns an immutable {@code Map} with keys sorted by frequency (descending sort),
* associated with the sum of their frequency and frequencies of previous keys.
*/
private static< T > Map< T, Float > sumDistributions(
final String requesterNameForLogging,
final Map< T, Float > individualDistributions
) {
final List< Map.Entry< T, Float > > sortedKeys =
Lists.newArrayList( individualDistributions.entrySet() ) ;
final Comparator< Map.Entry< T, Float > > invertedComparatorOnFrequency =
new Comparator< Map.Entry< T, Float > >() {
@Override
public int compare(
final Map.Entry< T, Float > entry1,
final Map.Entry< T, Float > entry2
) {
return entry2.getValue().compareTo( entry1.getValue() ) ;
}
} ;
Collections.sort( sortedKeys,invertedComparatorOnFrequency ) ;
final Map< T, Float > cumulatedFrequencies = Maps.newLinkedHashMap() ;
float sum = 0.0f ;
// First, verify validity and calculate the complete sum.
for( final Map.Entry< T, Float > entry : sortedKeys ) {
final Float frequency = entry.getValue() ;
if( ! Bounded.Percentage.isValid( frequency ) ) {
throw new IllegalArgumentException(
"Invalid frequency for '" + entry.getKey() + "': " + frequency ) ;
}
sum += frequency ;
}
// Now start iterating again. This time we boost the frequency of the first element
// to reach the total of 100. Boosting first element (the one with highest frequency)
// introduces minimal change.
final float correction = 99.999999f - sum ;
logMessageAboutCorrection( requesterNameForLogging, sum, correction ) ;
sum = correction ;
for( final Map.Entry< T, Float > entry : sortedKeys ) {
final float frequency = entry.getValue() ;
if( frequency != 0.0f ) {
sum += frequency ;
cumulatedFrequencies.put( entry.getKey(), sum ) ;
}
}
if( LOGGER.isDebugEnabled() && false ) {
for( final Map.Entry< T, Float > entry : cumulatedFrequencies.entrySet() ) {
LOGGER.debug( " ", entry.getKey(), " -> ", entry.getValue() ) ;
}
}
return Collections.unmodifiableMap( cumulatedFrequencies ) ;
}
private static void logMessageAboutCorrection(
final String requesterNameForLogging,
final float sum,
final float correction
) {
final String message = "For " + requesterNameForLogging + ", " +
"sum of raw frequencies: " + sum + "; correction: " + correction ;
if( Math.abs( correction ) > 1.0f ) {
LOGGER.warn( message ) ;
} else {
LOGGER.debug( message ) ;
}
}
}