/* * <p> * 版权: ©2011 * </p> */ package org.young.isocket.filter; import java.io.IOException; import java.util.HashMap; import java.util.Map; import java.util.Queue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; import org.glassfish.grizzly.Connection; import org.glassfish.grizzly.filterchain.BaseFilter; import org.glassfish.grizzly.filterchain.FilterChainContext; import org.glassfish.grizzly.filterchain.NextAction; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.young.icore.util.PropertiesLoaderUtils; import org.young.isocket.exception.AuthenticationException; import org.young.isocket.service.ServiceRequest; import org.young.isocket.service.ServiceResponse; import org.young.isocket.util.SocketKeys; /** * <p> * * </p> * * @see * @author yangjun2 * @email yangjun1120@gmail.com * */ public class ClientAuthFilter extends BaseFilter { private static final Logger logger = LoggerFactory.getLogger(ClientAuthFilter.class); // Map of authenticated connections private ConcurrentHashMap<Connection, ConnectionAuthInfo> authenticatedConnections = new ConcurrentHashMap<Connection, ConnectionAuthInfo>(); private ServiceRequest authRequest; public ClientAuthFilter(String userName, String password) { PropertiesLoaderUtils.setPropertiesFields(this); authRequest = new ServiceRequest(); //authRequest.setAuth(true); authRequest.setServiceId(SocketKeys.SERVICE_ID_AUTH); Map<String, String> user = new HashMap<String, String>(); user.put(SocketKeys.PARAMETER_USERNAME, userName); user.put(SocketKeys.PARAMETER_PASSWORD, password); authRequest.setRequestObject(user); } @Override public NextAction handleRead(FilterChainContext ctx) throws IOException { // Get the connection final Connection connection = ctx.getConnection(); // Get the processing packet final ServiceResponse serviceResponse = (ServiceResponse) ctx.getMessage(); // Check if the packet is authentication response if (serviceResponse.getServiceId().equals(SocketKeys.SERVICE_ID_AUTH)) { // if yes - retrieve the id, assigned by server if (serviceResponse.getResponseCode() == SocketKeys.RESPONSE_CODE_SUCCESS) { final String sessionId = serviceResponse.getResponseObject(); synchronized (connection) { // store id in the map ConnectionAuthInfo info = authenticatedConnections.get(connection); if (info == null) { serviceResponse.setResponseCode(SocketKeys.RESPONSE_CODE_AUTHENTICATIONERROR); serviceResponse.setResponseMessage(SocketKeys.MESSAGE_AUTH_DELETED); } info.sessionId = sessionId; info.setState(1); // resume pending writes if (info.pendingMessages != null) { for (FilterChainContext pendedContext : info.pendingMessages) { pendedContext.resume(); } } info.pendingMessages = null; } } else { logger.error(String.format(SocketKeys.MESSAGE_SERVICE_ERROR, new Object[] { serviceResponse.getId(), serviceResponse.getServiceId(), serviceResponse.getResponseMessage() })); synchronized (connection) { // store id in the map ConnectionAuthInfo info = authenticatedConnections.get(connection); if (info == null) { serviceResponse.setResponseCode(SocketKeys.RESPONSE_CODE_AUTHENTICATIONERROR); serviceResponse.setResponseMessage(SocketKeys.MESSAGE_AUTH_DELETED); } info.sessionId = null; info.setState(2); // resume pending writes if (info.pendingMessages != null) { for (FilterChainContext pendedContext : info.pendingMessages) { pendedContext.resume(); } } info.pendingMessages = null; } } // if it's authentication response - we don't pass processing to a next filter in a chain. return ctx.getStopAction(); } else { // if it's some custom message // Get id line final String sessionId = serviceResponse.getSessionId(); // Check the client id if (checkAuth(connection, sessionId)) { // if id corresponds to what client has - // Remove authentication header serviceResponse.setSessionId(null);//remove sessionId // Pass to a next filter return ctx.getInvokeAction(); } else { serviceResponse.setResponseCode(SocketKeys.RESPONSE_CODE_AUTHENTICATIONERROR); serviceResponse.setResponseMessage(SocketKeys.MESSAGE_AUTH_ERROR); return ctx.getInvokeAction(); } } } @Override public NextAction handleClose(FilterChainContext ctx) throws IOException { authenticatedConnections.remove(ctx.getConnection()); return ctx.getInvokeAction(); } public ConnectionAuthInfo getConnectionAuthInfo(Connection connection) { return authenticatedConnections.get(connection); } public void puttConnectionAuthInfoIfAbsent(Connection connection, ConnectionAuthInfo authInfo) { authenticatedConnections.putIfAbsent(connection, authInfo); } @Override public NextAction handleWrite(final FilterChainContext ctx) throws IOException { // Get the connection final Connection connection = ctx.getConnection(); //get send buffer final ServiceRequest sourceRequst = ctx.getMessage(); // Get the connection authentication information ConnectionAuthInfo authInfo = authenticatedConnections.get(connection); if (authInfo == null) { // connection is not authenticated authInfo = new ConnectionAuthInfo(); final ConnectionAuthInfo existingInfo = authenticatedConnections.putIfAbsent(connection, authInfo); if (existingInfo == null) { // it's the first message for this client - we need to start authentication process // sending authentication packet authRequest.setTransformType(sourceRequst.getTransformType()); ctx.write(authRequest); } else { // authentication has been already started. authInfo = existingInfo; } } //if (sourceRequst.getServiceId().equals(SocketKeys.SERVICE_ID_AUTH)) { if (authInfo.getState() == 0) {//un authenticate //if (authInfo.pendingMessages != null) { // it might be a sign, that authentication has been completed on another thread // synchronize and check one more time synchronized (connection) { //if (authInfo.pendingMessages != null) { if ((authInfo.getState() == 0)) { if (authInfo.sessionId == null) { // Authentication hs been started by another thread, but it is still in progress // add suspended write context to a queue ctx.suspend(); authInfo.pendingMessages.add(ctx); return ctx.getSuspendAction(); } } } } else if (authInfo.getState() == 1) {//authenticate success } else if (authInfo.getState() == 2) {//authenticate failure if (sourceRequst.getServiceId().equals(SocketKeys.SERVICE_ID_AUTH)) { } else { throw new AuthenticationException(SocketKeys.MESSAGE_AUTH_ERROR); } } else { throw new IllegalArgumentException(String.format("auth info state:%s is error.", authInfo.getState())); } // Authentication has been completed - add "auth-id" header and pass the message to a next filter in chain. sourceRequst.setSessionId(authInfo.getSessionId()); return ctx.getInvokeAction(); } private NextAction handleAuthFail(FilterChainContext ctx, ServiceRequest sourceRequest, String errMsg, Object... objects) throws IOException { ServiceResponse serviceResponse = new ServiceResponse(); serviceResponse.setResponseCode(SocketKeys.RESPONSE_CODE_AUTHENTICATIONERROR); serviceResponse.setResponseMessage(String.format(errMsg, objects)); serviceResponse.setTransformType(sourceRequest.getTransformType()); //serviceResponse.setAuth(sourceRequest.isAuth()); serviceResponse.setId(sourceRequest.getId()); serviceResponse.setSessionId(sourceRequest.getSessionId()); serviceResponse.setServiceId(sourceRequest.getServiceId()); ctx.write(serviceResponse); // stop the packet processing return ctx.getStopAction(); } private boolean checkAuth(Connection connection, String sessionId) { // Get the connection id, from the client map final ConnectionAuthInfo registeredId = authenticatedConnections.get(connection); if (registeredId == null || registeredId.sessionId == null) return false; return sessionId.equals(registeredId.sessionId); } /** * Single connection authentication info. */ public static class ConnectionAuthInfo { // Connection id private volatile String sessionId; // Queue of the pending writes private volatile Queue<FilterChainContext> pendingMessages; private int state = 0; public int getState() { return state; } public void setState(int state) { this.state = state; } public Queue<FilterChainContext> getPendingMessages() { return pendingMessages; } public void setPendingMessages(Queue<FilterChainContext> pendingMessages) { this.pendingMessages = pendingMessages; } public void setSessionId(String sessionId) { this.sessionId = sessionId; } public ConnectionAuthInfo() { pendingMessages = new ConcurrentLinkedQueue<FilterChainContext>(); } public String getSessionId() { return this.sessionId; } } }