package hep.aida.ref.pdf.examples; import hep.aida.*; import hep.aida.IFitResult; import hep.aida.dev.IDevFitData; import hep.aida.dev.IDevFitDataIterator; import hep.aida.ext.IFitMethod; import hep.aida.ref.fitter.InternalFitFunction; import hep.aida.ref.function.BaseModelFunction; import hep.aida.ref.pdf.Dependent; import hep.aida.ref.pdf.InternalObjectiveFunction; import hep.aida.ref.pdf.PdfFitter; import java.util.Random; import hep.aida.ref.pdf.*; public class TestGradient { public static void main(String[] args) { // Create factories IAnalysisFactory analysisFactory = IAnalysisFactory.create(); ITreeFactory treeFactory = analysisFactory.createTreeFactory(); ITree tree = treeFactory.create(); IPlotter plotter = analysisFactory.createPlotterFactory().create("Plotter"); IHistogramFactory histogramFactory = analysisFactory.createHistogramFactory(tree); IFunctionFactory functionFactory = analysisFactory.createFunctionFactory(tree); IFitFactory fitFactory = analysisFactory.createFitFactory(); IDataPointSetFactory dataPointSetFactory = analysisFactory.createDataPointSetFactory(tree); double lowRange = -4, highRange = -1*lowRange; IHistogram1D h1 = histogramFactory.createHistogram1D("Histogram 1D", 50, lowRange, highRange); ICloud1D c1 = histogramFactory.createCloud1D("Cloud"); Random r = new Random(); for (int i = 0; i < 100; i++) { double x = r.nextGaussian(); h1.fill(x); c1.fill(x); } double h1Norm = h1.sumBinHeights()*(h1.axis().upperEdge()-h1.axis().lowerEdge())/h1.axis().bins(); h1.scale(1./h1Norm); boolean norm = true; double[] f_pars = new double[]{0, 0.5}; double[] g_pars = new double[]{f_pars[0], f_pars[1], 1.}; Dependent x = new Dependent("x", lowRange, highRange); // Gaussian g_notNorm = new Gaussian("myGauss not normalized", x); // g_notNorm.setParameters(g_pars); Gaussian g = new Gaussian("myGauss", x); g.setParameters(g_pars); IRangeSet g_range = g.normalizationRange(0); g_range.excludeAll(); g_range.include(lowRange, highRange); g.normalize(norm); IModelFunction f = (IModelFunction) functionFactory.createFunctionByName("IGauss", "g"); IRangeSet f_range = f.normalizationRange(0); f_range.excludeAll(); f_range.include(lowRange, highRange); f.normalize(norm); // WATCH THIS!!! THE PARAMETERS MUST BE SET _AFTER_ THE NORMALIZATION . WHEN NORMALIZING THE FUNCTION CORE IS SWITCHED!!!! f.setParameters(f_pars); IDataPointSet gDataPointSet = dataPointSetFactory.create("g", 2); IDataPointSet fDataPointSet = dataPointSetFactory.create("f", 2); int points = 200; double delta = (highRange - lowRange) / (double) points; for (int i = 0; i < points; i++) { double[] xv = new double[] {lowRange + delta * (double) i}; IDataPoint gPoint = gDataPointSet.addPoint(); gPoint.coordinate(0).setValue(xv[0]); gPoint.coordinate(1).setValue(g.value(xv)); IDataPoint fPoint = fDataPointSet.addPoint(); fPoint.coordinate(0).setValue(xv[0]); fPoint.coordinate(1).setValue(f.value(xv)); } plotter.createRegions(2, 2); plotter.region(0).plot(h1); plotter.region(0).plot(g); plotter.region(1).plot(h1); plotter.region(1).plot(g); // plotter.region(1).plot(g_notNorm); plotter.region(1).plot(f); plotter.region(2).plot(gDataPointSet); plotter.region(2).plot(fDataPointSet); plotter.region(2).plot(gDataPointSet); plotter.show(); double[] xVals = new double[]{r.nextDouble()}; x.setValue(xVals[0]); System.out.println("Function value at x=" + xVals[0] + " g: " + g.value() + " f: " + f.value(xVals)); ((BaseModelFunction)f).calculateNormalizationAmplitude(); System.out.println("Function Normalization at x=" + xVals[0] + " g: " + 1./g.evaluateAnalyticalNormalization(x) + " f: " + ((BaseModelFunction)f).getNormalizationAmplitude()); double[] g_grad = g.gradient(); // double[] g_grad_notNorm = g_notNorm.gradient(); double[] f_grad = f.gradient(xVals); System.out.println("Gradient Size: " + g_grad.length + " " + f_grad.length); for (int i = 0; i < g_grad.length; i++) { System.out.println("Gradient at x=" + xVals[0] + " g: " + g_grad[i] + /*"(" + g_grad_notNorm[i] + ")*/" f: " + f_grad[i]); } double[] g_par_grad = g.parameterGradient(xVals); double[] f_par_grad = f.parameterGradient(xVals); for (int i = 0; i < f_par_grad.length; i++) { System.out.println("Gradient for par " + f.parameterNames()[i] + "(" + g.getParameter(i).name() + ") at x=" + xVals[0] + " g: " + g_par_grad[i] + " f: " + f_par_grad[i]); } IFitData fitData = fitFactory.createFitData(); fitData.create1DConnection(c1); IFitMethod fitMethod = PdfFitter.getFitMethod("uml"); InternalObjectiveFunction g_objectiveFunction = new InternalObjectiveFunction(new IFitData[]{fitData}, new Function[]{g}, fitMethod); IDevFitDataIterator dataIter = ((IDevFitData) fitData).dataIterator(); InternalFitFunction f_objectiveFunction = new InternalFitFunction(dataIter, f, fitMethod); String[] g_vars = g_objectiveFunction.variableNames(); String[] f_vars = f_objectiveFunction.variableNames(); if (g_vars.length != f_vars.length) { throw new RuntimeException("Should have the same dimension "); } for (int i = 0; i < g_vars.length; i++) { System.out.println("g var[" + i + "] = " + g_vars[i] + " f var[" + i + "] = " + f_vars[i]); } double g_of_value = g_objectiveFunction.value(g_pars); double[] g_of_grad = g_objectiveFunction.gradient(g_pars); double f_of_value = f_objectiveFunction.value(f_pars); double[] f_of_grad = f_objectiveFunction.gradient(f_pars); System.out.println("Objective function value g = " + g_of_value + " f = " + f_of_value); for (int i = 0; i < f_of_grad.length; i++) { System.out.println("Objective function gradient for var " + f_objectiveFunction.variableName(i) + " (" + g_objectiveFunction.variableName(i) + ") g: " + g_of_grad[i] + " f: " + f_of_grad[i]); } /* Gaussian g = new Gaussian("myGauss"); g.setParameter("norm",h1.maxBinHeight()); g.setParameter("mean",h1.mean()+1); g.setParameter("sigma",h1.rms()); plotter.region(0).plot(h1); PdfFitter gaussFit = new PdfFitter("Chi2","fminuit"); gaussFit.setUseFunctionGradient(false); long start = System.currentTimeMillis(); // gaussFit.fit(h1, g); long end = System.currentTimeMillis(); long time = end-start; System.out.println("Time to fit : "+time); IFitter fitter = fitFactory.createFitter("uml","fminuit","noClone=true"); fitter.setUseFunctionGradient(true); IFunction ig = FunctionConverter.convert(g); start = System.currentTimeMillis(); IFitResult fitResult = fitter.fit(c1, ig); end = System.currentTimeMillis(); time = end-start; System.out.println("Time to fit : "+time+" "+fitResult.quality()); // plotter.region(0).plot(g); plotter.region(0).plot(fitResult.fittedFunction()); plotter.show(); */ } }