package me.brandonc.security.totp.util; //Copyright (C) 2009 Google Inc. import java.io.ByteArrayInputStream; import java.io.DataInput; import java.io.DataInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.security.GeneralSecurityException; import javax.crypto.Mac; /** * An implementation of the HOTP generator specified by RFC 4226. Generates * short passcodes that may be used in challenge-response protocols or as * timeout passcodes that are only valid for a short period. * * The default passcode is a 6-digit decimal code and the default timeout period * is 5 minutes. * * @author sweis@google.com (Steve Weis) * */ public class PasscodeGenerator { /** Default decimal passcode length */ private static final int PASS_CODE_LENGTH = 6; /** Default passcode timeout period (in seconds) */ static final int INTERVAL = 30; /** The number of previous and future intervals to check */ private static final int ADJACENT_INTERVALS = 1; private static final int PIN_MODULO = (int) Math.pow(10, PASS_CODE_LENGTH); private final Signer signer; private final int codeLength; private final int intervalPeriod; /* * Using an interface to allow us to inject different signature * implementations. */ interface Signer { byte[] sign(byte[] data) throws GeneralSecurityException; } /** * @param mac * A {@link Mac} used to generate passcodes */ public PasscodeGenerator(Mac mac) { this(mac, PASS_CODE_LENGTH, INTERVAL); } /** * @param mac * A {@link Mac} used to generate passcodes * @param passCodeLength * The length of the decimal passcode * @param interval * The interval that a passcode is valid for */ public PasscodeGenerator(final Mac mac, int passCodeLength, int interval) { this(new Signer() { @Override public byte[] sign(byte[] data) { return mac.doFinal(data); } }, passCodeLength, interval); } public PasscodeGenerator(Signer signer, int passCodeLength, int interval) { this.signer = signer; this.codeLength = passCodeLength; this.intervalPeriod = interval; } private String padOutput(int value) { String result = Integer.toString(value); for (int i = result.length(); i < codeLength; i++) { result = "0" + result; } return result; } /** * @return A decimal timeout code */ public String generateTimeoutCode() throws GeneralSecurityException { return generateResponseCode(clock.getCurrentInterval()); } /** * @param challenge * A long-valued challenge * @return A decimal response code * @throws GeneralSecurityException * If a JCE exception occur */ public String generateResponseCode(long challenge) throws GeneralSecurityException { byte[] value = ByteBuffer.allocate(8).putLong(challenge).array(); return generateResponseCode(value); } /** * @param challenge * An arbitrary byte array used as a challenge * @return A decimal response code * @throws GeneralSecurityException * If a JCE exception occur */ public String generateResponseCode(byte[] challenge) throws GeneralSecurityException { byte[] hash = signer.sign(challenge); // Dynamically truncate the hash // OffsetBits are the low order bits of the last byte of the hash int offset = hash[hash.length - 1] & 0xF; // Grab a positive integer value starting at the given offset. int truncatedHash = hashToInt(hash, offset) & 0x7FFFFFFF; int pinValue = truncatedHash % PIN_MODULO; return padOutput(pinValue); } /** * Grabs a positive integer value from the input array starting at the given * offset. * * @param bytes * the array of bytes * @param start * the index into the array to start grabbing bytes * @return the integer constructed from the four bytes in the array */ private int hashToInt(byte[] bytes, int start) { DataInput input = new DataInputStream(new ByteArrayInputStream(bytes, start, bytes.length - start)); int val; try { val = input.readInt(); } catch (IOException e) { throw new IllegalStateException(e); } return val; } /** * @param challenge * A challenge to check a response against * @param response * A response to verify * @return True if the response is valid */ public boolean verifyResponseCode(long challenge, String response) throws GeneralSecurityException { String expectedResponse = generateResponseCode(challenge); return expectedResponse.equals(response); } /** * Verify a timeout code. The timeout code will be valid for a time * determined by the interval period and the number of adjacent intervals * checked. * * @param timeoutCode * The timeout code * @return True if the timeout code is valid */ public boolean verifyTimeoutCode(String timeoutCode) throws GeneralSecurityException { return verifyTimeoutCode(timeoutCode, ADJACENT_INTERVALS, ADJACENT_INTERVALS); } /** * Verify a timeout code. The timeout code will be valid for a time * determined by the interval period and the number of adjacent intervals * checked. * * @param timeoutCode * The timeout code * @param pastIntervals * The number of past intervals to check * @param futureIntervals * The number of future intervals to check * @return True if the timeout code is valid */ public boolean verifyTimeoutCode(String timeoutCode, int pastIntervals, int futureIntervals) throws GeneralSecurityException { long currentInterval = clock.getCurrentInterval(); String expectedResponse = generateResponseCode(currentInterval); if (expectedResponse.equals(timeoutCode)) { return true; } for (int i = 1; i <= pastIntervals; i++) { String pastResponse = generateResponseCode(currentInterval - i); if (pastResponse.equals(timeoutCode)) { return true; } } for (int i = 1; i <= futureIntervals; i++) { String futureResponse = generateResponseCode(currentInterval + i); if (futureResponse.equals(timeoutCode)) { return true; } } return false; } private IntervalClock clock = new IntervalClock() { /* * @return The current interval */ @Override public long getCurrentInterval() { long currentTimeSeconds = System.currentTimeMillis() / 1000; return currentTimeSeconds / getIntervalPeriod(); } @Override public int getIntervalPeriod() { return intervalPeriod; } }; // To facilitate injecting a mock clock interface IntervalClock { int getIntervalPeriod(); long getCurrentInterval(); } }