package org.bouncycastle.pqc.crypto.xmss; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import org.bouncycastle.crypto.Digest; import org.bouncycastle.util.encoders.Hex; /** * Utils for XMSS implementation. * */ public class XMSSUtil { /** * Calculates the logarithm base 2 for a given Integer. * * @param n * Number. * @return Logarithm to base 2 of {@code n}. */ public static int log2(int n) { int log = 0; while ((n >>= 1) != 0) { log++; } return log; } /** * Convert int/long to n-byte array. * * @param value * int/long value. * @param sizeInByte * Size of byte array in byte. * @return int/long as big-endian byte array of size {@code sizeInByte}. */ public static byte[] toBytesBigEndian(long value, int sizeInByte) { byte[] out = new byte[sizeInByte]; for (int i = (sizeInByte - 1); i >= 0; i--) { out[i] = (byte) value; value >>>= 8; } return out; } /** * Copy int to byte array in big-endian at specific offset. * * @param Byte * array. * @param Integer * to put. * @param Offset * in {@code in}. */ public static void intToBytesBigEndianOffset(byte[] in, int value, int offset) { if (in == null) { throw new NullPointerException("in == null"); } if ((in.length - offset) < 4) { throw new IllegalArgumentException("not enough space in array"); } in[offset] = (byte) ((value >> 24) & 0xff); in[offset + 1] = (byte) ((value >> 16) & 0xff); in[offset + 2] = (byte) ((value >> 8) & 0xff); in[offset + 3] = (byte) ((value) & 0xff); } /** * Copy long to byte array in big-endian at specific offset. * * @param Byte * array. * @param Long * to put. * @param Offset * in {@code in}. */ public static void longToBytesBigEndianOffset(byte[] in, long value, int offset) { if (in == null) { throw new NullPointerException("in == null"); } if ((in.length - offset) < 8) { throw new IllegalArgumentException("not enough space in array"); } in[offset] = (byte) ((value >> 56) & 0xff); in[offset + 1] = (byte) ((value >> 48) & 0xff); in[offset + 2] = (byte) ((value >> 40) & 0xff); in[offset + 3] = (byte) ((value >> 32) & 0xff); in[offset + 4] = (byte) ((value >> 24) & 0xff); in[offset + 5] = (byte) ((value >> 16) & 0xff); in[offset + 6] = (byte) ((value >> 8) & 0xff); in[offset + 7] = (byte) ((value) & 0xff); } /** * Generic convert from big endian byte array to long. * * @param x-byte * array * @param offset. * @param size. * @return Long. */ public static long bytesToXBigEndian(byte[] in, int offset, int size) { if (in == null) { throw new NullPointerException("in == null"); } long res = 0; for (int i = offset; i < (offset + size); i++) { res = (res << 8) | (in[i] & 0xff); } return res; } /** * Clone a byte array. * * @param in * byte array. * @return Copy of byte array. */ public static byte[] cloneArray(byte[] in) { if (in == null) { throw new NullPointerException("in == null"); } byte[] out = new byte[in.length]; for (int i = 0; i < in.length; i++) { out[i] = in[i]; } return out; } /** * Clone a 2d byte array. * * @param in * 2d byte array. * @return Copy of 2d byte array. */ public static byte[][] cloneArray(byte[][] in) { if (hasNullPointer(in)) { throw new NullPointerException("in has null pointers"); } byte[][] out = new byte[in.length][]; for (int i = 0; i < in.length; i++) { out[i] = new byte[in[i].length]; for (int j = 0; j < in[i].length; j++) { out[i][j] = in[i][j]; } } return out; } /** * Concatenates an arbitrary number of byte arrays. * * @param arrays * Arrays that shall be concatenated. * @return Concatenated array. */ public static byte[] concat(byte[]... arrays) { int totalLength = 0; for (int i = 0; i < arrays.length; i++) { totalLength += arrays[i].length; } byte[] result = new byte[totalLength]; int currentIndex = 0; for (int i = 0; i < arrays.length; i++) { System.arraycopy(arrays[i], 0, result, currentIndex, arrays[i].length); currentIndex += arrays[i].length; } return result; } /** * Compares two byte arrays. * * @param a * byte array 1. * @param b * byte array 2. * @return true if all values in byte array are equal false else. */ public static boolean compareByteArray(byte[] a, byte[] b) { if (a == null || b == null) { throw new NullPointerException("a or b == null"); } if (a.length != b.length) { throw new IllegalArgumentException("size of a and b must be equal"); } for (int i = 0; i < a.length; i++) { if (a[i] != b[i]) { return false; } } return true; } /** * Compares two 2d-byte arrays. * * @param a * 2d-byte array 1. * @param b * 2d-byte array 2. * @return true if all values in 2d-byte array are equal false else. */ public static boolean compareByteArray(byte[][] a, byte[][] b) { if (hasNullPointer(a) || hasNullPointer(b)) { throw new NullPointerException("a or b == null"); } for (int i = 0; i < a.length; i++) { if (!compareByteArray(a[i], b[i])) { return false; } } return true; } /** * Dump content of 2d byte array. * * @param x * byte array. */ public static void dumpByteArray(byte[][] x) { if (hasNullPointer(x)) { throw new NullPointerException("x has null pointers"); } for (int i = 0; i < x.length; i++) { System.out.println(Hex.toHexString(x[i])); } } /** * Checks whether 2d byte array has null pointers. * * @param in * 2d byte array. * @return true if at least one null pointer is found false else. */ public static boolean hasNullPointer(byte[][] in) { if (in == null) { return true; } for (int i = 0; i < in.length; i++) { if (in[i] == null) { return true; } } return false; } /** * Copy src byte array to dst byte array at offset. * * @param dst * Destination. * @param src * Source. * @param offset * Destination offset. */ public static void copyBytesAtOffset(byte[] dst, byte[] src, int offset) { if (dst == null) { throw new NullPointerException("dst == null"); } if (src == null) { throw new NullPointerException("src == null"); } if (offset < 0) { throw new IllegalArgumentException("offset hast to be >= 0"); } if ((src.length + offset) > dst.length) { throw new IllegalArgumentException("src length + offset must not be greater than size of destination"); } for (int i = 0; i < src.length; i++) { dst[offset + i] = src[i]; } } /** * Copy length bytes at position offset from src. * * @param src * Source byte array. * @param offset * Offset in source byte array. * @param length * Length of bytes to copy. * @return New byte array. */ public static byte[] extractBytesAtOffset(byte[] src, int offset, int length) { if (src == null) { throw new NullPointerException("src == null"); } if (offset < 0) { throw new IllegalArgumentException("offset hast to be >= 0"); } if (length < 0) { throw new IllegalArgumentException("length hast to be >= 0"); } if ((offset + length) > src.length) { throw new IllegalArgumentException("offset + length must not be greater then size of source array"); } byte[] out = new byte[length]; for (int i = 0; i < out.length; i++) { out[i] = src[offset + i]; } return out; } /** * Check whether an index is valid or not. * * @param height * Height of binary tree. * @param index * Index to validate. * @return true if index is valid false else. */ public static boolean isIndexValid(int height, long index) { if (index < 0) { throw new IllegalStateException("index must not be negative"); } return index < (1L << height); } /** * Determine digest size of digest. * * @param digest * Digest. * @return Digest size. */ public static int getDigestSize(Digest digest) { if (digest == null) { throw new NullPointerException("digest == null"); } String algorithmName = digest.getAlgorithmName(); if (algorithmName.equals("SHAKE128")) { return 32; } if (algorithmName.equals("SHAKE256")) { return 64; } return digest.getDigestSize(); } public static long getTreeIndex(long index, int xmssTreeHeight) { return index >> xmssTreeHeight; } public static int getLeafIndex(long index, int xmssTreeHeight) { return (int) (index & ((1L << xmssTreeHeight) - 1L)); } public static byte[] serialize(Object obj) throws IOException { ByteArrayOutputStream out = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(out); oos.writeObject(obj); oos.flush(); return out.toByteArray(); } public static Object deserialize(byte[] data) throws IOException, ClassNotFoundException { ByteArrayInputStream in = new ByteArrayInputStream(data); ObjectInputStream is = new ObjectInputStream(in); return is.readObject(); } public static int calculateTau(int index, int height) { int tau = 0; for (int i = 0; i < height; i++) { if (((index >> i) & 1) == 0) { tau = i; break; } } return tau; } public static boolean isNewBDSInitNeeded(long globalIndex, int xmssHeight, int layer) { if (globalIndex == 0) { return false; } return (globalIndex % (long) Math.pow((1 << xmssHeight), layer + 1) == 0) ? true : false; } public static boolean isNewAuthenticationPathNeeded(long globalIndex, int xmssHeight, int layer) { if (globalIndex == 0) { return false; } return ((globalIndex + 1) % (long) Math.pow((1 << xmssHeight), layer) == 0) ? true : false; } }