/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.models.sessions.infinispan.initializer;
import org.infinispan.Cache;
import org.infinispan.context.Flag;
import org.infinispan.distexec.DefaultExecutorService;
import org.infinispan.lifecycle.ComponentStatus;
import org.infinispan.remoting.transport.Transport;
import org.jboss.logging.Logger;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.KeycloakSessionFactory;
import org.keycloak.models.KeycloakSessionTask;
import org.keycloak.models.utils.KeycloakModelUtils;
import java.io.Serializable;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
/**
* Startup initialization for reading persistent userSessions/clientSessions to be filled into infinispan/memory . In cluster,
* the initialization is distributed among all cluster nodes, so the startup time is even faster
*
* TODO: Move to clusterService. Implementation is already pretty generic and doesn't contain any "userSession" specific stuff. All sessions-specific logic is in the SessionLoader implementation
*
* @author <a href="mailto:mposolda@redhat.com">Marek Posolda</a>
*/
public class InfinispanUserSessionInitializer {
private static final String STATE_KEY_PREFIX = "distributed::";
private static final Logger log = Logger.getLogger(InfinispanUserSessionInitializer.class);
private final KeycloakSessionFactory sessionFactory;
private final Cache<String, Serializable> workCache;
private final SessionLoader sessionLoader;
private final int maxErrors;
private final int sessionsPerSegment;
private final String stateKey;
public InfinispanUserSessionInitializer(KeycloakSessionFactory sessionFactory, Cache<String, Serializable> workCache, SessionLoader sessionLoader, int maxErrors, int sessionsPerSegment, String stateKeySuffix) {
this.sessionFactory = sessionFactory;
this.workCache = workCache;
this.sessionLoader = sessionLoader;
this.maxErrors = maxErrors;
this.sessionsPerSegment = sessionsPerSegment;
this.stateKey = STATE_KEY_PREFIX + stateKeySuffix;
}
public void initCache() {
this.workCache.getAdvancedCache().getComponentRegistry().registerComponent(sessionFactory, KeycloakSessionFactory.class);
}
public void loadPersistentSessions() {
if (isFinished()) {
return;
}
while (!isFinished()) {
if (!isCoordinator()) {
try {
Thread.sleep(1000);
} catch (InterruptedException ie) {
log.error("Interrupted", ie);
}
} else {
startLoading();
}
}
}
private boolean isFinished() {
InitializerState state = getStateFromCache();
return state != null && state.isFinished();
}
private InitializerState getOrCreateInitializerState() {
InitializerState state = getStateFromCache();
if (state == null) {
final int[] count = new int[1];
// Rather use separate transactions for update and counting
KeycloakModelUtils.runJobInTransaction(sessionFactory, new KeycloakSessionTask() {
@Override
public void run(KeycloakSession session) {
sessionLoader.init(session);
}
});
KeycloakModelUtils.runJobInTransaction(sessionFactory, new KeycloakSessionTask() {
@Override
public void run(KeycloakSession session) {
count[0] = sessionLoader.getSessionsCount(session);
}
});
state = new InitializerState();
state.init(count[0], sessionsPerSegment);
saveStateToCache(state);
}
return state;
}
private InitializerState getStateFromCache() {
// TODO: We ignore cacheStore for now, so that in Cross-DC scenario (with RemoteStore enabled) is the remoteStore ignored. This means that every DC needs to load offline sessions separately.
return (InitializerState) workCache.getAdvancedCache()
.withFlags(Flag.SKIP_CACHE_STORE, Flag.SKIP_CACHE_LOAD)
.get(stateKey);
}
private void saveStateToCache(final InitializerState state) {
// 3 attempts to send the message (it may fail if some node fails in the meantime)
retry(3, new Runnable() {
@Override
public void run() {
// Save this synchronously to ensure all nodes read correct state
// TODO: We ignore cacheStore for now, so that in Cross-DC scenario (with RemoteStore enabled) is the remoteStore ignored. This means that every DC needs to load offline sessions separately.
InfinispanUserSessionInitializer.this.workCache.getAdvancedCache().
withFlags(Flag.IGNORE_RETURN_VALUES, Flag.FORCE_SYNCHRONOUS, Flag.SKIP_CACHE_STORE, Flag.SKIP_CACHE_LOAD)
.put(stateKey, state);
}
});
}
private boolean isCoordinator() {
Transport transport = workCache.getCacheManager().getTransport();
return transport == null || transport.isCoordinator();
}
// Just coordinator will run this
private void startLoading() {
InitializerState state = getOrCreateInitializerState();
// Assume each worker has same processor's count
int processors = Runtime.getRuntime().availableProcessors();
ExecutorService localExecutor = Executors.newCachedThreadPool();
Transport transport = workCache.getCacheManager().getTransport();
boolean distributed = transport != null;
ExecutorService executorService = distributed ? new DefaultExecutorService(workCache, localExecutor) : localExecutor;
int errors = 0;
try {
while (!state.isFinished()) {
int nodesCount = transport==null ? 1 : transport.getMembers().size();
int distributedWorkersCount = processors * nodesCount;
log.debugf("Starting next iteration with %d workers", distributedWorkersCount);
List<Integer> segments = state.getUnfinishedSegments(distributedWorkersCount);
if (log.isTraceEnabled()) {
log.trace("unfinished segments for this iteration: " + segments);
}
List<Future<WorkerResult>> futures = new LinkedList<>();
for (Integer segment : segments) {
SessionInitializerWorker worker = new SessionInitializerWorker();
worker.setWorkerEnvironment(segment, sessionsPerSegment, sessionLoader);
if (!distributed) {
worker.setEnvironment(workCache, null);
}
Future<WorkerResult> future = executorService.submit(worker);
futures.add(future);
}
for (Future<WorkerResult> future : futures) {
try {
WorkerResult result = future.get();
if (result.getSuccess()) {
int computedSegment = result.getSegment();
state.markSegmentFinished(computedSegment);
} else {
if (log.isTraceEnabled()) {
log.tracef("Segment %d failed to compute", result.getSegment());
}
}
} catch (InterruptedException ie) {
errors++;
log.error("Interruped exception when computed future. Errors: " + errors, ie);
} catch (ExecutionException ee) {
errors++;
log.error("ExecutionException when computed future. Errors: " + errors, ee);
}
}
if (errors >= maxErrors) {
throw new RuntimeException("Maximum count of worker errors occured. Limit was " + maxErrors + ". See server.log for details");
}
saveStateToCache(state);
if (log.isDebugEnabled()) {
log.debug("New initializer state pushed. The state is: " + state.printState());
}
}
} finally {
if (distributed) {
executorService.shutdown();
}
localExecutor.shutdown();
}
}
private void retry(int retry, Runnable runnable) {
while (true) {
try {
runnable.run();
return;
} catch (RuntimeException e) {
ComponentStatus status = workCache.getStatus();
if (status.isStopping() || status.isTerminated()) {
log.warn("Failed to put initializerState to the cache. Cache is already terminating");
log.debug(e.getMessage(), e);
return;
}
retry--;
if (retry == 0) {
throw e;
}
}
}
}
public static class WorkerResult implements Serializable {
private Integer segment;
private Boolean success;
public static WorkerResult create (Integer segment, boolean success) {
WorkerResult res = new WorkerResult();
res.setSegment(segment);
res.setSuccess(success);
return res;
}
public Integer getSegment() {
return segment;
}
public void setSegment(Integer segment) {
this.segment = segment;
}
public Boolean getSuccess() {
return success;
}
public void setSuccess(Boolean success) {
this.success = success;
}
}
}