/* * Copyright (c) 2012 Diamond Light Source Ltd. * * All rights reserved. This program and the accompanying materials * are made available under the terms of the Eclipse Public License v1.0 * which accompanies this distribution, and is available at * http://www.eclipse.org/legal/epl-v10.html */ package uk.ac.diamond.scisoft.python; import java.lang.reflect.Array; import java.math.BigInteger; import java.util.Arrays; import java.util.List; import org.apache.commons.math3.complex.Complex; import org.eclipse.january.DatasetException; import org.eclipse.january.dataset.BroadcastUtils; import org.eclipse.january.dataset.Dataset; import org.eclipse.january.dataset.DatasetFactory; import org.eclipse.january.dataset.IDataset; import org.eclipse.january.dataset.ILazyDataset; import org.eclipse.january.dataset.Slice; import org.eclipse.january.dataset.SliceND; import org.python.core.Py; import org.python.core.PyArray; import org.python.core.PyComplex; import org.python.core.PyEllipsis; import org.python.core.PyException; import org.python.core.PyInteger; import org.python.core.PyNone; import org.python.core.PyObject; import org.python.core.PySequence; import org.python.core.PySequenceList; import org.python.core.PySlice; import org.python.core.PyString; /** * Class of utilities for interfacing with Jython */ public class PythonUtils { /** * Convert tuples/lists of tuples/lists to Java lists of lists. Also convert complex numbers to Apache Commons * version and Jython strings to strings * * @param obj * @return converted object */ public static Object convertToJava(Object obj) { if (obj == null || obj instanceof PyNone) return null; if (obj instanceof PySequenceList) { obj = ((PySequenceList) obj).toArray(); } if (obj instanceof PyArray) { obj = ((PyArray) obj).getArray(); } if (obj instanceof List<?>) { @SuppressWarnings("unchecked") List<Object> jl = (List<Object>) obj; int l = jl.size(); for (int i = 0; i < l; i++) { Object lo = jl.get(i); if (lo instanceof PyObject) { jl.set(i, convertToJava(lo)); } } return obj; } if (obj.getClass().isArray()) { int l = Array.getLength(obj); for (int i = 0; i < l; i++) { Object lo = Array.get(obj, i); if (lo instanceof PyObject) { Array.set(obj, i, convertToJava(lo)); } } return obj; } if (obj instanceof BigInteger || !(obj instanceof PyObject)) return obj; if (obj instanceof PyComplex) { PyComplex z = (PyComplex) obj; return new Complex(z.real, z.imag); } else if (obj instanceof PyString) { return obj.toString(); } return ((PyObject) obj).__tojava__(Object.class); } static class SliceData { SliceND slice; // slices /** * Required or output shape */ int[] shape; } /** * Convert an array of python slice objects to a slice array * * @param indexes * @param shape * @return slice array */ private static SliceData convertPySlicesToSlice(final PyObject indexes, final int[] shape) { PyObject indices[] = indexes instanceof PySequenceList ? ((PySequenceList) indexes).getArray() : new PyObject[] { indexes }; int orank = shape.length; int na = 0; // count new axes int nc = 0; // count collapsed dimensions int ns = 0; // count slices or ellipses for (int j = 0; j < indices.length; j++) { PyObject index = indices[j]; if (index instanceof PyNone) na++; else if (index instanceof PyInteger) nc++; else if (index instanceof PySlice) ns++; else if (index instanceof PyEllipsis) ns++; } int spare = orank - nc - ns; // number of spare dimensions SliceND slice = new SliceND(shape); boolean hasEllipse = false; boolean[] sdim = new boolean[orank]; // flag which dimensions are sliced int[] axes = new int[na]; // new axes int i = 0; int a = 0; // new axes int c = 0; // collapsed dimensions for (int j = 0; i < orank && j < indices.length; j++) { PyObject index = indices[j]; if (index instanceof PyEllipsis) { sdim[i++] = true; if (!hasEllipse) { // pad out with full slices on first ellipse hasEllipse = true; for (int k = 0; k < spare; k++) { sdim[i++] = true; } } } else if (index instanceof PyInteger) { int n = ((PyInteger) index).getValue(); if (n < -shape[i] || n >= shape[i]) { throw new PyException(Py.IndexError); } if (n < 0) { n += shape[i]; } sdim[i] = false; // nb specifying indexes whilst using slices will reduce rank slice.setSlice(i++, n, n + 1, 1); c++; } else if (index instanceof PySlice) { PySlice pyslice = (PySlice) index; sdim[i] = true; Slice nslice = convertToSlice(pyslice); slice.setSlice(i++, nslice.getStart(), nslice.getStop(), nslice.getStep()); } else if (index instanceof PyNone) { // newaxis axes[a++] = (hasEllipse ? j + spare : j) - c; } else { throw new IllegalArgumentException("Unexpected item in indexing"); } } assert nc == c; while (i < orank) { sdim[i++] = true; } while (a < na) { axes[a] = i - c + a; a++; } int[] sShape = slice.getShape(); int[] newShape = new int[orank - nc]; i = 0; for (int j = 0; i < orank; i++) { if (sdim[i]) { newShape[j++] = sShape[i]; } } if (na > 0) { int[] oldShape = newShape; newShape = new int[newShape.length + na]; i = 0; for (int k = 0, j = 0; i < newShape.length; i++) { if (k < na && i == axes[k]) { k++; newShape[i] = 1; } else { newShape[i] = oldShape[j++]; } } } SliceData sd = new SliceData(); sd.slice = slice; sd.shape = newShape; return sd; } /** * @param pyslice * @return slice */ public static Slice convertToSlice(PySlice pyslice) { return new Slice(pyslice.start instanceof PyNone ? null : ((PyInteger) pyslice.start).getValue(), pyslice.stop instanceof PyNone ? null : ((PyInteger) pyslice.stop).getValue(), pyslice.step instanceof PyNone ? null : ((PyInteger) pyslice.step).getValue()); } /** * @param indexes * @param shape * @return N-D slice */ public static SliceND convertToSliceND(PyObject indexes, int[] shape) { return convertPySlicesToSlice(indexes, shape).slice; } /** * Jython method to get slice within a dataset * * @param a * dataset * @param indexes * can be a mixed array of integers or slices * @return dataset of specified sub-dataset * @throws DatasetException */ public static IDataset getSlice(final ILazyDataset a, final PyObject indexes) throws DatasetException { int[] shape = a.getShape(); SliceData slice = convertPySlicesToSlice(indexes, shape); IDataset dataSlice; if (a instanceof IDataset) { dataSlice = ((IDataset) a).getSliceView(slice.slice); } else { dataSlice = a.getSlice(slice.slice); } dataSlice.setShape(slice.shape); return dataSlice; } /** * Jython method to set slice within a dataset * * @param a * dataset * @param object * can an item or a dataset * @param indexes * can be a mixed array of integers or slices */ public static void setSlice(Dataset a, Object object, final PyObject indexes) { if (a.isComplex() || a.getElementsPerItem() == 1) { if (object instanceof PySequence) { object = DatasetFactory.createFromObject(a.getDType(), object); } } SliceData slice = convertPySlicesToSlice(indexes, a.getShapeRef()); if (object instanceof IDataset) { IDataset d = (IDataset) object; int[] iShape = d instanceof Dataset ? ((Dataset) d).getShapeRef() : d.getShape(); int[] sShape = slice.slice.getShape(); if (!Arrays.equals(iShape, slice.shape)) { // check input shape matches required one try { if (iShape.length > slice.shape.length) { BroadcastUtils.broadcastShapesToMax(iShape, slice.shape); iShape = slice.shape; } else { iShape = BroadcastUtils.broadcastShapesToMax(slice.shape, iShape).get(0); } } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Input dataset shape must match slice shape"); } } else if (!Arrays.equals(iShape, sShape)) { // check input shape matches slice shape if (iShape.length > sShape.length) { BroadcastUtils.broadcastShapesToMax(iShape, sShape); } iShape = sShape; } d = d.getSliceView(); try { d.setShape(iShape); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Input dataset could not be set to slice shape"); } object = d; } a.setSlice(object, slice.slice); } /** * Create a dataset from object (as workaround for Jython's funky dispatcher calling wrong method) * @param dtype * @param obj * can be a Java list, array or Number * @return dataset * @throws IllegalArgumentException if dataset type is not known */ public static Dataset createFromObject(final Integer dtype, final Object obj) { return DatasetFactory.createFromObject(dtype, obj, null); } /** * Create dataset with items ranging from given start to given stop in given steps * <p> * Use this to get around the overloaded method problem in Jython * @param start * @param stop * @param step * @param dtype * @return a new 1D dataset of given type, filled with values determined by parameters */ public static Dataset createRange(final double start, final double stop, final double step, final int dtype) { return DatasetFactory.createRange(start, stop, step, dtype); } }