/* * eID Applet Project. * Copyright (C) 2008-2009 FedICT. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License version * 3.0 as published by the Free Software Foundation. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, see * http://www.gnu.org/licenses/. */ package be.fedict.eid.applet.shared.protocol; import java.io.BufferedReader; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStreamReader; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import be.fedict.eid.applet.shared.annotation.HttpBody; import be.fedict.eid.applet.shared.annotation.HttpHeader; import be.fedict.eid.applet.shared.annotation.MessageDiscriminator; import be.fedict.eid.applet.shared.annotation.NotNull; import be.fedict.eid.applet.shared.annotation.PostConstruct; import be.fedict.eid.applet.shared.annotation.ProtocolVersion; import be.fedict.eid.applet.shared.annotation.ValidateSemanticalIntegrity; /** * Unmarshaller component is responsible for governing the process of converting * HTTP transported data streams to Java objects. * * <p> * Keep this class stateless as it can be shared across different HTTP requests * inside AppletServiceServlet. * </p> * * @author Frank Cornelis * */ public class Unmarshaller { private String protocolMessageDiscriminatorHeaderName; private Map<String, Class<?>> protocolMessageClasses; private String protocolVersionHeaderName; private Integer protocolVersion; /** * Main constructor. * * @param catalog */ public Unmarshaller(ProtocolMessageCatalog catalog) { processMessageCatalog(catalog); } private void processMessageCatalog(ProtocolMessageCatalog catalog) { this.protocolMessageClasses = new HashMap<String, Class<?>>(); List<Class<?>> messageClasses = catalog.getCatalogClasses(); for (Class<?> messageClass : messageClasses) { Field discriminatorField = findDiscriminatorField(messageClass); HttpHeader httpHeaderAnnotation = discriminatorField.getAnnotation(HttpHeader.class); String discriminatorHttpHeaderName = httpHeaderAnnotation.value(); if (null == this.protocolMessageDiscriminatorHeaderName) { this.protocolMessageDiscriminatorHeaderName = discriminatorHttpHeaderName; } else { if (false == this.protocolMessageDiscriminatorHeaderName.equals(discriminatorHttpHeaderName)) { throw new RuntimeException("discriminator field not the same over all message classes"); } } String discriminatorValue; try { discriminatorValue = (String) discriminatorField.get(null); } catch (Exception e) { throw new RuntimeException("error reading field: " + e.getMessage()); } if (this.protocolMessageClasses.containsValue(discriminatorValue)) { throw new RuntimeException("discriminator field not unique for: " + messageClass.getName()); } this.protocolMessageClasses.put(discriminatorValue, messageClass); Field protocolVersionField = findProtocolVersionField(messageClass); httpHeaderAnnotation = protocolVersionField.getAnnotation(HttpHeader.class); String protocolVersionHttpHeaderName = httpHeaderAnnotation.value(); if (null == this.protocolVersionHeaderName) { this.protocolVersionHeaderName = protocolVersionHttpHeaderName; } else { if (false == this.protocolVersionHeaderName.equals(protocolVersionHeaderName)) { throw new RuntimeException("protocol version field not the same over all message classes"); } } Integer protocolVersion; try { protocolVersion = (Integer) protocolVersionField.get(null); } catch (Exception e) { throw new RuntimeException("error reading field: " + e.getMessage()); } if (null == this.protocolVersion) { this.protocolVersion = protocolVersion; } else { if (false == this.protocolVersion.equals(protocolVersion)) { throw new RuntimeException("protocol version not the same over all message classes"); } } } } private Field findDiscriminatorField(Class<?> messageClass) { Field[] fields = messageClass.getFields(); for (Field field : fields) { MessageDiscriminator messageDiscriminatorAnnotation = field.getAnnotation(MessageDiscriminator.class); if (null == messageDiscriminatorAnnotation) { continue; } if (Modifier.FINAL != (field.getModifiers() & Modifier.FINAL)) { throw new RuntimeException("message discriminator should be final"); } if (Modifier.STATIC != (field.getModifiers() & Modifier.STATIC)) { throw new RuntimeException("message discriminator should be static"); } if (false == String.class.equals(field.getType())) { throw new RuntimeException("message discriminator should be a String"); } HttpHeader httpHeaderAnnotation = field.getAnnotation(HttpHeader.class); if (null == httpHeaderAnnotation) { throw new RuntimeException("message discriminator should be a HTTP header"); } return field; } throw new RuntimeException("no message discriminator field found on " + messageClass.getName()); } private Field findProtocolVersionField(Class<?> messageClass) { Field[] fields = messageClass.getFields(); for (Field field : fields) { ProtocolVersion protocolVersionAnnotation = field.getAnnotation(ProtocolVersion.class); if (null == protocolVersionAnnotation) { continue; } if (Modifier.FINAL != (field.getModifiers() & Modifier.FINAL)) { throw new RuntimeException("protocol version field should be final"); } if (Modifier.STATIC != (field.getModifiers() & Modifier.STATIC)) { throw new RuntimeException("protocol version field should be static"); } if (false == Integer.TYPE.equals(field.getType())) { throw new RuntimeException("protocol version field should be an int"); } HttpHeader httpHeaderAnnotation = field.getAnnotation(HttpHeader.class); if (null == httpHeaderAnnotation) { throw new RuntimeException("protocol version field should be a HTTP header"); } return field; } throw new RuntimeException("no protocol version field field found on " + messageClass.getName()); } /** * Receive a certain transfer object from the given HTTP receiver component. * * @param httpReceiver * @return */ public Object receive(HttpReceiver httpReceiver) { /* * Secure channel check */ if (false == httpReceiver.isSecure()) { throw new SecurityException("HTTP receiver over unsecure channel"); } /* * Message protocol check */ String protocolVersionHeader = httpReceiver.getHeaderValue(this.protocolVersionHeaderName); if (null == protocolVersionHeader) { throw new RuntimeException("no protocol version header"); } Integer protocolVersion = Integer.parseInt(protocolVersionHeader); if (false == this.protocolVersion.equals(protocolVersion)) { throw new RuntimeException("protocol version mismatch"); } /* * Message discriminator */ String discriminatorValue = httpReceiver.getHeaderValue(this.protocolMessageDiscriminatorHeaderName); Class<?> protocolMessageClass = this.protocolMessageClasses.get(discriminatorValue); if (null == protocolMessageClass) { throw new RuntimeException("unsupported message: " + discriminatorValue); } /* * Create the message object */ Object transferObject; try { transferObject = protocolMessageClass.newInstance(); } catch (Exception e) { throw new RuntimeException("error: " + e.getMessage(), e); } /* * First inject all HTTP headers. Is also performing some syntactical * input validation. */ try { injectHttpHeaderFields(httpReceiver, protocolMessageClass, transferObject); } catch (Exception e) { throw new RuntimeException("error: " + e.getMessage(), e); } /* * Inject HTTP body. */ Field[] fields = protocolMessageClass.getFields(); injectHttpBody(httpReceiver, transferObject, fields); /* * Input validation. */ inputValidation(transferObject, fields); /* * Semantical integrity validation. */ semanticValidation(protocolMessageClass, transferObject); /* * PostConstruct semantics */ postConstructSemantics(protocolMessageClass, transferObject); return transferObject; } private void injectHttpBody(HttpReceiver httpReceiver, Object transferObject, Field[] fields) { Field bodyField = null; for (Field field : fields) { HttpBody httpBodyAnnotation = field.getAnnotation(HttpBody.class); if (null != httpBodyAnnotation) { if (null == bodyField) { bodyField = field; } else { throw new RuntimeException("multiple body fields detected"); } } } if (null != bodyField) { byte[] body = httpReceiver.getBody(); Object bodyValue; if (List.class.equals(bodyField.getType())) { List<String> bodyList = new LinkedList<String>(); BufferedReader reader = new BufferedReader(new InputStreamReader(new ByteArrayInputStream(body))); String line; try { while (null != (line = reader.readLine())) { bodyList.add(line); } } catch (IOException e) { throw new RuntimeException("IO error: " + e.getMessage()); } bodyValue = bodyList; } else { bodyValue = body; } try { bodyField.set(transferObject, bodyValue); } catch (Exception e) { throw new RuntimeException("error: " + e.getMessage(), e); } } } private void postConstructSemantics(Class<?> protocolMessageClass, Object transferObject) { Method[] methods = protocolMessageClass.getMethods(); for (Method method : methods) { PostConstruct postConstructAnnotation = method.getAnnotation(PostConstruct.class); if (null != postConstructAnnotation) { try { method.invoke(transferObject, new Object[] {}); } catch (InvocationTargetException e) { Throwable methodException = e.getTargetException(); if (methodException instanceof RuntimeException) { RuntimeException runtimeException = (RuntimeException) methodException; /* * We directly rethrow the runtime exception to have a * cleaner stack trace. */ throw runtimeException; } throw new RuntimeException( "@PostConstruct method invocation error: " + methodException.getMessage(), methodException); } catch (Exception e) { throw new RuntimeException("@PostConstruct error: " + e.getMessage(), e); } } } } @SuppressWarnings("unchecked") private void semanticValidation(Class<?> protocolMessageClass, Object transferObject) { ValidateSemanticalIntegrity validateSemanticalIntegrity = protocolMessageClass .getAnnotation(ValidateSemanticalIntegrity.class); if (null != validateSemanticalIntegrity) { Class<? extends SemanticValidator<?>> validatorClass = validateSemanticalIntegrity.value(); SemanticValidator validator; try { validator = validatorClass.newInstance(); } catch (Exception e) { throw new RuntimeException("error: " + e.getMessage(), e); } try { validator.validate(transferObject); } catch (SemanticValidatorException e) { throw new RuntimeException("semantic validation error: " + e.getMessage()); } } } private void inputValidation(Object transferObject, Field[] fields) { for (Field field : fields) { NotNull notNullAnnotation = field.getAnnotation(NotNull.class); if (null == notNullAnnotation) { continue; } // XXX: doesn't make sense for primitive fields Object fieldValue; try { fieldValue = field.get(transferObject); } catch (Exception e) { throw new RuntimeException("error: " + e.getMessage(), e); } if (null == fieldValue) { throw new RuntimeException("field should not be null: " + field.getName()); } } } private void injectHttpHeaderFields(HttpReceiver httpReceiver, Class<?> protocolMessageClass, Object transferObject) throws IllegalArgumentException, IllegalAccessException { List<String> headerNames = httpReceiver.getHeaderNames(); for (String headerName : headerNames) { Field httpHeaderField = findHttpHeaderField(protocolMessageClass, headerName); if (null != httpHeaderField) { String headerValue = httpReceiver.getHeaderValue(headerName); if (0 != (httpHeaderField.getModifiers() & Modifier.FINAL)) { /* * In this case we must check that the value corresponds. */ String constantValue; if (String.class.equals(httpHeaderField.getType())) { constantValue = (String) httpHeaderField.get(transferObject); } else if (Integer.TYPE.equals(httpHeaderField.getType())) { constantValue = ((Integer) httpHeaderField.get(transferObject)).toString(); } else { throw new RuntimeException("unsupported type: " + httpHeaderField.getType().getName()); } if (false == constantValue.equals(headerValue)) { throw new RuntimeException("constant value mismatch: " + httpHeaderField.getName() + "; expected value: " + constantValue + "; actual value: " + headerValue); } } else { if (String.class.equals(httpHeaderField.getType())) { httpHeaderField.set(transferObject, headerValue); } else if (Integer.TYPE.equals(httpHeaderField.getType()) || Integer.class.equals(httpHeaderField.getType())) { Integer intValue = Integer.parseInt(headerValue); httpHeaderField.set(transferObject, intValue); // TODO make this type handling more generic } else if (Boolean.TYPE.equals(httpHeaderField.getType()) || Boolean.class.equals(httpHeaderField.getType())) { Boolean boolValue = Boolean.parseBoolean(headerValue); httpHeaderField.set(transferObject, boolValue); } else if (httpHeaderField.getType().isEnum()) { Enum<?> e = (Enum<?>) httpHeaderField.getType().getEnumConstants()[0]; Object value = e.valueOf(e.getClass(), headerValue); httpHeaderField.set(transferObject, value); } else { throw new RuntimeException("unsupported http header field type: " + httpHeaderField.getType()); } } } } } private Field findHttpHeaderField(Class<?> protocolMessageClass, String headerName) { if (null == headerName) { throw new RuntimeException("header name should not be null"); } Field[] fields = protocolMessageClass.getFields(); for (Field field : fields) { HttpHeader httpHeaderAnnotation = field.getAnnotation(HttpHeader.class); if (null == httpHeaderAnnotation) { continue; } String fieldHttpHeaderName = httpHeaderAnnotation.value(); /* * Ignore cases since the HttpServletRequest class likes to do so. */ if (headerName.equalsIgnoreCase(fieldHttpHeaderName)) { return field; } } return null; } }