/*- * Copyright (c) 2013 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.analysis.fitting.functions; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import java.util.Arrays; import org.eclipse.dawnsci.analysis.api.fitting.functions.IFunction; import org.eclipse.dawnsci.analysis.api.fitting.functions.IOperator; import org.eclipse.january.dataset.Dataset; import org.eclipse.january.dataset.DatasetFactory; import org.eclipse.january.dataset.DoubleDataset; import org.eclipse.january.dataset.Random; import org.junit.Assert; import org.junit.Test; public class OperatorTest { private static final double ABS_TOL = 1e-7; @Test public void testAdd() { Add op = new Add(); IFunction fa = new Cubic(); fa.setParameterValues(23., -10., 1.2, -5.2); IFunction fb = new StraightLine(); fb.setParameterValues(4.2, -7.5); op.addFunction(fa); op.addFunction(fb); Assert.assertEquals(6, op.getNoOfParameters()); Assert.assertArrayEquals(new double[] {23., -10., 1.2, -5.2, 4.2, -7.5}, op.getParameterValues(), ABS_TOL); Assert.assertEquals(-23. - 10. - 1.2 - 5.2 - 4.2 - 7.5, op.val(-1), ABS_TOL); Dataset xd = DatasetFactory.createFromObject(new double[] {-1, 0, 2}); DoubleDataset dx; dx = op.calculateValues(xd); Assert.assertArrayEquals(new double[] {-23. - 10. - 1.2 - 5.2 - 4.2 - 7.5, -5.2 - 7.5, 23.*8 - 10.*4 + 1.2*2 - 5.2 + 4.2*2 - 7.5}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(0), xd); Assert.assertArrayEquals(new double[] {-1, 0, 8}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(1), xd); Assert.assertArrayEquals(new double[] {1, 0, 4}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(2), xd); Assert.assertArrayEquals(new double[] {-1, 0, 2}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(3), xd); Assert.assertArrayEquals(new double[] {1, 1, 1}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(4), xd); Assert.assertArrayEquals(new double[] {-1, 0, 2}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(5), xd); Assert.assertArrayEquals(new double[] {1, 1, 1}, dx.getData(), ABS_TOL); DoubleDataset[] coords = new DoubleDataset[] {DatasetFactory.createRange(DoubleDataset.class, 15, 30, 0.25)}; DoubleDataset weight = null; CoordinatesIterator it = CoordinatesIterator.createIterator(null, coords); DoubleDataset current = DatasetFactory.zeros(DoubleDataset.class, it.getShape()); DoubleDataset data = Random.randn(it.getShape()); op.fillWithValues(current, it); double rd = data.residual(current, weight, false); double rf = op.residual(true, data, weight, coords); Assert.assertEquals(rd, rf, 1e-9); } @Test public void testMultiply() { Multiply op = new Multiply(); IFunction fa = new Cubic(); fa.setParameterValues(23., -10., 1.2, -5.2); IFunction fb = new StraightLine(); fb.setParameterValues(4.2, -7.5); op.addFunction(fa); op.addFunction(fb); Assert.assertEquals(6, op.getNoOfParameters()); Assert.assertArrayEquals(new double[] {23., -10., 1.2, -5.2, 4.2, -7.5}, op.getParameterValues(), ABS_TOL); Assert.assertEquals((-23. - 10. - 1.2 - 5.2) * (- 4.2 - 7.5), op.val(-1), ABS_TOL); Dataset xd = DatasetFactory.createFromObject(new double[] {-1, 0, 2}); DoubleDataset dx; dx = op.calculateValues(xd); Assert.assertArrayEquals(new double[] {(-23. - 10. - 1.2 - 5.2) * (- 4.2 - 7.5), -5.2 * - 7.5, (23.*8 - 10.*4 + 1.2*2 - 5.2) * (4.2*2 - 7.5)}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(0), xd); Assert.assertArrayEquals(new double[] {-(-4.2 - 7.5), 0, 8*(4.2*2 - 7.5)}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(1), xd); Assert.assertArrayEquals(new double[] {(-4.2 - 7.5), 0, 4*(4.2*2 - 7.5)}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(2), xd); Assert.assertArrayEquals(new double[] {-(-4.2 - 7.5), 0, 2*(4.2*2 - 7.5)}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(3), xd); Assert.assertArrayEquals(new double[] {-4.2 - 7.5, -7.5, 4.2*2 - 7.5}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(4), xd); Assert.assertArrayEquals(new double[] {(-23. - 10. - 1.2 - 5.2) * -1, 0, (23.*8 - 10.*4 + 1.2*2 - 5.2) * 2}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(5), xd); Assert.assertArrayEquals(new double[] {(-23. - 10. - 1.2 - 5.2), -5.2, (23.*8 - 10.*4 + 1.2*2 - 5.2)}, dx.getData(), ABS_TOL); } @Test public void testSubtract() { Subtract op = new Subtract(); IFunction fa = new Cubic(); fa.setParameterValues(23., -10., 1.2, -5.2); IFunction fb = new StraightLine(); fb.setParameterValues(-4.2, 7.5); op.addFunction(fa); op.addFunction(fb); Assert.assertEquals(6, op.getNoOfParameters()); Assert.assertArrayEquals(new double[] {23., -10., 1.2, -5.2, -4.2, 7.5}, op.getParameterValues(), ABS_TOL); Assert.assertEquals(-23. - 10. - 1.2 - 5.2 - 4.2 - 7.5, op.val(-1), ABS_TOL); Dataset xd = DatasetFactory.createFromObject(new double[] {-1, 0, 2}); DoubleDataset dx; dx = op.calculateValues(xd); Assert.assertArrayEquals(new double[] {-23. - 10. - 1.2 - 5.2 - 4.2 - 7.5, -5.2 - 7.5, 23.*8 - 10.*4 + 1.2*2 - 5.2 + 4.2*2 - 7.5}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(0), xd); Assert.assertArrayEquals(new double[] {-1, 0, 8}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(1), xd); Assert.assertArrayEquals(new double[] {1, 0, 4}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(2), xd); Assert.assertArrayEquals(new double[] {-1, 0, 2}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(3), xd); Assert.assertArrayEquals(new double[] {1, 1, 1}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(4), xd); Assert.assertArrayEquals(new double[] {1, 0, -2}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(5), xd); Assert.assertArrayEquals(new double[] {-1, -1, -1}, dx.getData(), ABS_TOL); } @Test public void testDivide() { Divide op = new Divide(); IFunction fa = new Cubic(); fa.setParameterValues(23., -10., 1.2, -5.2); IFunction fb = new StraightLine(); fb.setParameterValues(4.2, -7.5); op.addFunction(fa); op.addFunction(fb); Assert.assertEquals(6, op.getNoOfParameters()); Assert.assertArrayEquals(new double[] {23., -10., 1.2, -5.2, 4.2, -7.5}, op.getParameterValues(), ABS_TOL); Assert.assertEquals((-23. - 10. - 1.2 - 5.2) / (- 4.2 - 7.5), op.val(-1), ABS_TOL); Dataset xd = DatasetFactory.createFromObject(new double[] {-1, 0, 2}); DoubleDataset dx; dx = op.calculateValues(xd); Assert.assertArrayEquals(new double[] {(-23. - 10. - 1.2 - 5.2) / (-4.2 - 7.5), -5.2 / - 7.5, (23.*8 - 10.*4 + 1.2*2 - 5.2) / (4.2*2 - 7.5)}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(0), xd); Assert.assertArrayEquals(new double[] {-1/(-4.2 - 7.5), 0, 8/(4.2*2 - 7.5)}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(1), xd); Assert.assertArrayEquals(new double[] {1/(-4.2 - 7.5), 0, 4/(4.2*2 - 7.5)}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(2), xd); Assert.assertArrayEquals(new double[] {-1/(-4.2 - 7.5), 0, 2/(4.2*2 - 7.5)}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(3), xd); Assert.assertArrayEquals(new double[] {1/(-4.2 - 7.5), 1/-7.5, 1/(4.2*2 - 7.5)}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(4), xd); Assert.assertArrayEquals(new double[] {(-23. - 10. - 1.2 - 5.2) / ((-4.2 - 7.5)*(-4.2 - 7.5)), 0, (23.*8 - 10.*4 + 1.2*2 - 5.2) * -2 / ((4.2*2 - 7.5)*(4.2*2 - 7.5))}, dx.getData(), ABS_TOL); dx = op.calculatePartialDerivativeValues(op.getParameter(5), xd); Assert.assertArrayEquals(new double[] {(-23. - 10. - 1.2 - 5.2) * -1 / ((-4.2 - 7.5)*(-4.2 - 7.5)), -5.2 * -1 / (7.5 * 7.5), (23.*8 - 10.*4 + 1.2*2 - 5.2) * -1 / ((4.2*2 - 7.5)*(4.2*2 - 7.5)) }, dx.getData(), ABS_TOL); } @Test public void testBinaryOperators() { IFunction fa = new Cubic(); fa.setParameterValues(23., -10., 1.2, -5.2); IFunction fb = new StraightLine(); fb.setParameterValues(-4.2, 7.5); IFunction fc = new Offset(); fc.setParameterValues(42.); IOperator op = new Subtract(); op.addFunction(fa); op.addFunction(fb); try { op.addFunction(fc); Assert.fail("Should have thrown exception"); } catch (Exception e) { // do nothing } op = new Divide(); op.addFunction(fa); op.addFunction(fb); try { op.addFunction(fc); Assert.fail("Should have thrown exception"); } catch (Exception e) { // do nothing } } @Test public void testConvolve() { double w = 110 * Math.log(2)*FermiGauss.K2EV_CONVERSION_FACTOR; Dataset xd = DatasetFactory.createFromObject(new double[] {23. - w, 23, 23. + 2 * w}); AFunction f = new Fermi(); f.setParameterValues(23., 110*FermiGauss.K2EV_CONVERSION_FACTOR, 1, 0); AFunction g = new Gaussian(); g.setParameterValues((double) xd.mean(), 1., 1.); Convolve cfg = new Convolve(); cfg.addFunction(f); cfg.addFunction(g); AFunction fg = new FermiGauss(); fg.setParameterValues(23., 110., 0, 1, 0, 1); Assert.assertEquals(7, cfg.getNoOfParameters()); double[] cps = Arrays.copyOf(cfg.getParameterValues(), 6); cps[1] /= FermiGauss.K2EV_CONVERSION_FACTOR; cps[3] = cps[2]; cps[2] = 0; cps[4] = 0; double[] ps = fg.getParameterValues(); Assert.assertArrayEquals(cps, ps, ABS_TOL); DoubleDataset fgx = fg.calculateValues(xd); DoubleDataset cfgx = cfg.calculateValues(xd); Assert.assertArrayEquals(cfgx.getData(), fgx.getData(), 200*ABS_TOL); } @Test public void testToString() { // make sure empty CompositeFunction does not throw exception // this test is not concerned with the contents of toString, just // that there is no exception. final String operatorText = "Operator has no functions"; CompositeFunction compositeFunction = new CompositeFunction(); Assert.assertTrue(compositeFunction.toString().startsWith(operatorText)); Subtract subFunction = new Subtract(); Assert.assertTrue(compositeFunction.toString().startsWith(operatorText)); subFunction.setFunction(0, new Gaussian()); Assert.assertFalse(subFunction.toString().startsWith(operatorText)); subFunction.removeFunction(0); Assert.assertTrue(subFunction.toString().startsWith(operatorText)); subFunction.setFunction(1, new Gaussian()); Assert.assertFalse(subFunction.toString().startsWith(operatorText)); subFunction.removeFunction(1); Assert.assertTrue(subFunction.toString().startsWith(operatorText)); } @Test public void testCopy() throws Exception { CompositeFunction compositeFunction = new CompositeFunction(); compositeFunction.addFunction(new Gaussian()); CompositeFunction copy = compositeFunction.copy(); assertEquals(1, copy.getFunctions().length); assertTrue(copy.getFunction(0) instanceof Gaussian); Add addOperator = new Add(); addOperator.addFunction(new Gaussian()); Add addCopy = (Add)addOperator.copy(); assertEquals(1, addCopy.getFunctions().length); assertTrue(addCopy.getFunction(0) instanceof Gaussian); } public static class MockGaussianLikeThing extends Gaussian { public MockGaussianLikeThing() { name = "MockGaussianLikeThing"; } } @Test public void testEquals() { // Test that two function trees with different child functions // but the same parameters do not compare to equals CompositeFunction expected = new CompositeFunction(); expected.addFunction(new Gaussian()); CompositeFunction actual = new CompositeFunction(); actual.addFunction(new MockGaussianLikeThing()); assertTrue(!expected.equals(actual)); } @Test public void testRemoveFunction() { CompositeFunction composite = new CompositeFunction(); Add add = new Add(); assertEquals(0, add.getParameters().length); assertEquals(0, composite.getParameters().length); Gaussian gaussian = new Gaussian(1, 2, 3); add.addFunction(gaussian); composite.addFunction(add); assertEquals(3, add.getParameters().length); assertEquals(3, composite.getParameters().length); add.removeFunction(0); assertEquals(0, add.getParameters().length); assertEquals(0, composite.getParameters().length); } @Test public void testAddFunctionOrder_BottomUp() { CompositeFunction composite = new CompositeFunction(); Add add = new Add(); Gaussian gaussian = new Gaussian(1, 2, 3); add.addFunction(gaussian); composite.addFunction(add); assertEquals(3, add.getParameters().length); assertEquals(3, composite.getParameters().length); } @Test public void testAddFunctionOrder_TopDown() { CompositeFunction composite = new CompositeFunction(); Add add = new Add(); Gaussian gaussian = new Gaussian(1, 2, 3); composite.addFunction(add); add.addFunction(gaussian); assertEquals(3, add.getParameters().length); assertEquals(3, composite.getParameters().length); } @Test public void testRemoveFunction_WithWorkaround() { CompositeFunction composite = new CompositeFunction(); Add add = new Add(); Gaussian gaussian = new Gaussian(1, 2, 3); add.addFunction(gaussian); composite.addFunction(add); assertEquals(3, add.getParameters().length); assertEquals(3, composite.getParameters().length); add.removeFunction(0); assertEquals(0, add.getParameters().length); assertEquals(0, composite.getParameters().length); } @Test public void testAddFunctionOrder_BottomUp_WithWorkaround() { CompositeFunction composite = new CompositeFunction(); Add add = new Add(); Gaussian gaussian = new Gaussian(1, 2, 3); add.addFunction(gaussian); composite.addFunction(add); assertEquals(3, add.getParameters().length); assertEquals(3, composite.getParameters().length); } @Test public void testAddFunctionOrder_TopDown_WithWorkaround() { CompositeFunction composite = new CompositeFunction(); Add add = new Add(); Gaussian gaussian = new Gaussian(1, 2, 3); composite.addFunction(add); add.addFunction(gaussian); assertEquals(3, add.getParameters().length); assertEquals(3, composite.getParameters().length); } @Test public void testUpdateParametersWorkaround() { CompositeFunction composite = new CompositeFunction(); Subtract subtract = new Subtract(); Gaussian gaussian = new Gaussian(1, 2, 3); composite.addFunction(subtract); subtract.setFunction(1, gaussian); assertEquals(3, gaussian.getParameters().length); assertEquals(3, subtract.getParameters().length); assertEquals(3, composite.getParameters().length); } @Test public void testUpdateParamtersChangingFunction() { CompositeFunction composite = new CompositeFunction(); Polynomial poly = new Polynomial(); composite.addFunction(poly); assertEquals(1, poly.getParameters().length); assertEquals(1, composite.getParameters().length); poly.setDegree(1); assertEquals(2, poly.getParameters().length); assertEquals(2, composite.getParameters().length); } private static final class TestGenericBinaryOperator extends ABinaryOperator { public TestGenericBinaryOperator() { super(); } @Override protected void setNames() { setNames("Hello", "World"); } @Override public double val(double... values) { return 0; } @Override public void fillWithValues(DoubleDataset data, CoordinatesIterator it) { } } @Test public void testSetFunction() { IOperator myOp = new TestGenericBinaryOperator(); assertEquals("Hello", myOp.getName()); assertEquals("World", myOp.getDescription()); myOp.setFunction(1, new Gaussian()); assertEquals("Hello", myOp.getName()); assertEquals("World", myOp.getDescription()); assertEquals(1, myOp.getNoOfFunctions()); assertEquals(2, myOp.getFunctions().length); assertEquals(new Gaussian(), myOp.getFunction(1)); assertEquals(3, myOp.getNoOfParameters()); } @Test public void testNoOfFunctions() { IOperator myOp = new TestGenericBinaryOperator(); assertEquals("Hello", myOp.getName()); assertEquals("World", myOp.getDescription()); myOp.addFunction(new Gaussian()); myOp.addFunction(new Gaussian()); assertEquals("Hello", myOp.getName()); assertEquals("World", myOp.getDescription()); assertEquals(2, myOp.getNoOfFunctions()); myOp.setFunction(0, null); // functions can be set to null(!) assertEquals(1, myOp.getNoOfFunctions()); myOp.addFunction(new Gaussian()); assertEquals(2, myOp.getNoOfFunctions()); assertNotNull(myOp.getFunction(0)); assertNotNull(myOp.getFunction(1)); } @Test public void testUpdateParametersWithEqualParameters() { Add add = new Add(); assertEquals(0, add.getNoOfParameters()); add.addFunction(new Gaussian()); assertEquals(3, add.getNoOfParameters()); add.addFunction(new Gaussian()); assertEquals(6, add.getNoOfParameters()); } @Test public void testIsValid() { // invalid because binary operator requires 2 functions assertFalse(new TestGenericBinaryOperator().isValid()); CompositeFunction compositeFunction = new CompositeFunction(); TestGenericBinaryOperator binaryOperator = new TestGenericBinaryOperator(); compositeFunction.addFunction(binaryOperator); // make sure that the invalid is propagated up assertFalse(compositeFunction.isValid()); // correct the invalidity and make sure it is valid and propagated binaryOperator.addFunction(new Gaussian()); binaryOperator.addFunction(new Gaussian()); assertTrue(binaryOperator.isValid()); assertTrue(compositeFunction.isValid()); } }