/* * Copyright (C) 2012 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.jboss.errai.bus.server; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import javax.servlet.http.HttpSession; import org.jboss.errai.bus.client.api.QueueSession; import org.jboss.errai.bus.client.api.SessionEndEvent; import org.jboss.errai.bus.client.api.SessionEndListener; import org.jboss.errai.bus.client.api.laundry.LaundryListProviderFactory; import org.jboss.errai.bus.server.api.SessionProvider; import org.jboss.errai.bus.server.service.ErraiConfigAttribs; import org.jboss.errai.bus.server.service.ErraiServiceConfigurator; import org.jboss.errai.bus.server.servlet.CSRFTokenCheck; import org.jboss.errai.bus.server.servlet.RequestSecurityCheck; import org.jboss.errai.bus.server.util.SecureHashUtil; import org.jboss.errai.bus.server.util.ServerLaundryList; import org.jboss.errai.common.client.api.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * The SessionProvider for HTTP-based queue sessions. */ public class HttpSessionProvider implements SessionProvider<HttpSession> { static final Logger log = LoggerFactory.getLogger(HttpSessionProvider.class); static final Map<String, SessionsContainer> containersByHttpSessionId = new HashMap<>(); private RequestSecurityCheck csrfCheck = RequestSecurityCheck.noCheck(); @Override public void init(final ErraiServiceConfigurator config) { csrfCheck = (ErraiConfigAttribs.ENABLE_CSRF_BUS_TOKEN.getBoolean(config) ? CSRFTokenCheck.INSTANCE : RequestSecurityCheck.noCheck()); } @Override public QueueSession createOrGetSession(final HttpSession externSessRef, final String remoteQueueID) { final SessionsContainer sc; if (containersByHttpSessionId.containsKey(externSessRef.getId())) { sc = containersByHttpSessionId.get(externSessRef.getId()); } else { sc = new SessionsContainer(); containersByHttpSessionId.put(externSessRef.getId(), sc); csrfCheck.prepareSession(externSessRef, log); } QueueSession qs = sc.getSession(remoteQueueID); if (qs == null) { log.debug("queue session " + remoteQueueID + " started"); qs = sc.createSession(externSessRef.getId(), remoteQueueID); qs.setAttribute(HttpSession.class.getName(), externSessRef); qs.addSessionEndListener(new SessionEndListener() { @Override public void onSessionEnd(final SessionEndEvent event) { log.debug("queue session " + remoteQueueID + " ended"); sc.removeSession(remoteQueueID); } }); } return qs; } public static class SessionsContainer { private final Map<String, Object> sharedAttributes = new HashMap<>(); private final Map<String, QueueSession> queueSessions = new HashMap<>(); public QueueSession createSession(final String httpSessionId, final String remoteQueueId) { final QueueSession qs = new HttpSessionWrapper(this, httpSessionId, remoteQueueId); queueSessions.put(remoteQueueId, qs); return qs; } public QueueSession getSession(final String remoteQueueId) { return queueSessions.get(remoteQueueId); } public void removeSession(final String remoteQueueId) { queueSessions.remove(remoteQueueId); } } private static class HttpSessionWrapper implements QueueSession { private final SessionsContainer container; private final String parentSessionId; private final String sessionId; private final String remoteQueueID; private List<SessionEndListener> sessionEndListeners; public HttpSessionWrapper(final SessionsContainer container, final String httpSessionId, final String remoteQueueID) { this.container = Assert.notNull(container); this.remoteQueueID = Assert.notNull(remoteQueueID); this.parentSessionId = Assert.notNull(httpSessionId); this.sessionId = SecureHashUtil.nextSecureHash("SHA-256", httpSessionId.getBytes(), remoteQueueID.getBytes()); } @Override public String getSessionId() { return sessionId; } @Override public String getParentSessionId() { return parentSessionId; } @Override public boolean endSession() { container.removeSession(remoteQueueID); fireSessionEndListeners(); return true; } @Override public boolean isValid() { return container.getSession(remoteQueueID) != null; } @Override public void setAttribute(final String attribute, final Object value) { container.sharedAttributes.put(attribute, value); } @Override public <T> T getAttribute(final Class<T> type, final String attribute) { return (T) container.sharedAttributes.get(attribute); } @Override public Collection<String> getAttributeNames() { return container.sharedAttributes.keySet(); } @Override public boolean hasAttribute(final String attribute) { return container.sharedAttributes.containsKey(attribute); } @Override public Object removeAttribute(final String attribute) { return container.sharedAttributes.remove(attribute); } @Override public void addSessionEndListener(final SessionEndListener listener) { synchronized (this) { if (sessionEndListeners == null) { sessionEndListeners = new ArrayList<>(); } sessionEndListeners.add(listener); } } private void fireSessionEndListeners() { ((ServerLaundryList) LaundryListProviderFactory.get().getLaundryList(this)).cleanAll(); if (sessionEndListeners == null) return; final SessionEndEvent event = new SessionEndEvent(this); for (final SessionEndListener sessionEndListener : sessionEndListeners) { sessionEndListener.onSessionEnd(event); } } @Override public boolean equals(final Object o) { if (this == o) return true; if (!(o instanceof HttpSessionWrapper)) return false; final HttpSessionWrapper that = (HttpSessionWrapper) o; if (remoteQueueID != null ? !remoteQueueID.equals(that.remoteQueueID) : that.remoteQueueID != null) return false; if (sessionId != null ? !sessionId.equals(that.sessionId) : that.sessionId != null) return false; return true; } @Override public int hashCode() { int result = (sessionId != null ? sessionId.hashCode() : 0); result = 31 * result + (remoteQueueID != null ? remoteQueueID.hashCode() : 0); return result; } @Override public String toString() { return "HttpSessionWrapper{" + "sessionId='" + sessionId + '\'' + ", remoteQueueID='" + remoteQueueID + '\'' + '}'; } } }