package org.jdiameter.client.impl.annotation; import org.jdiameter.api.*; import org.jdiameter.api.annotation.*; import org.jdiameter.client.impl.annotation.internal.ClassInfo; import org.jdiameter.client.impl.annotation.internal.ConstructorInfo; import org.jdiameter.client.impl.annotation.internal.MethodInfo; import org.jdiameter.client.impl.annotation.internal.Storage; import org.jdiameter.client.impl.RawSessionImpl; import org.jdiameter.client.api.annotation.IRecoder; import org.jdiameter.client.api.annotation.RecoderException; import org.jdiameter.client.api.IMessage; import org.slf4j.LoggerFactory; import org.slf4j.Logger; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.util.*; // TODO full min/max/position constrains and optimization (caching) public class Recoder implements IRecoder { private static final Logger log = LoggerFactory.getLogger(Recoder.class); private Storage storage = new Storage(); private final RawSessionImpl rawSession; private final MetaData metaData; public Recoder(SessionFactory factory, MetaData metaData) { this.metaData = metaData; try { this.rawSession = (RawSessionImpl) factory.getNewRawSession(); } catch (InternalException e) { throw new IllegalArgumentException(e); } } // ======================================================================================= //@Override public Message encodeToRequest(Object yourDomainMessageObject, Avp... additionalAvp) throws RecoderException { return encode(yourDomainMessageObject, null, 0, additionalAvp); } //@Override public Message encodeToAnswer(Object yourDomainMessageObject, Request request, long resultCode) throws RecoderException { return encode(yourDomainMessageObject, request, resultCode); } public Message encode( Object yourDomainMessageObject, Request request, long resultCode, Avp... addAvp) throws RecoderException { IMessage message = null; ClassInfo classInfo = storage.getClassInfo(yourDomainMessageObject.getClass()); CommandDscr commandDscr = classInfo.getAnnotation(CommandDscr.class); if (commandDscr != null) { // Get command parameters if (request == null) { message = (IMessage) rawSession.createMessage(commandDscr.code(), ApplicationId.createByAccAppId(0)); message.setRequest(true); message.getAvps().addAvp(addAvp); try { if ( message.getAvps().getAvp(Avp.AUTH_APPLICATION_ID) != null ) { message.setHeaderApplicationId( message.getAvps().getAvp(Avp.AUTH_APPLICATION_ID).getUnsigned32() ); } else if ( message.getAvps().getAvp(Avp.ACCT_APPLICATION_ID) != null ) { message.setHeaderApplicationId( message.getAvps().getAvp(Avp.ACCT_APPLICATION_ID).getUnsigned32() ); } else if ( message.getAvps().getAvp(Avp.VENDOR_SPECIFIC_APPLICATION_ID) != null ) { message.setHeaderApplicationId( message.getAvps().getAvp(Avp.VENDOR_SPECIFIC_APPLICATION_ID). getGrouped().getAvp(Avp.VENDOR_ID).getUnsigned32() ); } } catch (Exception exc) { throw new RecoderException(exc); } if ( message.getAvps().getAvp(Avp.ORIGIN_HOST) == null ) { message.getAvps().addAvp(Avp.ORIGIN_HOST, metaData.getLocalPeer().getUri().getFQDN(), true, false, true); } if ( message.getAvps().getAvp(Avp.ORIGIN_REALM) == null ) { message.getAvps().addAvp(Avp.ORIGIN_REALM, metaData.getLocalPeer().getRealmName(), true, false, true); } } else { message = (IMessage) request.createAnswer(resultCode); } for (CommandFlag f : commandDscr.flags()) { switch (f) { case E: message.setError(true); break; case P: message.setProxiable(true); break; case R: message.setRequest(true); break; case T: message.setReTransmitted(true); break; } } // Find top level avp in getter-annotation methods Map<String, Object> chMap = getChildInstance(yourDomainMessageObject, classInfo, null); // Fill for (Child ch : commandDscr.childs()) { fillChild(message.getAvps(), ch, chMap); } } else { log.debug("Can not found annotation for object {}", yourDomainMessageObject); } return message; } private Map<String, Object> getChildInstance(Object yourDomainMessageObject, ClassInfo c, Map<String, Object> chMap) throws RecoderException { if (chMap == null) chMap = new HashMap<String, Object>(); for (MethodInfo mi : c.getMethodsInfo()) { if ( mi.getAnnotation(Getter.class) != null) { try { Object value = mi.getMethod().invoke(yourDomainMessageObject); if (value != null) { Class mc = value.getClass().isArray() ? value.getClass().getComponentType() : value.getClass(); chMap.put(mc.getName(), value); for (Class<?> i : mc.getInterfaces()) chMap.put(i.getName(), value); } } catch (IllegalAccessException e) { throw new RecoderException(e); } catch (InvocationTargetException e) { throw new RecoderException(e); } } } return chMap; } private void fillChild(AvpSet as, Child ci, Map<String, Object> childs) throws RecoderException { Object c = childs.get( ci.ref().getName() ); if (c != null) { ClassInfo cc = storage.getClassInfo(ci.ref()); AvpDscr ad = cc.getAnnotation(AvpDscr.class); if (ad != null) { boolean m = false, p = false; // cast <=> getter for primitive switch (ad.type()) { case Integer32: case Enumerated: { for (AvpFlag f : ad.must()) if (AvpFlag.M.equals(f)) m = true; else if (AvpFlag.P.equals(f)) p = true; // find in getter Collection<Integer> cv = getValue(c, Integer.class); for (Integer v : cv) as.addAvp(ad.code(), v, ad.vendorId(), m, p); } break; case Unsigned32: { for (AvpFlag f : ad.must()) if (AvpFlag.M.equals(f)) m = true; else if (AvpFlag.P.equals(f)) p = true; Collection<Long> cv = getValue(c, Long.class); for (Long v : cv) as.addAvp(ad.code(), v, ad.vendorId(), m, p, true); } break; case Unsigned64: case Integer64: { for (AvpFlag f : ad.must()) if (AvpFlag.M.equals(f)) m = true; else if (AvpFlag.P.equals(f)) p = true; Collection<Long> cv = getValue(c, Long.class); for (Long v : cv) as.addAvp(ad.code(), v, ad.vendorId(), m, p); } break; case Float32: { for (AvpFlag f : ad.must()) if (AvpFlag.M.equals(f)) m = true; else if (AvpFlag.P.equals(f)) p = true; Collection<Float> cv = getValue(c, Float.class); for (Float v : cv) as.addAvp(ad.code(), v, ad.vendorId(), m, p); } break; case Float64: { for (AvpFlag f : ad.must()) if (AvpFlag.M.equals(f)) m = true; else if (AvpFlag.P.equals(f)) p = true; Collection<Double> cv = getValue(c, Double.class); for (Double v : cv) as.addAvp(ad.code(), v, ad.vendorId(), m, p); } break; case OctetString: case Address: case Time: case DiameterIdentity: case DiameterURI: case IPFilterRule: case QoSFilterRule: { for (AvpFlag f : ad.must()) if (AvpFlag.M.equals(f)) m = true; else if (AvpFlag.P.equals(f)) p = true; Collection<String> cv = getValue(c, String.class); for (String v : cv) as.addAvp(ad.code(), v, ad.vendorId(), m, p, true); } break; case UTF8String: { for (AvpFlag f : ad.must()) if (AvpFlag.M.equals(f)) m = true; else if (AvpFlag.P.equals(f)) p = true; Collection<String> cv = getValue(c, String.class); for (String v : cv) as.addAvp(ad.code(), v, ad.vendorId(), m, p, false); } break; case Grouped: { for (AvpFlag f : ad.must()) { if (AvpFlag.M.equals(f)) { m = true; } else if (AvpFlag.P.equals(f)) { p = true; } } Collection<Object> cv = new ArrayList<Object>(); if (c.getClass().isArray()) { cv = Arrays.asList((Object[])c); } else { cv.add(c); } for (Object cj : cv) { AvpSet las = as.addGroupedAvp(ad.code(),ad.vendorId(), m, p); Map<String, Object> lchilds = getChildInstance(cj, storage.getClassInfo(cj.getClass()), null); for (Child lci : ad.childs()) { fillChild(las, lci, lchilds); } } } break; } } } } private <T> Collection<T> getValue(Object ic, Class<T> type ) throws RecoderException { Collection<T> rc = new ArrayList<T>(); Object[] xc = null; if (ic.getClass().isArray()) xc = (Object[]) ic; else xc = new Object[] {ic}; for (Object c : xc) { for (MethodInfo lm : storage.getClassInfo(c.getClass()).getMethodsInfo()) { if ( lm.getAnnotation(Getter.class) != null ) { try { rc.add((T)lm.getMethod().invoke(c)); } catch (IllegalAccessException e) { throw new RecoderException(e); } catch (InvocationTargetException e) { throw new RecoderException(e); } } } } return rc; } // ======================================================================================= public <T> T decode( Message message, java.lang.Class<T> yourDomainMessageObject) throws RecoderException { Object rc = null; ClassInfo c = storage.getClassInfo(yourDomainMessageObject); CommandDscr cd = c.getAnnotation(CommandDscr.class); if (cd != null) { try { if (message.getCommandCode() != cd.code()) throw new IllegalArgumentException("Invalid message code " + message.getCommandCode()); if (message.getApplicationId() != 0 && message.getApplicationId() != cd.appId()) throw new IllegalArgumentException("Invalid Application-Id " + message.getApplicationId()); for (CommandFlag f : cd.flags()) { switch (f) { case E: if ( !message.isError() ) throw new IllegalArgumentException("Flag e is not set"); break; case P: if ( !message.isProxiable() ) throw new IllegalArgumentException("Flag p is not set"); break; case R: if ( !message.isRequest()) throw new IllegalArgumentException("Flag m is not set"); break; case T: if ( !message.isReTransmitted() ) throw new IllegalArgumentException("Flag t is not set"); break; } } // Find max constructor + lost avp set by setters int cacount = 0; Constructor<?> cm = null; Map<String, Class<?>> cmargs = new HashMap<String, Class<?>>(); for (ConstructorInfo ci : c.getConstructorsInfo()) { if ( ci.getAnnotation(Setter.class) != null ) { // check params - all params must have avp annotation Class<?>[] params = ci.getConstructor().getParameterTypes(); boolean correct = true; for (Class<?> j : params ) { if (j.isArray()) j = j.getComponentType(); if ( storage.getClassInfo(j).getAnnotation(AvpDscr.class) == null) { correct = false; break; } } if ( !correct ) continue; // find max args constructor if (cacount < params.length) { cacount = params.length; cm = ci.getConstructor(); } } } // fill cm args List<Object> initargs = new ArrayList<Object>(); if (cm != null) { for (Class<?> ac : cm.getParameterTypes()) { Class<?> lac = ac.isArray() ? ac.getComponentType() : ac; cmargs.put(lac.getName(), ac); // Create params initargs.add( createChildByAvp( findChildDscr(cd.childs(), ac), ac, message.getAvps()) ); } // Create instance class rc = cm.newInstance(initargs.toArray()); } else { rc = yourDomainMessageObject.newInstance(); } // for (MethodInfo mi : c.getMethodsInfo()) { if ( mi.getAnnotation(Setter.class) != null ) { Class<?>[] pt = mi.getMethod().getParameterTypes(); if (pt.length == 1 && storage.getClassInfo(pt[0]).getAnnotation(AvpDscr.class) != null) { Class<?> ptc = pt[0].isArray()? pt[0].getComponentType() : pt[0]; if ( !cmargs.containsKey(ptc.getName()) ) { cmargs.put(ptc.getName(), ptc); mi.getMethod().invoke(rc, createChildByAvp( findChildDscr(cd.childs(), pt[0]), pt[0], message.getAvps())); } } } } // Fill undefined avp setUndefinedAvp(message.getAvps(), rc, c, cmargs); } catch (InstantiationException e) { throw new RecoderException(e); } catch (InvocationTargetException e) { throw new RecoderException(e); } catch (IllegalAccessException e) { throw new RecoderException(e); } } return (T) rc; } private void setUndefinedAvp(AvpSet set, Object rc, ClassInfo c, Map<String, Class<?>> cmargs) throws RecoderException { try { for (MethodInfo mi : c.getMethodsInfo()) { Setter s = mi.getAnnotation(Setter.class); if ( s != null && Setter.Type.UNDEFINED.equals(s.value())) { Map<Integer, Integer> known = new HashMap<Integer, Integer>(); for (Class<?> argc : cmargs.values()) { AvpDscr argd = storage.getClassInfo( (argc.isArray() ? argc.getComponentType() : argc) ).getAnnotation(AvpDscr.class); known.put(argd.code(), argd.code()); } for (Avp a : set) { if ( !known.containsKey(a.getCode())) mi.getMethod().invoke(rc, new UnknownAvp(a.getCode(), a.isMandatory(), a.isVendorId(), a.isEncrypted(), a.getVendorId(), a.getRaw()) ); } break; } } } catch (IllegalAccessException e) { throw new RecoderException(e); } catch (InvocationTargetException e) { throw new RecoderException(e); } catch (AvpDataException e) { throw new RecoderException(e); } } private Child findChildDscr(Child[] childs, Class<?> m) { for (Child c : childs) { Class<?> t = c.ref(); m = m.isArray() ? m.getComponentType() : m; if ( m == t ) return c; if ( m.getSuperclass() == t ) return c; for (Class<?> i : m.getInterfaces()) if ( i == t ) return c; } return null; } private Object createChildByAvp(Child mInfo, Class<?> m, AvpSet parentSet) throws RecoderException { Object rc; AvpDscr ad = storage.getClassInfo((m.isArray() ? m.getComponentType():m)).getAnnotation(AvpDscr.class); Avp av = parentSet.getAvp(ad.code()); if (av != null) { for (AvpFlag i : ad.must()) switch (i) { case M: if ( !av.isMandatory() ) throw new IllegalArgumentException("not set flag M"); break; case V: if ( !av.isVendorId() ) throw new IllegalArgumentException("not set flag V"); break; case P: if ( !av.isEncrypted() ) throw new IllegalArgumentException("not set flag P"); break; } } else { if (mInfo.min() > 0) throw new IllegalArgumentException("Avp " + ad.code() + " is mandatory"); } if (AvpType.Grouped.equals( ad.type()) ) { if (m.isArray()) { Class<?> arrayClass = m.getComponentType(); AvpSet as = parentSet.getAvps(ad.code()); Object[] array = (Object[]) java.lang.reflect.Array.newInstance(arrayClass, as.size()); for (int ii = 0; ii < array.length; ii++) { array[ii] = newInstanceGroupedAvp(arrayClass, ad, as.getAvpByIndex(ii)); } rc = array; } else { rc = newInstanceGroupedAvp(m, ad, parentSet.getAvp(ad.code())); } } else { if (m.isArray()) { Class<?> arrayClass = m.getComponentType(); AvpSet as = parentSet.getAvps(ad.code()); Object[] array = (Object[]) java.lang.reflect.Array.newInstance(arrayClass, as.size()); for (int ii = 0; ii < array.length; ii++) { array[ii] = newInstanceSimpleAvp(arrayClass, ad, as.getAvpByIndex(ii)); } rc = array; } else { rc = newInstanceSimpleAvp(m, ad, parentSet.getAvp(ad.code())); } } // ========= return rc; } private Object newInstanceGroupedAvp(Class<?> m, AvpDscr ad, Avp avp) throws RecoderException { Object rc; int cacount = 0; ClassInfo c = storage.getClassInfo(m); Constructor<?> cm = null; Map<String, Class<?>> cmargs = new HashMap<String, Class<?>>(); for (ConstructorInfo ci : c.getConstructorsInfo()) { if ( ci.getAnnotation(Setter.class) != null ) { // check params - all params must have avp annotation Class<?>[] params = ci.getConstructor().getParameterTypes(); boolean correct = true; for (Class<?> j : params ) { if (j.isArray()) j = j.getComponentType(); if ( storage.getClassInfo(j).getAnnotation(AvpDscr.class) == null) { correct = false; break; } } if ( !correct ) continue; // find max args constructor if (cacount < params.length) { cacount = params.length; cm = ci.getConstructor(); } } } // fill cm args try { List<Object> initargs = new ArrayList<Object>(); if (cm != null) { for (Class<?> ac : cm.getParameterTypes()) { Class<?> lac = ac.isArray() ? ac.getComponentType() : ac; cmargs.put(lac.getName(), ac); // Create params initargs.add( createChildByAvp( findChildDscr(ad.childs(), ac), ac, avp.getGrouped()) ); } // Create instance class rc = cm.newInstance(initargs.toArray()); } else { rc = m.newInstance(); } // for (MethodInfo mi : c.getMethodsInfo()) { if ( mi.getAnnotation(Setter.class) != null ) { Class<?>[] pt = mi.getMethod().getParameterTypes(); if (pt.length == 1 && storage.getClassInfo(pt[0]).getAnnotation(AvpDscr.class) != null) { Class<?> ptc = pt[0].isArray()? pt[0].getComponentType() : pt[0]; if ( !cmargs.containsKey(ptc.getName()) ) { cmargs.put(ptc.getName(), ptc); mi.getMethod().invoke( rc, createChildByAvp( findChildDscr(ad.childs(), pt[0]), pt[0], avp.getGrouped()) ); } } } } // Fill undefined child setUndefinedAvp(avp.getGrouped(), rc, c, cmargs); } catch (InstantiationException e) { throw new RecoderException(e); } catch (InvocationTargetException e) { throw new RecoderException(e); } catch (AvpDataException e) { throw new RecoderException(e); } catch (IllegalAccessException e) { throw new RecoderException(e); } return rc; } private Object newInstanceSimpleAvp(Class<?> m, AvpDscr ad, Avp avp) { Object rc = null; if (avp == null) return null; ClassInfo c = storage.getClassInfo(m); try { for (ConstructorInfo ci : c.getConstructorsInfo()) { if ( ci.getConstructor().getParameterTypes().length == 1 && ci.getAnnotation(Setter.class) != null ) { List<Object> args = new ArrayList<Object>(); if ( ci.getConstructor().getParameterTypes()[0].isArray() ) { args.add( getValue(ad.type(), avp ) ); } else { args.add( getValue(ad.type(), avp ) ); } rc = ci.getConstructor().newInstance(args.toArray()); } } if (rc == null) { rc = m.newInstance(); for (MethodInfo mi : c.getMethodsInfo()) { if ( mi.getAnnotation(Setter.class) != null ) { List<Object> args = new ArrayList<Object>(); if ( mi.getMethod().getParameterTypes()[0].isArray() ) { args.add( getValue(ad.type(), avp ) ); } else { args.add( getValue(ad.type(), avp ) ); } mi.getMethod().invoke(rc, args); } } } } catch (InstantiationException e) { throw new RecoderException(e); } catch (InvocationTargetException e) { throw new RecoderException(e); } catch (AvpDataException e) { throw new RecoderException(e); } catch (IllegalAccessException e) { throw new RecoderException(e); } return rc; } private Object getValue(AvpType type, Avp avp) throws AvpDataException { switch (type) { case Integer32: case Enumerated: return avp.getInteger32(); case Unsigned32: return avp.getUnsigned32(); case Unsigned64: case Integer64: return avp.getInteger64(); case Float32: return avp.getFloat32(); case Float64: return avp.getFloat64(); case OctetString: case Address: case Time: case DiameterIdentity: case DiameterURI: case IPFilterRule: case QoSFilterRule: return avp.getOctetString(); case UTF8String: return avp.getUTF8String(); } return null; } // ======================================================================================= }