package edu.berkeley.nlp.math;
import edu.berkeley.nlp.mapper.AsynchronousMapper;
import edu.berkeley.nlp.mapper.SimpleMapper;
import edu.berkeley.nlp.util.Pair;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
/**
* User: aria42
* Date: Mar 10, 2009
*/
public class CachingObjectiveDifferentiableFunction<I> extends CachingDifferentiableFunction {
private List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns;
private Regularizer regularizer;
private Collection<I> items;
public CachingObjectiveDifferentiableFunction(Collection<I> items,
List<? extends ObjectiveItemDifferentiableFunction<I>> itemFns,
Regularizer regularizer)
{
this.itemFns = itemFns;
this.regularizer = regularizer;
this.items = items;
}
public CachingObjectiveDifferentiableFunction(Collection<I> items,
ObjectiveItemDifferentiableFunction<I> itemFn,
Regularizer regularizer)
{
this(items, Collections.singletonList(itemFn),regularizer);
}
private class Mapper implements SimpleMapper<I> {
ObjectiveItemDifferentiableFunction<I> itemFn;
double objVal ;
double[] localGrad ;
Mapper(ObjectiveItemDifferentiableFunction<I> itemFn) {
this.itemFn = itemFn;
this.objVal = 0.0;
this.localGrad = new double[itemFn.dimension()];
}
public void map(I elem) {
objVal += itemFn.update(elem,localGrad);
}
}
private List<Mapper> getMappers() {
List<Mapper> mappers = new ArrayList<Mapper>();
for (ObjectiveItemDifferentiableFunction<I> itemFn : itemFns) {
mappers.add(new Mapper(itemFn));
}
return mappers;
}
protected Pair<Double, double[]> calculate(double[] x) {
for (ObjectiveItemDifferentiableFunction<I> itemFn : itemFns) {
itemFn.setWeights(x);
}
List<Mapper> mappers = getMappers();
AsynchronousMapper.doMapping(items,mappers);
double objVal = 0.0;
double[] grad = new double[dimension()];
for (Mapper mapper : mappers) {
objVal += mapper.objVal;
DoubleArrays.addInPlace(grad,mapper.localGrad);
}
if (regularizer != null) {
objVal += regularizer.update(x,grad,1.0);
}
return Pair.newPair(objVal,grad);
}
public int dimension() {
return itemFns.get(0).dimension();
}
}