/* * Copyright (c) 2012 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.analysis.io.numpy; import java.io.File; import java.util.Collection; import java.util.LinkedList; import java.util.List; import org.apache.commons.lang.ArrayUtils; import org.eclipse.january.dataset.Dataset; import org.eclipse.january.dataset.DatasetFactory; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; import uk.ac.diamond.scisoft.analysis.PythonHelper; import uk.ac.diamond.scisoft.analysis.io.NumPyFileLoader; /** * Test that NaN is properly loaded and stored */ @RunWith(Parameterized.class) public class NumPyNanTest { @Parameters public static Collection<Object[]> configs() { List<Object[]> params = new LinkedList<Object[]>(); params.add(new Object[] { "'<f4'", Dataset.FLOAT32}); params.add(new Object[] { "'<f8'", Dataset.FLOAT64}); return params; } private String numpyDataType; private int dtype; public NumPyNanTest(String numpyDataType, int dtype) { this.numpyDataType = numpyDataType; this.dtype = dtype; } @Test public void testLoad() throws Exception { File loc = NumPyTest.getTempFile(); StringBuilder script = new StringBuilder(); script.append("import numpy; "); script.append("exp=numpy.array([float('NaN')]*2, dtype=" + numpyDataType + "); "); script.append("numpy.save(r'" + loc.toString() + "', exp);"); PythonHelper.runPythonScript(script.toString(), true); Dataset loadedFile = NumPyFileLoader.loadFileHelper(loc.toString()); Assert.assertTrue(ArrayUtils.isEquals(new int[] {2}, loadedFile.getShape())); Assert.assertEquals(dtype, loadedFile.getDType()); Assert.assertTrue(Double.isNaN(loadedFile.getDouble(0))); Assert.assertTrue(Double.isNaN(loadedFile.getDouble(1))); } @Test public void testSave() throws Exception { Dataset ds = DatasetFactory.createFromObject(dtype, new double[] {Double.NaN, Double.NaN}); File loc = NumPyTest.getTempFile(); NumPyTest.saveNumPyFile(ds, loc, false); StringBuilder script = new StringBuilder(); script.append("import numpy; "); script.append("exp=numpy.array([float('NaN')]*2, dtype=" + numpyDataType + "); "); script.append("act=numpy.load(r'" + loc.toString() + "');"); script.append("print(exp.dtype==act.dtype and exp.shape==act.shape and numpy.isnan(act[0]) and numpy.isnan(act[1]))"); String pythonStdout = PythonHelper.runPythonScript(script.toString(), false); Assert.assertTrue(Boolean.parseBoolean(pythonStdout.trim())); } }