/* * Copyright (c) 2015 Villu Ruusmann * * This file is part of JPMML-SkLearn * * JPMML-SkLearn is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-SkLearn is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>. */ package numpy.core; import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.nio.ByteOrder; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import com.google.common.io.ByteStreams; import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; import com.google.common.primitives.UnsignedInts; import net.razorvine.pickle.Unpickler; import net.razorvine.serpent.Parser; import net.razorvine.serpent.ast.Ast; import numpy.DType; import org.jpmml.converter.ValueUtil; import org.jpmml.sklearn.TupleUtil; public class NDArrayUtil { private NDArrayUtil(){ } static public int[] getShape(NDArray array){ Object[] shape = array.getShape(); List<? extends Number> values = (List)Arrays.asList(shape); return Ints.toArray(ValueUtil.asIntegers(values)); } /** * Gets the payload of a one-dimensional array. */ static public List<?> getContent(NDArray array){ Object content = array.getContent(); return asJavaList(array, (List<?>)content); } /** * Gets the payload of the specified dimension of a multi-dimensional array. * * @param key The dimension. */ static public List<?> getContent(NDArray array, String key){ Map<String, ?> content = (Map<String, ?>)array.getContent(); return asJavaList(array, (List<?>)content.get(key)); } static private <E> List<E> asJavaList(NDArray array, List<E> values){ boolean fortranOrder = array.getFortranOrder(); if(fortranOrder){ int[] shape = getShape(array); switch(shape.length){ case 1: return values; case 2: return toJavaList(values, shape[0], shape[1]); default: throw new IllegalArgumentException(); } } return values; } /** * Translates a column-major (ie. Fortran-type) array to a row-major (ie. C-type) array. */ static private <E> List<E> toJavaList(List<E> values, int rows, int columns){ List<E> result = new ArrayList<>(values.size()); for(int i = 0; i < values.size(); i++){ int row = i / columns; int column = i % columns; E value = values.get((column * rows) + row); result.add(value); } return result; } /** * http://docs.scipy.org/doc/numpy-dev/neps/npy-format.html */ static public NDArray parseNpy(InputStream is) throws IOException { byte[] magicBytes = new byte[MAGIC_STRING.length]; ByteStreams.readFully(is, magicBytes); if(!Arrays.equals(magicBytes, MAGIC_STRING)){ throw new IOException(); } int majorVersion = readUnsignedByte(is); int minorVersion = readUnsignedByte(is); if(majorVersion != 1 || minorVersion != 0){ throw new IOException(); } int headerLength = readUnsignedShort(is, ByteOrder.LITTLE_ENDIAN); if(headerLength < 0){ throw new IOException(); } byte[] headerBytes = new byte[headerLength]; ByteStreams.readFully(is, headerBytes); String header = new String(headerBytes); // Remove trailing whitespace header = header.trim(); Map<String, ?> headerDict = parseDict(header); Object descr = headerDict.get("descr"); Boolean fortranOrder = (Boolean)headerDict.get("fortran_order"); Object[] shape = (Object[])headerDict.get("shape"); byte[] data = ByteStreams.toByteArray(is); NDArray array = new NDArray(); array.__setstate__(new Object[]{Arrays.asList(majorVersion, minorVersion), shape, descr, fortranOrder, data}); return array; } static public Object parseData(InputStream is, Object descr, Object[] shape) throws IOException { if(descr instanceof DType){ DType dType = (DType)descr; descr = dType.toDescr(); } int length = 1; for(int i = 0; i < shape.length; i++){ length *= ValueUtil.asInt((Number)shape[i]); } // End if if(descr instanceof String){ return parseArray(is, (String)descr, length); } List<Object[]> dims = (List<Object[]>)descr; Map<String, List<?>> result = new LinkedHashMap<>(); List<Object[]> objects = parseMultiArray(is, (List)TupleUtil.extractElementList(dims, 1), length); for(int i = 0; i < dims.size(); i++){ Object[] dim = dims.get(i); result.put((String)dim[0], TupleUtil.extractElementList(objects, i)); } return result; } static public List<Object> parseArray(InputStream is, String descr, int length) throws IOException { List<Object> result = new ArrayList<>(length); TypeDescriptor descriptor = new TypeDescriptor(descr); while(result.size() < length){ Object element = descriptor.read(is); if(descriptor.isObject()){ NDArray array = (NDArray)element; result.addAll(NDArrayUtil.getContent(array)); continue; } result.add(element); } return result; } static public List<Object[]> parseMultiArray(InputStream is, List<String> descrs, int length) throws IOException { List<Object[]> result = new ArrayList<>(length); List<TypeDescriptor> descriptors = new ArrayList<>(); for(String descr : descrs){ TypeDescriptor descriptor = new TypeDescriptor(descr); if(descriptor.isObject()){ throw new IllegalArgumentException(descr); } descriptors.add(descriptor); } for(int i = 0; i < length; i++){ Object[] element = new Object[descriptors.size()]; for(int j = 0; j < descriptors.size(); j++){ TypeDescriptor descriptor = descriptors.get(j); element[j] = descriptor.read(is); } result.add(element); } return result; } static private Map<String, ?> parseDict(String string){ Parser parser = new Parser(); Ast ast = parser.parse(string); return (Map<String, ?>)ast.getData(); } static private byte readByte(InputStream is) throws IOException { int b = is.read(); if(b < 0){ throw new EOFException(); } return (byte)b; } static private int readUnsignedByte(InputStream is) throws IOException { int b = is.read(); if(b < 0){ throw new EOFException(); } return b; } static private int readUnsignedShort(InputStream is, ByteOrder byteOrder) throws IOException { byte b1 = readByte(is); byte b2 = readByte(is); if((ByteOrder.BIG_ENDIAN).equals(byteOrder)){ return Ints.fromBytes((byte)0, (byte)0, b1, b2); } else if((ByteOrder.LITTLE_ENDIAN).equals(byteOrder)){ return Ints.fromBytes((byte)0, (byte)0, b2, b1); } throw new IOException(); } static private int readInt(InputStream is, ByteOrder byteOrder) throws IOException { byte b1 = readByte(is); byte b2 = readByte(is); byte b3 = readByte(is); byte b4 = readByte(is); if((ByteOrder.BIG_ENDIAN).equals(byteOrder)){ return Ints.fromBytes(b1, b2, b3, b4); } else if((ByteOrder.LITTLE_ENDIAN).equals(byteOrder)){ return Ints.fromBytes(b4, b3, b2, b1); } throw new IOException(); } static private long readLong(InputStream is, ByteOrder byteOrder) throws IOException { byte b1 = readByte(is); byte b2 = readByte(is); byte b3 = readByte(is); byte b4 = readByte(is); byte b5 = readByte(is); byte b6 = readByte(is); byte b7 = readByte(is); byte b8 = readByte(is); if((ByteOrder.BIG_ENDIAN).equals(byteOrder)){ return Longs.fromBytes(b1, b2, b3, b4, b5, b6, b7, b8); } else if((ByteOrder.LITTLE_ENDIAN).equals(byteOrder)){ return Longs.fromBytes(b8, b7, b6, b5, b4, b3, b2, b1); } throw new IOException(); } static private float readFloat(InputStream is, ByteOrder byteOrder) throws IOException { return Float.intBitsToFloat(readInt(is, byteOrder)); } static private double readDouble(InputStream is, ByteOrder byteOrder) throws IOException { return Double.longBitsToDouble(readLong(is, byteOrder)); } static private Object readObject(InputStream is) throws IOException { Unpickler unpickler = new Unpickler(); return unpickler.load(is); } static private String readString(InputStream is, int size) throws IOException { byte[] buffer = new byte[size]; ByteStreams.readFully(is, buffer); return toString(buffer, "UTF-8"); } static private String readUnicode(InputStream is, ByteOrder byteOrder, int size) throws IOException { byte[] buffer = new byte[size * 4]; ByteStreams.readFully(is, buffer); if((ByteOrder.BIG_ENDIAN).equals(byteOrder)){ return toString(buffer, "UTF-32BE"); } else if((ByteOrder.LITTLE_ENDIAN).equals(byteOrder)){ return toString(buffer, "UTF-32LE"); } throw new IOException(); } static private String toString(byte[] buffer, String encoding) throws IOException { String string = new String(buffer, encoding); // Trim trailing zero characters while(string.length() > 0 && string.charAt(string.length() - 1) == '\0'){ string = string.substring(0, string.length() - 1); } return string; } /** * http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html * http://docs.scipy.org/doc/numpy/reference/generated/numpy.dtype.byteorder.html */ static private class TypeDescriptor { private ByteOrder byteOrder = null; private Kind kind = null; private int size = 0; private TypeDescriptor(String descr){ int i = 0; ByteOrder byteOrder = null; switch(descr.charAt(i)){ // Native case '=': byteOrder = ByteOrder.nativeOrder(); i++; break; // Big-endian case '>': byteOrder = ByteOrder.BIG_ENDIAN; i++; break; // Little-endian case '<': byteOrder = ByteOrder.LITTLE_ENDIAN; i++; break; // Not applicable case '|': i++; break; } setByteOrder(byteOrder); Kind kind = Kind.forChar(descr.charAt(i)); i++; setKind(kind); if(i < descr.length()){ int size = Integer.parseInt(descr.substring(i)); setSize(size); } } public Object read(InputStream is) throws IOException { Kind kind = getKind(); ByteOrder byteOrder = getByteOrder(); int size = getSize(); switch(kind){ case BOOLEAN: { switch(size){ case 1: return (readByte(is) == 1); default: break; } } break; case INTEGER: { switch(size){ case 4: return readInt(is, byteOrder); case 8: return readLong(is, byteOrder); default: break; } } break; case UNSIGNED_INTEGER: { switch(size){ case 4: return UnsignedInts.toLong(readInt(is, byteOrder)); default: break; } } break; case FLOAT: { switch(size){ case 4: return readFloat(is, byteOrder); case 8: return readDouble(is, byteOrder); default: break; } } break; case OBJECT: { return readObject(is); } case STRING: { return readString(is, size); } case UNICODE: { return readUnicode(is, byteOrder, size); } case VOID: { byte[] buffer = new byte[size]; ByteStreams.readFully(is, buffer); return buffer; } default: break; } throw new IOException(); } public boolean isObject(){ Kind kind = getKind(); switch(kind){ case OBJECT: return true; default: return false; } } public ByteOrder getByteOrder(){ return this.byteOrder; } private void setByteOrder(ByteOrder byteOrder){ this.byteOrder = byteOrder; } public Kind getKind(){ return this.kind; } private void setKind(Kind kind){ this.kind = kind; } public int getSize(){ return this.size; } private void setSize(int size){ this.size = size; } static private enum Kind { BOOLEAN, INTEGER, UNSIGNED_INTEGER, FLOAT, COMPLEX_FLOAT, OBJECT, STRING, UNICODE, VOID, ; static public Kind forChar(char c){ switch(c){ case 'b': return BOOLEAN; case 'i': return INTEGER; case 'u': return UNSIGNED_INTEGER; case 'f': return FLOAT; case 'c': return COMPLEX_FLOAT; case 'O': return OBJECT; case 'S': case 'a': return STRING; case 'U': return UNICODE; case 'V': return VOID; default: throw new IllegalArgumentException(); } } } } private static final byte[] MAGIC_STRING = {(byte)'\u0093', 'N', 'U', 'M', 'P', 'Y'}; }