package org.bouncycastle.pqc.crypto.xmss; import java.io.IOException; import java.security.SecureRandom; import java.text.ParseException; import java.util.Map; import java.util.TreeMap; /** * XMSS^MT. * */ public final class XMSSMT { private XMSSMTParameters params; private XMSS xmss; private SecureRandom prng; private KeyedHashFunctions khf; private XMSSMTPrivateKeyParameters privateKey; private XMSSMTPublicKeyParameters publicKey; /** * XMSSMT constructor... * * @param params * XMSSMTParameters. */ public XMSSMT(XMSSMTParameters params) { super(); if (params == null) { throw new NullPointerException("params == null"); } this.params = params; xmss = params.getXMSS(); prng = params.getXMSS().getParams().getPRNG(); khf = xmss.getKhf(); try { privateKey = new XMSSMTPrivateKeyParameters.Builder(params).build(); publicKey = new XMSSMTPublicKeyParameters.Builder(params).build(); } catch (ParseException e) { /* should not be possible */ e.printStackTrace(); } catch (ClassNotFoundException e) { /* should not be possible */ e.printStackTrace(); } catch (IOException e) { /* should not be possible */ e.printStackTrace(); } } /** * Generate a new XMSSMT private key / public key pair. * */ public void generateKeys() { /* generate XMSSMT private key */ privateKey = generatePrivateKey(); /* init global xmss */ XMSSPrivateKeyParameters xmssPrivateKey = null; XMSSPublicKeyParameters xmssPublicKey = null; try { xmssPrivateKey = new XMSSPrivateKeyParameters.Builder(xmss.getParams()) .withSecretKeySeed(privateKey.getSecretKeySeed()).withSecretKeyPRF(privateKey.getSecretKeyPRF()) .withPublicSeed(privateKey.getPublicSeed()).withBDSState(new BDS(xmss)).build(); xmssPublicKey = new XMSSPublicKeyParameters.Builder(xmss.getParams()).withPublicSeed(getPublicSeed()) .build(); } catch (ParseException ex) { /* should not be possible */ ex.printStackTrace(); } catch (ClassNotFoundException e) { /* should not be possible */ e.printStackTrace(); } catch (IOException e) { /* should not be possible */ e.printStackTrace(); } /* import to xmss */ try { xmss.importState(xmssPrivateKey.toByteArray(), xmssPublicKey.toByteArray()); } catch (ParseException e) { e.printStackTrace(); } catch (ClassNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } /* get root */ int rootLayerIndex = params.getLayers() - 1; OTSHashAddress otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder().withLayerAddress(rootLayerIndex) .build(); /* store BDS instance of root xmss instance */ BDS bdsRoot = new BDS(xmss); XMSSNode root = bdsRoot.initialize(otsHashAddress); getBDSState().put(rootLayerIndex, bdsRoot); xmss.setRoot(root.getValue()); /* set XMSS^MT root / create public key */ try { privateKey = new XMSSMTPrivateKeyParameters.Builder(params).withSecretKeySeed(privateKey.getSecretKeySeed()) .withSecretKeyPRF(privateKey.getSecretKeyPRF()).withPublicSeed(privateKey.getPublicSeed()) .withRoot(xmss.getRoot()).withBDSState(privateKey.getBDSState()).build(); publicKey = new XMSSMTPublicKeyParameters.Builder(params).withRoot(root.getValue()) .withPublicSeed(getPublicSeed()).build(); } catch (ParseException e) { /* should not be possible */ e.printStackTrace(); } catch (ClassNotFoundException e) { /* should not be possible */ e.printStackTrace(); } catch (IOException e) { /* should not be possible */ e.printStackTrace(); } } private XMSSMTPrivateKeyParameters generatePrivateKey() { int n = params.getDigestSize(); byte[] secretKeySeed = new byte[n]; prng.nextBytes(secretKeySeed); byte[] secretKeyPRF = new byte[n]; prng.nextBytes(secretKeyPRF); byte[] publicSeed = new byte[n]; prng.nextBytes(publicSeed); XMSSMTPrivateKeyParameters privateKey = null; try { privateKey = new XMSSMTPrivateKeyParameters.Builder(params).withSecretKeySeed(secretKeySeed) .withSecretKeyPRF(secretKeyPRF).withPublicSeed(publicSeed) .withBDSState(this.privateKey.getBDSState()).build(); } catch (ParseException ex) { /* should not be possible */ ex.printStackTrace(); } catch (ClassNotFoundException e) { /* should not be possible */ e.printStackTrace(); } catch (IOException e) { /* should not be possible */ e.printStackTrace(); } return privateKey; } /** * Import XMSSMT private key / public key pair. * * @param privateKey * XMSSMT private key. * @param publicKey * XMSSMT public key. * @throws ParseException * @throws ClassNotFoundException * @throws IOException */ public void importState(byte[] privateKey, byte[] publicKey) throws ParseException, ClassNotFoundException, IOException { if (privateKey == null) { throw new NullPointerException("privateKey == null"); } if (publicKey == null) { throw new NullPointerException("publicKey == null"); } XMSSMTPrivateKeyParameters xmssMTPrivateKey = new XMSSMTPrivateKeyParameters.Builder(params) .withPrivateKey(privateKey, xmss).build(); XMSSMTPublicKeyParameters xmssMTPublicKey = new XMSSMTPublicKeyParameters.Builder(params) .withPublicKey(publicKey).build(); if (!XMSSUtil.compareByteArray(xmssMTPrivateKey.getRoot(), xmssMTPublicKey.getRoot())) { throw new IllegalStateException("root of private key and public key do not match"); } if (!XMSSUtil.compareByteArray(xmssMTPrivateKey.getPublicSeed(), xmssMTPublicKey.getPublicSeed())) { throw new IllegalStateException("public seed of private key and public key do not match"); } /* init global xmss */ XMSSPrivateKeyParameters xmssPrivateKey = new XMSSPrivateKeyParameters.Builder(xmss.getParams()) .withSecretKeySeed(xmssMTPrivateKey.getSecretKeySeed()) .withSecretKeyPRF(xmssMTPrivateKey.getSecretKeyPRF()).withPublicSeed(xmssMTPrivateKey.getPublicSeed()) .withRoot(xmssMTPrivateKey.getRoot()).withBDSState(new BDS(xmss)).build(); XMSSPublicKeyParameters xmssPublicKey = new XMSSPublicKeyParameters.Builder(xmss.getParams()) .withRoot(xmssMTPrivateKey.getRoot()).withPublicSeed(getPublicSeed()).build(); /* import to xmss */ xmss.importState(xmssPrivateKey.toByteArray(), xmssPublicKey.toByteArray()); this.privateKey = xmssMTPrivateKey; this.publicKey = xmssMTPublicKey; } /** * Sign message. * * @param message * Message to sign. * @return XMSSMT signature on digest of message. */ public byte[] sign(byte[] message) { if (message == null) { throw new NullPointerException("message == null"); } if (getBDSState().isEmpty()) { throw new IllegalStateException("not initialized"); } // privateKey.increaseIndex(this); long globalIndex = getIndex(); int totalHeight = params.getHeight(); int xmssHeight = xmss.getParams().getHeight(); if (!XMSSUtil.isIndexValid(totalHeight, globalIndex)) { throw new IllegalArgumentException("index out of bounds"); } /* compress message */ byte[] random = khf.PRF(privateKey.getSecretKeyPRF(), XMSSUtil.toBytesBigEndian(globalIndex, 32)); byte[] concatenated = XMSSUtil.concat(random, privateKey.getRoot(), XMSSUtil.toBytesBigEndian(globalIndex, params.getDigestSize())); byte[] messageDigest = khf.HMsg(concatenated, message); XMSSMTSignature signature = null; try { signature = new XMSSMTSignature.Builder(params).withIndex(globalIndex).withRandom(random).build(); } catch (ParseException ex) { /* should not be possible */ ex.printStackTrace(); } /* layer 0 */ long indexTree = XMSSUtil.getTreeIndex(globalIndex, xmssHeight); int indexLeaf = XMSSUtil.getLeafIndex(globalIndex, xmssHeight); /* reset xmss */ xmss.setIndex(indexLeaf); xmss.setPublicSeed(getPublicSeed()); /* create signature with XMSS tree on layer 0 */ /* adjust addresses */ OTSHashAddress otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder().withTreeAddress(indexTree) .withOTSAddress(indexLeaf).build(); /* sign message digest */ WOTSPlusSignature wotsPlusSignature = xmss.wotsSign(messageDigest, otsHashAddress); /* get authentication path from BDS */ if (getBDSState().get(0) == null || indexLeaf == 0) { getBDSState().put(0, new BDS(xmss)); getBDSState().get(0).initialize(otsHashAddress); } XMSSReducedSignature reducedSignature = null; try { reducedSignature = new XMSSReducedSignature.Builder(xmss.getParams()) .withWOTSPlusSignature(wotsPlusSignature).withAuthPath(getBDSState().get(0).getAuthenticationPath()) .build(); } catch (ParseException ex) { /* should never happen */ ex.printStackTrace(); } signature.getReducedSignatures().add(reducedSignature); /* prepare authentication path for next leaf */ if (indexLeaf < ((1 << xmssHeight) - 1)) { getBDSState().get(0).nextAuthenticationPath(otsHashAddress); } /* loop over remaining layers */ for (int layer = 1; layer < params.getLayers(); layer++) { /* get root of layer - 1 */ XMSSNode root = getBDSState().get(layer - 1).getRoot(); indexLeaf = XMSSUtil.getLeafIndex(indexTree, xmssHeight); indexTree = XMSSUtil.getTreeIndex(indexTree, xmssHeight); xmss.setIndex(indexLeaf); /* adjust addresses */ otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder().withLayerAddress(layer) .withTreeAddress(indexTree).withOTSAddress(indexLeaf).build(); /* sign root digest of layer - 1 */ wotsPlusSignature = xmss.wotsSign(root.getValue(), otsHashAddress); /* get authentication path from BDS */ if (getBDSState().get(layer) == null || XMSSUtil.isNewBDSInitNeeded(globalIndex, xmssHeight, layer)) { getBDSState().put(layer, new BDS(xmss)); getBDSState().get(layer).initialize(otsHashAddress); } try { reducedSignature = new XMSSReducedSignature.Builder(xmss.getParams()) .withWOTSPlusSignature(wotsPlusSignature) .withAuthPath(getBDSState().get(layer).getAuthenticationPath()).build(); } catch (ParseException ex) { /* should never happen */ ex.printStackTrace(); } signature.getReducedSignatures().add(reducedSignature); /* prepare authentication path for next leaf */ if (indexLeaf < ((1 << xmssHeight) - 1) && XMSSUtil.isNewAuthenticationPathNeeded(globalIndex, xmssHeight, layer)) { getBDSState().get(layer).nextAuthenticationPath(otsHashAddress); } } /* update private key */ try { privateKey = new XMSSMTPrivateKeyParameters.Builder(params).withIndex(globalIndex + 1) .withSecretKeySeed(privateKey.getSecretKeySeed()).withSecretKeyPRF(privateKey.getSecretKeyPRF()) .withPublicSeed(privateKey.getPublicSeed()).withRoot(privateKey.getRoot()) .withBDSState(privateKey.getBDSState()).build(); } catch (ParseException e) { /* should not be possible */ e.printStackTrace(); } catch (ClassNotFoundException e) { /* should not be possible */ e.printStackTrace(); } catch (IOException e) { /* should not be possible */ e.printStackTrace(); } return signature.toByteArray(); } /** * Verify an XMSSMT signature. * * @param message * Message. * @param signature * XMSSMT signature. * @param publicKey * XMSSMT public key. * @return true if signature is valid false else. * @throws ParseException */ public boolean verifySignature(byte[] message, byte[] signature, byte[] publicKey) throws ParseException { if (message == null) { throw new NullPointerException("message == null"); } if (signature == null) { throw new NullPointerException("signature == null"); } if (publicKey == null) { throw new NullPointerException("publicKey == null"); } /* (re)create compressed message */ XMSSMTSignature sig = new XMSSMTSignature.Builder(params).withSignature(signature).build(); XMSSMTPublicKeyParameters pubKey = new XMSSMTPublicKeyParameters.Builder(params).withPublicKey(publicKey) .build(); byte[] concatenated = XMSSUtil.concat(sig.getRandom(), pubKey.getRoot(), XMSSUtil.toBytesBigEndian(sig.getIndex(), params.getDigestSize())); byte[] messageDigest = khf.HMsg(concatenated, message); long globalIndex = sig.getIndex(); int xmssHeight = xmss.getParams().getHeight(); long indexTree = XMSSUtil.getTreeIndex(globalIndex, xmssHeight); int indexLeaf = XMSSUtil.getLeafIndex(globalIndex, xmssHeight); /* adjust xmss */ xmss.setIndex(indexLeaf); xmss.setPublicSeed(pubKey.getPublicSeed()); /* prepare addresses */ OTSHashAddress otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder().withTreeAddress(indexTree) .withOTSAddress(indexLeaf).build(); /* get root node on layer 0 */ XMSSReducedSignature xmssMTSignature = sig.getReducedSignatures().get(0); XMSSNode rootNode = xmss.getRootNodeFromSignature(messageDigest, xmssMTSignature, otsHashAddress); for (int layer = 1; layer < params.getLayers(); layer++) { xmssMTSignature = sig.getReducedSignatures().get(layer); indexLeaf = XMSSUtil.getLeafIndex(indexTree, xmssHeight); indexTree = XMSSUtil.getTreeIndex(indexTree, xmssHeight); xmss.setIndex(indexLeaf); /* adjust address */ otsHashAddress = (OTSHashAddress) new OTSHashAddress.Builder().withLayerAddress(layer) .withTreeAddress(indexTree).withOTSAddress(indexLeaf).build(); /* get root node */ rootNode = xmss.getRootNodeFromSignature(rootNode.getValue(), xmssMTSignature, otsHashAddress); } /* compare roots */ return XMSSUtil.compareByteArray(rootNode.getValue(), pubKey.getRoot()); } /** * Export XMSSMT private key. * * @return XMSSMT private key. */ public byte[] exportPrivateKey() { return privateKey.toByteArray(); } /** * Export XMSSMT public key. * * @return XMSSMT public key. */ public byte[] exportPublicKey() { return publicKey.toByteArray(); } /** * Getter XMSSMT params. * * @return XMSSMT params. */ public XMSSMTParameters getParams() { return params; } /** * Getter XMSSMT index. * * @return XMSSMT index. */ public long getIndex() { return privateKey.getIndex(); } /** * Getter public seed. * * @return Public seed. */ public byte[] getPublicSeed() { return privateKey.getPublicSeed(); } protected Map<Integer, BDS> getBDSState() { return privateKey.getBDSState(); } protected XMSS getXMSS() { return xmss; } }