/* * Copyright (c) 2012 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.analysis.io.numpy; import java.io.File; import java.io.IOException; import java.util.Collection; import java.util.LinkedList; import java.util.List; import org.apache.commons.lang.ArrayUtils; import org.eclipse.dawnsci.analysis.api.io.ScanFileHolderException; import org.eclipse.january.dataset.Dataset; import org.eclipse.january.dataset.DatasetFactory; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.junit.runners.Parameterized.Parameters; import uk.ac.diamond.scisoft.analysis.PythonHelper; import uk.ac.diamond.scisoft.analysis.io.DataHolder; import uk.ac.diamond.scisoft.analysis.io.NumPyFileLoader; import uk.ac.diamond.scisoft.analysis.io.NumPyFileSaver; @RunWith(Parameterized.class) public class NumPyTest { protected static final String PYTHON_NUMPY_PRINT_MATCHES = "print(exp.dtype==act.dtype and isinstance((exp==act), numpy.ndarray) and (exp==act).all())"; /** * Return a self deleting temp file * * @return temp file * @throws IOException */ protected static File getTempFile() throws IOException { File loc = File.createTempFile("scisoft", ".npy"); loc.deleteOnExit(); return loc; } /** * Wraps a dataset in a dataholder and saves it to the specified location * * @param ds * to save * @param loc * to save to * @throws ScanFileHolderException */ protected static void saveNumPyFile(Dataset ds, File loc, boolean unsigned) throws ScanFileHolderException { final DataHolder dh = new DataHolder(); dh.addDataset("", ds); new NumPyFileSaver(loc.toString(), unsigned).saveFile(dh); } static int[][] shapesToTest = { { 1 }, { 100 }, { 1000000 }, { 10, 10 }, { 5, 6, 7, 8 } }; static Object[][] types = new Object[][] { { "'|b1'", Dataset.BOOL }, { "'|i1'", Dataset.INT8 }, { "'<i2'", Dataset.INT16 }, { "'<i4'", Dataset.INT32 }, { "'<i8'", Dataset.INT64 }, { "'|u1'", Dataset.INT8, true }, { "'<u2'", Dataset.INT16, true }, { "'<u4'", Dataset.INT32, true }, { "'<f4'", Dataset.FLOAT32 }, { "'<f8'", Dataset.FLOAT64 }, { "'<c8'", Dataset.COMPLEX64 }, { "'<c16'", Dataset.COMPLEX128 },}; @Parameters public static Collection<Object[]> configs() { List<Object[]> params = new LinkedList<Object[]>(); int index = 0; for (int i = 0; i < types.length; i++) { Object[] type = types[i]; boolean unsigned = type.length > 2; for (int j = 0; j < shapesToTest.length; j++) { params.add(new Object[] { index++, type[0], type[1], shapesToTest[j], false, unsigned }); switch ((Integer) type[1]) { case Dataset.FLOAT32: case Dataset.FLOAT64: // Add some Inf values params.add(new Object[] { index++, type[0], type[1], shapesToTest[j], true, unsigned }); } } } return params; } private int index = 0; private String numpyDataType; private int abstractDatasetDataType; private int[] shape; private String shapeStr; private int len; private boolean addInf; private boolean unsigned; public NumPyTest(int index, String numpyDataType, int abstractDatasetDataType, int[] shape, boolean addInf, boolean unsigned) { this.index = index; this.numpyDataType = numpyDataType; this.abstractDatasetDataType = abstractDatasetDataType; this.shape = shape; this.addInf = addInf; this.unsigned = unsigned; this.len = 1; for (int i = 0; i < shape.length; i++) { this.len *= shape[i]; } this.shapeStr = ArrayUtils.toString(shape); this.shapeStr = this.shapeStr.substring(1, shapeStr.length() - 1); // System.out.println(this.toString()); } @Override public String toString() { return String.format("TEST %d: numpyType=%s datasetType=%d len=%d shape=%s", index, numpyDataType, abstractDatasetDataType, len, ArrayUtils.toString(shape)); } private Dataset createDataset() { final Dataset ds; if (abstractDatasetDataType != Dataset.BOOL) { ds = DatasetFactory.createRange(len, abstractDatasetDataType); } else { // creates an array of all False, so make two entries True if the array is big enough boolean[] boolarr = new boolean[len]; if (len > 0) boolarr[0] = true; if (len > 100) boolarr[100] = true; ds = DatasetFactory.createFromObject(abstractDatasetDataType, boolarr); } if (addInf && len > 3) { ds.set(Double.POSITIVE_INFINITY, 2); ds.set(Double.NEGATIVE_INFINITY, 3); } ds.setShape(shape); return ds; } private String createNumPyArray(String postCommands) { StringBuilder script = new StringBuilder(); script.append("import numpy; "); if (abstractDatasetDataType != Dataset.BOOL) { script.append("exp=numpy.arange(" + len + ", dtype=" + numpyDataType + "); "); } else { script.append("exp=numpy.array([False] * " + len + ", dtype=" + numpyDataType + "); "); if (len > 0) script.append("exp[0]=True; "); if (len > 100) script.append("exp[100]=True; "); } if (addInf && len > 3) { script.append("exp[2]=float('Inf'); "); script.append("exp[3]=-float('Inf'); "); } script.append("exp.shape=" + shapeStr + "; "); script.append(postCommands); return script.toString(); } // This test writes a numpy array with a small python script and then uses NumPyFileLoader to load it, and then make // sure it is equals to a newly created abstract dataset @Test public void testLoad() throws Exception { File loc = getTempFile(); String script = createNumPyArray(" numpy.save(r'" + loc.toString() + "', exp)"); PythonHelper.runPythonScript(script, true); Dataset loadedFile = NumPyFileLoader.loadFileHelper(loc.toString()); Dataset ds = createDataset(); if (unsigned) ds = DatasetFactory.createFromObject(unsigned, ds); Assert.assertEquals(toString(), ds, loadedFile); } // This test writes an abstract data set with NumPyFileSaver and runs a short python script // to load it and check it is as expected @Test public void testSave() throws Exception { Dataset ds = createDataset(); File loc = getTempFile(); saveNumPyFile(ds, loc, unsigned); String script = createNumPyArray(" act=numpy.load(r'" + loc.toString() + "');" + PYTHON_NUMPY_PRINT_MATCHES); String pythonStdout = PythonHelper.runPythonScript(script, false); Assert.assertTrue(toString(), Boolean.parseBoolean(pythonStdout.trim())); } // Test we can load what we just saved @Test public void testSaveAndLoad() throws Exception { Dataset exp = createDataset(); File loc = getTempFile(); saveNumPyFile(exp, loc, unsigned); Dataset act = NumPyFileLoader.loadFileHelper(loc.toString()); if (unsigned) exp = DatasetFactory.createFromObject(unsigned, exp); Assert.assertEquals(toString(), exp, act); } }