/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.nifi.security.util.crypto.scrypt; import static java.lang.Integer.MAX_VALUE; import static java.lang.System.arraycopy; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; import java.security.SecureRandom; import java.util.ArrayList; import java.util.List; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import org.apache.commons.codec.binary.Base64; import org.apache.commons.lang3.StringUtils; import org.apache.nifi.security.util.crypto.CipherUtility; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Copyright (C) 2011 - Will Glozer. All rights reserved. * <p/> * Taken from Will Glozer's port of Colin Percival's C implementation. Glozer's project located at <a href="https://github.com/wg/scrypt">https://github.com/wg/scrypt</a> was released under the ASF * 2.0 license and has not been updated since May 25, 2013 and there are outstanding issues which have been patched in this version. * <p/> * An implementation of the <a href="http://www.tarsnap.com/scrypt/scrypt.pdf">scrypt</a> * key derivation function. * <p/> * Allows for hashing passwords using the * <a href="http://www.tarsnap.com/scrypt.html">scrypt</a> key derivation function * and comparing a plain text password to a hashed one. */ public class Scrypt { private static final Logger logger = LoggerFactory.getLogger(Scrypt.class); private static final int DEFAULT_SALT_LENGTH = 16; /** * Hash the supplied plaintext password and generate output in the format described * below: * <p/> * The hashed output is an * extended implementation of the Modular Crypt Format that also includes the scrypt * algorithm parameters. * <p/> * Format: <code>$s0$PARAMS$SALT$KEY</code>. * <p/> * <dl> * <dd>PARAMS</dd><dt>32-bit hex integer containing log2(N) (16 bits), r (8 bits), and p (8 bits)</dt> * <dd>SALT</dd><dt>base64-encoded salt</dt> * <dd>KEY</dd><dt>base64-encoded derived key</dt> * </dl> * <p/> * <code>s0</code> identifies version 0 of the scrypt format, using a 128-bit salt and 256-bit derived key. * <p/> * This method generates a 16 byte random salt internally. * * @param password password * @param n CPU cost parameter * @param r memory cost parameter * @param p parallelization parameter * @param dkLen the desired key length in bits * @return the hashed password */ public static String scrypt(String password, int n, int r, int p, int dkLen) { byte[] salt = new byte[DEFAULT_SALT_LENGTH]; new SecureRandom().nextBytes(salt); return scrypt(password, salt, n, r, p, dkLen); } /** * Hash the supplied plaintext password and generate output in the format described * in {@link Scrypt#scrypt(String, int, int, int, int)}. * * @param password password * @param salt the raw salt (16 bytes) * @param n CPU cost parameter * @param r memory cost parameter * @param p parallelization parameter * @param dkLen the desired key length in bits * @return the hashed password */ public static String scrypt(String password, byte[] salt, int n, int r, int p, int dkLen) { try { byte[] derived = deriveScryptKey(password.getBytes(StandardCharsets.UTF_8), salt, n, r, p, dkLen); return formatHash(salt, n, r, p, derived); } catch (GeneralSecurityException e) { throw new IllegalStateException("JVM doesn't support SHA1PRNG or HMAC_SHA256?"); } } public static String formatSalt(byte[] salt, int n, int r, int p) { String params = encodeParams(n, r, p); StringBuilder sb = new StringBuilder((salt.length) * 2); sb.append("$s0$").append(params).append('$'); sb.append(CipherUtility.encodeBase64NoPadding(salt)); return sb.toString(); } private static String encodeParams(int n, int r, int p) { return Long.toString(log2(n) << 16L | r << 8 | p, 16); } private static String formatHash(byte[] salt, int n, int r, int p, byte[] derived) { StringBuilder sb = new StringBuilder((salt.length + derived.length) * 2); sb.append(formatSalt(salt, n, r, p)).append('$'); sb.append(CipherUtility.encodeBase64NoPadding(derived)); return sb.toString(); } /** * Returns the expected memory cost of the provided parameters in bytes. * * @param n the N value, iterations >= 2 * @param r the r value, block size >= 1 * @param p the p value, parallelization factor >= 1 * @return the memory cost in bytes */ public static int calculateExpectedMemory(int n, int r, int p) { return 128 * r * n + 128 * r * p; } /** * Compare the supplied plaintext password to a hashed password. * * @param password plaintext password * @param hashed scrypt hashed password * @return true if password matches hashed value */ public static boolean check(String password, String hashed) { try { if (StringUtils.isEmpty(password)) { throw new IllegalArgumentException("Password cannot be empty"); } if (StringUtils.isEmpty(hashed)) { throw new IllegalArgumentException("Hash cannot be empty"); } String[] parts = hashed.split("\\$"); if (parts.length != 5 || !parts[1].equals("s0")) { throw new IllegalArgumentException("Hash is not properly formatted"); } List<Integer> splitParams = parseParameters(parts[2]); int n = splitParams.get(0); int r = splitParams.get(1); int p = splitParams.get(2); byte[] salt = Base64.decodeBase64(parts[3]); byte[] derived0 = Base64.decodeBase64(parts[4]); // Previously this was hard-coded to 32 bits but the publicly-available scrypt methods accept arbitrary bit lengths int hashLength = derived0.length * 8; byte[] derived1 = deriveScryptKey(password.getBytes(StandardCharsets.UTF_8), salt, n, r, p, hashLength); if (derived0.length != derived1.length) return false; int result = 0; for (int i = 0; i < derived0.length; i++) { result |= derived0[i] ^ derived1[i]; } return result == 0; } catch (GeneralSecurityException e) { throw new IllegalStateException("JVM doesn't support SHA1PRNG or HMAC_SHA256?"); } } /** * Parses the individual values from the encoded params value in the modified-mcrypt format for the salt & hash. * <p/> * Example: * <p/> * Hash: $s0$e0801$epIxT/h6HbbwHaehFnh/bw$7H0vsXlY8UxxyW/BWx/9GuY7jEvGjT71GFd6O4SZND0 * Params: e0801 * <p/> * N = 16384 * r = 8 * p = 1 * * @param encodedParams the String representation of the second section of the mcrypt format hash * @return a list containing N, r, p */ public static List<Integer> parseParameters(String encodedParams) { long params = Long.parseLong(encodedParams, 16); List<Integer> paramsList = new ArrayList<>(3); // Parse N, r, p from encoded value and add to return list paramsList.add((int) Math.pow(2, params >> 16 & 0xffff)); paramsList.add((int) params >> 8 & 0xff); paramsList.add((int) params & 0xff); return paramsList; } private static int log2(int n) { int log = 0; if ((n & 0xffff0000) != 0) { n >>>= 16; log = 16; } if (n >= 256) { n >>>= 8; log += 8; } if (n >= 16) { n >>>= 4; log += 4; } if (n >= 4) { n >>>= 2; log += 2; } return log + (n >>> 1); } /** * Implementation of the <a href="http://www.tarsnap.com/scrypt/scrypt.pdf">scrypt KDF</a>. * * @param password password * @param salt salt * @param n CPU cost parameter * @param r memory cost parameter * @param p parallelization parameter * @param dkLen intended length of the derived key in bits * @return the derived key * @throws GeneralSecurityException when HMAC_SHA256 is not available */ protected static byte[] deriveScryptKey(byte[] password, byte[] salt, int n, int r, int p, int dkLen) throws GeneralSecurityException { if (n < 2 || (n & (n - 1)) != 0) { throw new IllegalArgumentException("N must be a power of 2 greater than 1"); } if (r < 1) { throw new IllegalArgumentException("Parameter r must be 1 or greater"); } if (p < 1) { throw new IllegalArgumentException("Parameter p must be 1 or greater"); } if (n > MAX_VALUE / 128 / r) { throw new IllegalArgumentException("Parameter N is too large"); } // Must be enforced before r check if (p > MAX_VALUE / 128) { throw new IllegalArgumentException("Parameter p is too large"); } if (r > MAX_VALUE / 128 / p) { throw new IllegalArgumentException("Parameter r is too large"); } if (password == null || password.length == 0) { throw new IllegalArgumentException("Password cannot be empty"); } int saltLength = salt == null ? 0 : salt.length; if (salt == null || saltLength == 0) { // Do not enforce this check here. According to the scrypt spec, the salt can be empty. However, in the user-facing ScryptCipherProvider, enforce an arbitrary check to avoid empty salts logger.warn("An empty salt was used for scrypt key derivation"); // throw new IllegalArgumentException("Salt cannot be empty"); // as the Exception is not being thrown, prevent NPE if salt is null by setting it to empty array if( salt == null ) salt = new byte[]{}; } if (saltLength < 8 || saltLength > 32) { // Do not enforce this check here. According to the scrypt spec, the salt can be empty. However, in the user-facing ScryptCipherProvider, enforce an arbitrary check of [8..32] bytes logger.warn("A salt of length {} was used for scrypt key derivation", saltLength); // throw new IllegalArgumentException("Salt must be between 8 and 32 bytes"); } Mac mac = Mac.getInstance("HmacSHA256"); mac.init(new SecretKeySpec(password, "HmacSHA256")); byte[] b = new byte[128 * r * p]; byte[] xy = new byte[256 * r]; byte[] v = new byte[128 * r * n]; int i; pbkdf2(mac, salt, 1, b, p * 128 * r); for (i = 0; i < p; i++) { smix(b, i * 128 * r, r, n, v, xy); } byte[] dk = new byte[dkLen / 8]; pbkdf2(mac, b, 1, dk, dkLen / 8); return dk; } /** * Implementation of PBKDF2 (RFC2898). * * @param alg the HMAC algorithm to use * @param p the password * @param s the salt * @param c the iteration count * @param dkLen the intended length, in octets, of the derived key * @return The derived key */ private static byte[] pbkdf2(String alg, byte[] p, byte[] s, int c, int dkLen) throws GeneralSecurityException { Mac mac = Mac.getInstance(alg); mac.init(new SecretKeySpec(p, alg)); byte[] dk = new byte[dkLen]; pbkdf2(mac, s, c, dk, dkLen); return dk; } /** * Implementation of PBKDF2 (RFC2898). * * @param mac the pre-initialized {@link Mac} instance to use * @param s the salt * @param c the iteration count * @param dk the byte array that derived key will be placed in * @param dkLen the intended length, in octets, of the derived key * @throws GeneralSecurityException if the key length is too long */ private static void pbkdf2(Mac mac, byte[] s, int c, byte[] dk, int dkLen) throws GeneralSecurityException { int hLen = mac.getMacLength(); if (dkLen > (Math.pow(2, 32) - 1) * hLen) { throw new GeneralSecurityException("Requested key length too long"); } byte[] U = new byte[hLen]; byte[] T = new byte[hLen]; byte[] block1 = new byte[s.length + 4]; int l = (int) Math.ceil((double) dkLen / hLen); int r = dkLen - (l - 1) * hLen; arraycopy(s, 0, block1, 0, s.length); for (int i = 1; i <= l; i++) { block1[s.length + 0] = (byte) (i >> 24 & 0xff); block1[s.length + 1] = (byte) (i >> 16 & 0xff); block1[s.length + 2] = (byte) (i >> 8 & 0xff); block1[s.length + 3] = (byte) (i >> 0 & 0xff); mac.update(block1); mac.doFinal(U, 0); arraycopy(U, 0, T, 0, hLen); for (int j = 1; j < c; j++) { mac.update(U); mac.doFinal(U, 0); for (int k = 0; k < hLen; k++) { T[k] ^= U[k]; } } arraycopy(T, 0, dk, (i - 1) * hLen, (i == l ? r : hLen)); } } private static void smix(byte[] b, int bi, int r, int n, byte[] v, byte[] xy) { int xi = 0; int yi = 128 * r; int i; arraycopy(b, bi, xy, xi, 128 * r); for (i = 0; i < n; i++) { arraycopy(xy, xi, v, i * (128 * r), 128 * r); blockmix_salsa8(xy, xi, yi, r); } for (i = 0; i < n; i++) { int j = integerify(xy, xi, r) & (n - 1); blockxor(v, j * (128 * r), xy, xi, 128 * r); blockmix_salsa8(xy, xi, yi, r); } arraycopy(xy, xi, b, bi, 128 * r); } private static void blockmix_salsa8(byte[] by, int bi, int yi, int r) { byte[] X = new byte[64]; int i; arraycopy(by, bi + (2 * r - 1) * 64, X, 0, 64); for (i = 0; i < 2 * r; i++) { blockxor(by, i * 64, X, 0, 64); salsa20_8(X); arraycopy(X, 0, by, yi + (i * 64), 64); } for (i = 0; i < r; i++) { arraycopy(by, yi + (i * 2) * 64, by, bi + (i * 64), 64); } for (i = 0; i < r; i++) { arraycopy(by, yi + (i * 2 + 1) * 64, by, bi + (i + r) * 64, 64); } } private static int r(int a, int b) { return (a << b) | (a >>> (32 - b)); } private static void salsa20_8(byte[] b) { int[] b32 = new int[16]; int[] x = new int[16]; int i; for (i = 0; i < 16; i++) { b32[i] = (b[i * 4 + 0] & 0xff) << 0; b32[i] |= (b[i * 4 + 1] & 0xff) << 8; b32[i] |= (b[i * 4 + 2] & 0xff) << 16; b32[i] |= (b[i * 4 + 3] & 0xff) << 24; } arraycopy(b32, 0, x, 0, 16); for (i = 8; i > 0; i -= 2) { x[4] ^= r(x[0] + x[12], 7); x[8] ^= r(x[4] + x[0], 9); x[12] ^= r(x[8] + x[4], 13); x[0] ^= r(x[12] + x[8], 18); x[9] ^= r(x[5] + x[1], 7); x[13] ^= r(x[9] + x[5], 9); x[1] ^= r(x[13] + x[9], 13); x[5] ^= r(x[1] + x[13], 18); x[14] ^= r(x[10] + x[6], 7); x[2] ^= r(x[14] + x[10], 9); x[6] ^= r(x[2] + x[14], 13); x[10] ^= r(x[6] + x[2], 18); x[3] ^= r(x[15] + x[11], 7); x[7] ^= r(x[3] + x[15], 9); x[11] ^= r(x[7] + x[3], 13); x[15] ^= r(x[11] + x[7], 18); x[1] ^= r(x[0] + x[3], 7); x[2] ^= r(x[1] + x[0], 9); x[3] ^= r(x[2] + x[1], 13); x[0] ^= r(x[3] + x[2], 18); x[6] ^= r(x[5] + x[4], 7); x[7] ^= r(x[6] + x[5], 9); x[4] ^= r(x[7] + x[6], 13); x[5] ^= r(x[4] + x[7], 18); x[11] ^= r(x[10] + x[9], 7); x[8] ^= r(x[11] + x[10], 9); x[9] ^= r(x[8] + x[11], 13); x[10] ^= r(x[9] + x[8], 18); x[12] ^= r(x[15] + x[14], 7); x[13] ^= r(x[12] + x[15], 9); x[14] ^= r(x[13] + x[12], 13); x[15] ^= r(x[14] + x[13], 18); } for (i = 0; i < 16; ++i) b32[i] = x[i] + b32[i]; for (i = 0; i < 16; i++) { b[i * 4 + 0] = (byte) (b32[i] >> 0 & 0xff); b[i * 4 + 1] = (byte) (b32[i] >> 8 & 0xff); b[i * 4 + 2] = (byte) (b32[i] >> 16 & 0xff); b[i * 4 + 3] = (byte) (b32[i] >> 24 & 0xff); } } private static void blockxor(byte[] s, int si, byte[] d, int di, int len) { for (int i = 0; i < len; i++) { d[di + i] ^= s[si + i]; } } private static int integerify(byte[] b, int bi, int r) { int n; bi += (2 * r - 1) * 64; n = (b[bi + 0] & 0xff) << 0; n |= (b[bi + 1] & 0xff) << 8; n |= (b[bi + 2] & 0xff) << 16; n |= (b[bi + 3] & 0xff) << 24; return n; } public static int getDefaultSaltLength() { return DEFAULT_SALT_LENGTH; } }