/*******************************************************************************
* Copyright (c) 2013, 2014 Lectorius, Inc.
* Authors:
* Vijay Pandurangan (vijayp@mitro.co)
* Evan Jones (ej@mitro.co)
* Adam Hilss (ahilss@mitro.co)
*
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*
* You can contact the authors at inbound@mitro.co.
*******************************************************************************/
package co.mitro.twofactor;
import java.io.IOException;
import java.sql.SQLException;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import co.mitro.core.crypto.KeyInterfaces.KeyFactory;
import co.mitro.core.server.Manager;
import co.mitro.core.server.ManagerFactory;
import co.mitro.core.server.Templates.ResourceMustacheFactory;
import co.mitro.core.server.data.DBIdentity;
import com.github.mustachejava.Mustache;
import com.google.common.base.Strings;
public abstract class TwoFactorServlet extends HttpServlet {
private static final long serialVersionUID = 1L;
// TODO: Maybe this should use server.Templates instead of this static factory?
private static final ResourceMustacheFactory MUSTACHE_FACTORY = new ResourceMustacheFactory();
private static final Logger logger = LoggerFactory.getLogger(TwoFactorServlet.class);
protected final ManagerFactory managerFactory;
protected final KeyFactory keyFactory;
public enum HttpMethod {GET, POST}
protected final HttpMethod method;
public static final String INCORRECT_CODE_ERROR_MESSAGE =
"The verification code you entered was incorrect or expired";
public static final String INVALID_CODE_ERROR_MESSAGE =
"Invalid verification code";
public static final String MISSING_CODE_ERROR_MESSAGE =
"Missing verification code";
public static class TwoFactorRequestParams {
public String token;
public String signature;
public String username;
public String codeString;
public Long timestampMs = null;
public TwoFactorRequestParams(String token, String signature, String username, String codeString, Long timestampMs) {
this.token = token;
this.signature = signature;
this.username = username;
this.codeString = codeString;
this.timestampMs = timestampMs;
}
}
public static class TwoFactorRequestContext {
public HttpServletRequest request;
public HttpServletResponse response;
public TwoFactorRequestParams params;
public Manager manager;
public DBIdentity identity;
TwoFactorRequestContext(HttpServletRequest request, HttpServletResponse response,
TwoFactorRequestParams params, Manager manager, DBIdentity identity) {
this.request = request;
this.response = response;
this.params = params;
this.manager = manager;
this.identity = identity;
}
}
public TwoFactorServlet(ManagerFactory managerFactory, KeyFactory keyFactory, HttpMethod method) {
this.managerFactory = managerFactory;
this.keyFactory = keyFactory;
this.method = method;
}
public HttpMethod getMethod() {
return method;
}
public static void preventCaching(HttpServletResponse response) {
response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate"); // HTTP 1.1.
response.setHeader("Pragma", "no-cache"); // HTTP 1.0.
response.setDateHeader("Expires", 0); // Proxies.
}
public static void renderTemplate(String templateName, Object b, HttpServletResponse response)
throws IOException {
Mustache mustache = MUSTACHE_FACTORY.compilePackageRelative(
TwoFactorServlet.class, templateName);
mustache.execute(response.getWriter(), b);
}
// Try codeString as both a TFA code and a backup code.
// Returns null if verification succeeded and an error string on failure.
// IMPORTANT: caller must commit transaction on successful verification.
// This is required to reset the backup code, if one was used.
public String verifyCode(String codeString, String secret, long time, DBIdentity identity, Manager manager)
throws SQLException {
if (!Strings.isNullOrEmpty(codeString)) {
try {
long code = Long.parseLong(codeString);
if (TwoFactorCodeChecker.checkCode(secret, code, time) ||
CryptoForBackupCodes.tryBackupCode(secret, codeString, identity, manager)) {
return null;
} else {
logger.info("invalid code '{}' for user {}", codeString, identity.getName());
return INCORRECT_CODE_ERROR_MESSAGE;
}
} catch (NumberFormatException e) {
logger.info("invalid number '{}' for user {}", codeString, identity.getName());
return INVALID_CODE_ERROR_MESSAGE;
}
} else {
return MISSING_CODE_ERROR_MESSAGE;
}
}
abstract protected TwoFactorRequestParams parseRequestParams(HttpServletRequest request);
abstract protected void processRequest(TwoFactorRequestContext context) throws IOException, SQLException, ServletException;
abstract protected boolean verifyRequest(TwoFactorRequestContext context);
protected void doCommon(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
TwoFactorRequestParams params = parseRequestParams(request);
try (Manager mgr = managerFactory.newManager()) {
DBIdentity identity = DBIdentity.getIdentityForUserName(mgr, params.username);
TwoFactorRequestContext context =
new TwoFactorRequestContext(request, response, params, mgr, identity);
if (identity == null || !verifyRequest(context)) {
throw new ServletException("Invalid identity or signature");
}
processRequest(context);
} catch (SQLException e) {
throw new ServletException("Internal server error");
}
}
protected void doGet(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
if (method == HttpMethod.GET) {
preventCaching(response);
doCommon(request, response);
} else {
super.doGet(request, response);
}
}
protected void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException {
if (method == HttpMethod.POST) {
doCommon(request, response);
} else {
super.doPost(request, response);
}
}
}