/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.services.managers;
import org.jboss.logging.Logger;
import org.keycloak.common.ClientConnection;
import org.keycloak.common.util.Time;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.RealmModel;
import org.keycloak.models.UserLoginFailureModel;
import org.keycloak.models.UserModel;
import org.keycloak.services.ServicesLogger;
import java.util.ArrayList;
import java.util.Collections;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
/**
* A single thread will log failures. This is so that we can avoid concurrent writes as we want an accurate failure count
*
* @author <a href="mailto:bill@burkecentral.com">Bill Burke</a>
* @version $Revision: 1 $
*/
public class DefaultBruteForceProtector implements Runnable, BruteForceProtector {
private static final Logger logger = Logger.getLogger(DefaultBruteForceProtector.class);
protected volatile boolean run = true;
protected int maxDeltaTimeSeconds = 60 * 60 * 12; // 12 hours
protected KeycloakSessionFactory factory;
protected CountDownLatch shutdownLatch = new CountDownLatch(1);
protected volatile long failures;
protected volatile long lastFailure;
protected volatile long totalTime;
protected LinkedBlockingQueue<LoginEvent> queue = new LinkedBlockingQueue<LoginEvent>();
public static final int TRANSACTION_SIZE = 20;
protected abstract class LoginEvent implements Comparable<LoginEvent> {
protected final String realmId;
protected final String userId;
protected final String ip;
protected LoginEvent(String realmId, String userId, String ip) {
this.realmId = realmId;
this.userId = userId;
this.ip = ip;
}
@Override
public int compareTo(LoginEvent o) {
return userId.compareTo(o.userId);
}
}
protected class ShutdownEvent extends LoginEvent {
public ShutdownEvent() {
super(null, null, null);
}
}
protected class FailedLogin extends LoginEvent {
protected final CountDownLatch latch = new CountDownLatch(1);
public FailedLogin(String realmId, String userId, String ip) {
super(realmId, userId, ip);
}
}
protected class SuccessfulLogin extends LoginEvent {
protected final CountDownLatch latch = new CountDownLatch(1);
public SuccessfulLogin(String realmId, String userId, String ip) {
super(realmId, userId, ip);
}
}
public DefaultBruteForceProtector(KeycloakSessionFactory factory) {
this.factory = factory;
}
public void failure(KeycloakSession session, LoginEvent event) {
logger.debug("failure");
RealmModel realm = getRealmModel(session, event);
logFailure(event);
String userId = event.userId;
UserModel user = session.users().getUserById(userId, realm);
if (user == null) {
return;
}
UserLoginFailureModel userLoginFailure = getUserModel(session, event);
if (userLoginFailure == null) {
userLoginFailure = session.sessions().addUserLoginFailure(realm, userId);
}
userLoginFailure.setLastIPFailure(event.ip);
long currentTime = Time.currentTimeMillis();
long last = userLoginFailure.getLastFailure();
long deltaTime = 0;
if (last > 0) {
deltaTime = currentTime - last;
}
userLoginFailure.setLastFailure(currentTime);
if(realm.isPermanentLockout()) {
userLoginFailure.incrementFailures();
logger.debugv("new num failures: {0}", userLoginFailure.getNumFailures());
if(userLoginFailure.getNumFailures() == realm.getFailureFactor()) {
logger.debugv("user {0} locked permanently due to too many login attempts", user.getUsername());
user.setEnabled(false);
return;
}
if (last > 0 && deltaTime < realm.getQuickLoginCheckMilliSeconds()) {
logger.debugv("quick login, set min wait seconds");
int waitSeconds = realm.getMinimumQuickLoginWaitSeconds();
int notBefore = (int) (currentTime / 1000) + waitSeconds;
logger.debugv("set notBefore: {0}", notBefore);
userLoginFailure.setFailedLoginNotBefore(notBefore);
}
return;
}
if (deltaTime > 0) {
// if last failure was more than MAX_DELTA clear failures
if (deltaTime > (long) realm.getMaxDeltaTimeSeconds() * 1000L) {
userLoginFailure.clearFailures();
}
}
userLoginFailure.incrementFailures();
logger.debugv("new num failures: {0}", userLoginFailure.getNumFailures());
int waitSeconds = realm.getWaitIncrementSeconds() * (userLoginFailure.getNumFailures() / realm.getFailureFactor());
logger.debugv("waitSeconds: {0}", waitSeconds);
logger.debugv("deltaTime: {0}", deltaTime);
if (waitSeconds == 0) {
if (last > 0 && deltaTime < realm.getQuickLoginCheckMilliSeconds()) {
logger.debugv("quick login, set min wait seconds");
waitSeconds = realm.getMinimumQuickLoginWaitSeconds();
}
}
if (waitSeconds > 0) {
waitSeconds = Math.min(realm.getMaxFailureWaitSeconds(), waitSeconds);
int notBefore = (int) (currentTime / 1000) + waitSeconds;
logger.debugv("set notBefore: {0}", notBefore);
userLoginFailure.setFailedLoginNotBefore(notBefore);
}
}
protected UserLoginFailureModel getUserModel(KeycloakSession session, LoginEvent event) {
RealmModel realm = getRealmModel(session, event);
if (realm == null) return null;
UserLoginFailureModel user = session.sessions().getUserLoginFailure(realm, event.userId);
if (user == null) return null;
return user;
}
protected RealmModel getRealmModel(KeycloakSession session, LoginEvent event) {
RealmModel realm = session.realms().getRealm(event.realmId);
if (realm == null) return null;
return realm;
}
public void start() {
new Thread(this, "Brute Force Protector").start();
}
public void shutdown() {
run = false;
try {
queue.offer(new ShutdownEvent());
shutdownLatch.await(10, TimeUnit.SECONDS);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
public void run() {
final ArrayList<LoginEvent> events = new ArrayList<LoginEvent>(TRANSACTION_SIZE + 1);
try {
while (run) {
try {
LoginEvent take = queue.poll(2, TimeUnit.SECONDS);
if (take == null) {
continue;
}
try {
events.add(take);
queue.drainTo(events, TRANSACTION_SIZE);
Collections.sort(events); // we sort to avoid deadlock due to ordered updates. Maybe I'm overthinking this.
KeycloakSession session = factory.create();
session.getTransactionManager().begin();
try {
for (LoginEvent event : events) {
if (event instanceof FailedLogin) {
failure(session, event);
} else if (event instanceof SuccessfulLogin) {
success(session, event);
} else if (event instanceof ShutdownEvent) {
run = false;
}
}
session.getTransactionManager().commit();
} catch (Exception e) {
session.getTransactionManager().rollback();
throw e;
} finally {
for (LoginEvent event : events) {
if (event instanceof FailedLogin) {
((FailedLogin) event).latch.countDown();
} else if (event instanceof SuccessfulLogin) {
((SuccessfulLogin) event).latch.countDown();
}
}
events.clear();
session.close();
}
} catch (Exception e) {
ServicesLogger.LOGGER.failedProcessingType(e);
}
} catch (InterruptedException e) {
break;
}
}
} finally {
shutdownLatch.countDown();
}
}
private void success(KeycloakSession session, LoginEvent event) {
String userId = event.userId;
UserModel model = session.users().getUserById(userId, getRealmModel(session, event));
UserLoginFailureModel user = getUserModel(session, event);
if(user == null) return;
logger.debugv("user {0} successfully logged in, clearing all failures", model.getUsername());
user.clearFailures();
}
protected void logFailure(LoginEvent event) {
ServicesLogger.LOGGER.loginFailure(event.userId, event.ip);
failures++;
long delta = 0;
if (lastFailure > 0) {
delta = Time.currentTimeMillis() - lastFailure;
if (delta > (long)maxDeltaTimeSeconds * 1000L) {
totalTime = 0;
} else {
totalTime += delta;
}
}
}
@Override
public void failedLogin(RealmModel realm, UserModel user, ClientConnection clientConnection) {
try {
FailedLogin event = new FailedLogin(realm.getId(), user.getId(), clientConnection.getRemoteAddr());
queue.offer(event);
// wait a minimum of seconds for type to process so that a hacker
// cannot flood with failed logins and overwhelm the queue and not have notBefore updated to block next requests
// todo failure HTTP responses should be queued via async HTTP
event.latch.await(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
}
logger.trace("sent failure event");
}
@Override
public void successfulLogin(final RealmModel realm, final UserModel user, final ClientConnection clientConnection) {
try {
SuccessfulLogin event = new SuccessfulLogin(realm.getId(), user.getId(), clientConnection.getRemoteAddr());
queue.offer(event);
event.latch.await(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
}
logger.trace("sent success event");
}
@Override
public boolean isTemporarilyDisabled(KeycloakSession session, RealmModel realm, UserModel user) {
UserLoginFailureModel failure = session.sessions().getUserLoginFailure(realm, user.getId());
if (failure != null) {
int currTime = (int) (Time.currentTimeMillis() / 1000);
int failedLoginNotBefore = failure.getFailedLoginNotBefore();
if (currTime < failedLoginNotBefore) {
logger.debugv("Current: {0} notBefore: {1}", currTime, failedLoginNotBefore);
return true;
}
}
return false;
}
@Override
public void close() {
}
}