package org.spongycastle.crypto.engines; import org.spongycastle.crypto.Digest; import org.spongycastle.crypto.params.NTRUEncryptionParameters; import org.spongycastle.util.Arrays; /** * An implementation of the Index Generation Function in IEEE P1363.1. */ public class IndexGenerator { private byte[] seed; private int N; private int c; private int minCallsR; private int totLen; private int remLen; private BitString buf; private int counter; private boolean initialized; private Digest hashAlg; private int hLen; /** * Constructs a new index generator. * * @param seed a seed of arbitrary length to initialize the index generator with * @param params NtruEncrypt parameters */ IndexGenerator(byte[] seed, NTRUEncryptionParameters params) { this.seed = seed; N = params.N; c = params.c; minCallsR = params.minCallsR; totLen = 0; remLen = 0; counter = 0; hashAlg = params.hashAlg; hLen = hashAlg.getDigestSize(); // hash length initialized = false; } /** * Returns a number <code>i</code> such that <code>0 <= i < N</code>. * * @return */ int nextIndex() { if (!initialized) { buf = new BitString(); byte[] hash = new byte[hashAlg.getDigestSize()]; while (counter < minCallsR) { appendHash(buf, hash); counter++; } totLen = minCallsR * 8 * hLen; remLen = totLen; initialized = true; } while (true) { totLen += c; BitString M = buf.getTrailing(remLen); if (remLen < c) { int tmpLen = c - remLen; int cThreshold = counter + (tmpLen + hLen - 1) / hLen; byte[] hash = new byte[hashAlg.getDigestSize()]; while (counter < cThreshold) { appendHash(M, hash); counter++; if (tmpLen > 8 * hLen) { tmpLen -= 8 * hLen; } } remLen = 8 * hLen - tmpLen; buf = new BitString(); buf.appendBits(hash); } else { remLen -= c; } int i = M.getLeadingAsInt(c); // assume c<32 if (i < (1 << c) - ((1 << c) % N)) { return i % N; } } } private void appendHash(BitString m, byte[] hash) { hashAlg.update(seed, 0, seed.length); putInt(hashAlg, counter); hashAlg.doFinal(hash, 0); m.appendBits(hash); } private void putInt(Digest hashAlg, int counter) { hashAlg.update((byte)(counter >> 24)); hashAlg.update((byte)(counter >> 16)); hashAlg.update((byte)(counter >> 8)); hashAlg.update((byte)counter); } /** * Represents a string of bits and supports appending, reading the head, and reading the tail. */ public static class BitString { byte[] bytes = new byte[4]; int numBytes; // includes the last byte even if only some of its bits are used int lastByteBits; // lastByteBits <= 8 /** * Appends all bits in a byte array to the end of the bit string. * * @param bytes a byte array */ void appendBits(byte[] bytes) { for (int i = 0; i != bytes.length; i++) { appendBits(bytes[i]); } } /** * Appends all bits in a byte to the end of the bit string. * * @param b a byte */ public void appendBits(byte b) { if (numBytes == bytes.length) { bytes = copyOf(bytes, 2 * bytes.length); } if (numBytes == 0) { numBytes = 1; bytes[0] = b; lastByteBits = 8; } else if (lastByteBits == 8) { bytes[numBytes++] = b; } else { int s = 8 - lastByteBits; bytes[numBytes - 1] |= (b & 0xFF) << lastByteBits; bytes[numBytes++] = (byte)((b & 0xFF) >> s); } } /** * Returns the last <code>numBits</code> bits from the end of the bit string. * * @param numBits number of bits * @return a new <code>BitString</code> of length <code>numBits</code> */ public BitString getTrailing(int numBits) { BitString newStr = new BitString(); newStr.numBytes = (numBits + 7) / 8; newStr.bytes = new byte[newStr.numBytes]; for (int i = 0; i < newStr.numBytes; i++) { newStr.bytes[i] = bytes[i]; } newStr.lastByteBits = numBits % 8; if (newStr.lastByteBits == 0) { newStr.lastByteBits = 8; } else { int s = 32 - newStr.lastByteBits; newStr.bytes[newStr.numBytes - 1] = (byte)(newStr.bytes[newStr.numBytes - 1] << s >>> s); } return newStr; } /** * Returns up to 32 bits from the beginning of the bit string. * * @param numBits number of bits * @return an <code>int</code> whose lower <code>numBits</code> bits are the beginning of the bit string */ public int getLeadingAsInt(int numBits) { int startBit = (numBytes - 1) * 8 + lastByteBits - numBits; int startByte = startBit / 8; int startBitInStartByte = startBit % 8; int sum = (bytes[startByte] & 0xFF) >>> startBitInStartByte; int shift = 8 - startBitInStartByte; for (int i = startByte + 1; i < numBytes; i++) { sum |= (bytes[i] & 0xFF) << shift; shift += 8; } return sum; } public byte[] getBytes() { return Arrays.clone(bytes); } } private static byte[] copyOf(byte[] src, int len) { byte[] tmp = new byte[len]; System.arraycopy(src, 0, tmp, 0, len < src.length ? len : src.length); return tmp; } }