/* * TeleStax, Open Source Cloud Communications * Copyright 2011-2016, TeleStax Inc. and individual contributors * by the @authors tag. * * This program is free software: you can redistribute it and/or modify * under the terms of the GNU Affero General Public License as * published by the Free Software Foundation; either version 3 of * the License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see <http://www.gnu.org/licenses/> * * This file incorporates work covered by the following copyright and * permission notice: * * JBoss, Home of Professional Open Source * Copyright 2007-2011, Red Hat, Inc. and individual contributors * by the @authors tag. See the copyright.txt in the distribution for a * full listing of individual contributors. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of * the License, or (at your option) any later version. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, write to the Free * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA * 02110-1301 USA, or see the FSF site: http://www.fsf.org. */ package org.jdiameter.client.impl.annotation; import java.lang.reflect.Constructor; import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import org.jdiameter.api.ApplicationId; import org.jdiameter.api.Avp; import org.jdiameter.api.AvpDataException; import org.jdiameter.api.AvpSet; import org.jdiameter.api.InternalException; import org.jdiameter.api.Message; import org.jdiameter.api.MetaData; import org.jdiameter.api.Request; import org.jdiameter.api.SessionFactory; import org.jdiameter.api.annotation.AvpDscr; import org.jdiameter.api.annotation.AvpFlag; import org.jdiameter.api.annotation.AvpType; import org.jdiameter.api.annotation.Child; import org.jdiameter.api.annotation.CommandDscr; import org.jdiameter.api.annotation.CommandFlag; import org.jdiameter.api.annotation.Getter; import org.jdiameter.api.annotation.Setter; import org.jdiameter.client.api.IMessage; import org.jdiameter.client.api.annotation.IRecoder; import org.jdiameter.client.api.annotation.RecoderException; import org.jdiameter.client.impl.RawSessionImpl; import org.jdiameter.client.impl.annotation.internal.ClassInfo; import org.jdiameter.client.impl.annotation.internal.ConstructorInfo; import org.jdiameter.client.impl.annotation.internal.MethodInfo; import org.jdiameter.client.impl.annotation.internal.Storage; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * * @author erick.svenson@yahoo.com * @author <a href="mailto:brainslog@gmail.com"> Alexandre Mendonca </a> * @author <a href="mailto:baranowb@gmail.com"> Bartosz Baranowski </a> */ public class Recoder implements IRecoder { // TODO full min/max/position constrains and optimization (caching) private static final Logger log = LoggerFactory.getLogger(Recoder.class); private Storage storage = new Storage(); private final RawSessionImpl rawSession; private final MetaData metaData; public Recoder(SessionFactory factory, MetaData metaData) { this.metaData = metaData; try { this.rawSession = (RawSessionImpl) factory.getNewRawSession(); } catch (InternalException e) { throw new IllegalArgumentException(e); } } // ======================================================================================= //@Override @Override public Message encodeToRequest(Object yourDomainMessageObject, Avp... additionalAvp) throws RecoderException { return encode(yourDomainMessageObject, null, 0, additionalAvp); } //@Override @Override public Message encodeToAnswer(Object yourDomainMessageObject, Request request, long resultCode) throws RecoderException { return encode(yourDomainMessageObject, request, resultCode); } public Message encode(Object yourDomainMessageObject, Request request, long resultCode, Avp... addAvp) throws RecoderException { IMessage message = null; ClassInfo classInfo = storage.getClassInfo(yourDomainMessageObject.getClass()); CommandDscr commandDscr = classInfo.getAnnotation(CommandDscr.class); if (commandDscr != null) { // Get command parameters if (request == null) { message = (IMessage) rawSession.createMessage(commandDscr.code(), ApplicationId.createByAccAppId(0)); message.setRequest(true); message.getAvps().addAvp(addAvp); try { if (message.getAvps().getAvp(Avp.AUTH_APPLICATION_ID) != null) { message.setHeaderApplicationId(message.getAvps().getAvp(Avp.AUTH_APPLICATION_ID).getUnsigned32()); } else if (message.getAvps().getAvp(Avp.ACCT_APPLICATION_ID) != null) { message.setHeaderApplicationId(message.getAvps().getAvp(Avp.ACCT_APPLICATION_ID).getUnsigned32()); } else if (message.getAvps().getAvp(Avp.VENDOR_SPECIFIC_APPLICATION_ID) != null) { message.setHeaderApplicationId(message.getAvps().getAvp(Avp.VENDOR_SPECIFIC_APPLICATION_ID). getGrouped().getAvp(Avp.VENDOR_ID).getUnsigned32()); } } catch (Exception exc) { throw new RecoderException(exc); } if (message.getAvps().getAvp(Avp.ORIGIN_HOST) == null) { message.getAvps().addAvp(Avp.ORIGIN_HOST, metaData.getLocalPeer().getUri().getFQDN(), true, false, true); } if (message.getAvps().getAvp(Avp.ORIGIN_REALM) == null) { message.getAvps().addAvp(Avp.ORIGIN_REALM, metaData.getLocalPeer().getRealmName(), true, false, true); } } else { message = (IMessage) request.createAnswer(resultCode); } for (CommandFlag f : commandDscr.flags()) { switch (f) { case E: message.setError(true); break; case P: message.setProxiable(true); break; case R: message.setRequest(true); break; case T: message.setReTransmitted(true); break; } } // Find top level avp in getter-annotation methods Map<String, Object> chMap = getChildInstance(yourDomainMessageObject, classInfo, null); // Fill for (Child ch : commandDscr.childs()) { fillChild(message.getAvps(), ch, chMap); } } else { log.debug("Can not found annotation for object {}", yourDomainMessageObject); } return message; } private Map<String, Object> getChildInstance(Object yourDomainMessageObject, ClassInfo c, Map<String, Object> chMap) throws RecoderException { if (chMap == null) { chMap = new HashMap<String, Object>(); } for (MethodInfo mi : c.getMethodsInfo()) { if (mi.getAnnotation(Getter.class) != null) { try { Object value = mi.getMethod().invoke(yourDomainMessageObject); if (value != null) { Class mc = value.getClass().isArray() ? value.getClass().getComponentType() : value.getClass(); chMap.put(mc.getName(), value); for (Class<?> i : mc.getInterfaces()) { chMap.put(i.getName(), value); } } } catch (IllegalAccessException e) { throw new RecoderException(e); } catch (InvocationTargetException e) { throw new RecoderException(e); } } } return chMap; } private void fillChild(AvpSet as, Child ci, Map<String, Object> childs) throws RecoderException { Object c = childs.get(ci.ref().getName()); if (c != null) { ClassInfo cc = storage.getClassInfo(ci.ref()); AvpDscr ad = cc.getAnnotation(AvpDscr.class); if (ad != null) { boolean m = false, p = false; // cast <=> getter for primitive switch (ad.type()) { case Integer32: case Enumerated: { for (AvpFlag f : ad.must()) { if (AvpFlag.M.equals(f)) { m = true; } else if (AvpFlag.P.equals(f)) { p = true; } } // find in getter Collection<Integer> cv = getValue(c, Integer.class); for (Integer v : cv) { as.addAvp(ad.code(), v, ad.vendorId(), m, p); } } break; case Unsigned32: { for (AvpFlag f : ad.must()) { if (AvpFlag.M.equals(f)) { m = true; } else if (AvpFlag.P.equals(f)) { p = true; } } Collection<Long> cv = getValue(c, Long.class); for (Long v : cv) { as.addAvp(ad.code(), v, ad.vendorId(), m, p, true); } } break; case Unsigned64: case Integer64: { for (AvpFlag f : ad.must()) { if (AvpFlag.M.equals(f)) { m = true; } else if (AvpFlag.P.equals(f)) { p = true; } } Collection<Long> cv = getValue(c, Long.class); for (Long v : cv) { as.addAvp(ad.code(), v, ad.vendorId(), m, p); } } break; case Float32: { for (AvpFlag f : ad.must()) { if (AvpFlag.M.equals(f)) { m = true; } else if (AvpFlag.P.equals(f)) { p = true; } } Collection<Float> cv = getValue(c, Float.class); for (Float v : cv) { as.addAvp(ad.code(), v, ad.vendorId(), m, p); } } break; case Float64: { for (AvpFlag f : ad.must()) { if (AvpFlag.M.equals(f)) { m = true; } else if (AvpFlag.P.equals(f)) { p = true; } } Collection<Double> cv = getValue(c, Double.class); for (Double v : cv) { as.addAvp(ad.code(), v, ad.vendorId(), m, p); } } break; case OctetString: case Address: case Time: case DiameterIdentity: case DiameterURI: case IPFilterRule: case QoSFilterRule: { for (AvpFlag f : ad.must()) { if (AvpFlag.M.equals(f)) { m = true; } else if (AvpFlag.P.equals(f)) { p = true; } } Collection<String> cv = getValue(c, String.class); for (String v : cv) { as.addAvp(ad.code(), v, ad.vendorId(), m, p, true); } } break; case UTF8String: { for (AvpFlag f : ad.must()) { if (AvpFlag.M.equals(f)) { m = true; } else if (AvpFlag.P.equals(f)) { p = true; } } Collection<String> cv = getValue(c, String.class); for (String v : cv) { as.addAvp(ad.code(), v, ad.vendorId(), m, p, false); } } break; case Grouped: { for (AvpFlag f : ad.must()) { if (AvpFlag.M.equals(f)) { m = true; } else if (AvpFlag.P.equals(f)) { p = true; } } Collection<Object> cv = new ArrayList<Object>(); if (c.getClass().isArray()) { cv = Arrays.asList((Object[]) c); } else { cv.add(c); } for (Object cj : cv) { AvpSet las = as.addGroupedAvp(ad.code(), ad.vendorId(), m, p); Map<String, Object> lchilds = getChildInstance(cj, storage.getClassInfo(cj.getClass()), null); for (Child lci : ad.childs()) { fillChild(las, lci, lchilds); } } } break; } } } } private <T> Collection<T> getValue(Object ic, Class<T> type) throws RecoderException { Collection<T> rc = new ArrayList<T>(); Object[] xc = null; if (ic.getClass().isArray()) { xc = (Object[]) ic; } else { xc = new Object[] {ic}; } for (Object c : xc) { for (MethodInfo lm : storage.getClassInfo(c.getClass()).getMethodsInfo()) { if (lm.getAnnotation(Getter.class) != null) { try { rc.add((T) lm.getMethod().invoke(c)); } catch (IllegalAccessException e) { throw new RecoderException(e); } catch (InvocationTargetException e) { throw new RecoderException(e); } } } } return rc; } // ======================================================================================= @Override public <T> T decode(Message message, java.lang.Class<T> yourDomainMessageObject) throws RecoderException { Object rc = null; ClassInfo c = storage.getClassInfo(yourDomainMessageObject); CommandDscr cd = c.getAnnotation(CommandDscr.class); if (cd != null) { try { if (message.getCommandCode() != cd.code()) { throw new IllegalArgumentException("Invalid message code " + message.getCommandCode()); } if (message.getApplicationId() != 0 && message.getApplicationId() != cd.appId()) { throw new IllegalArgumentException("Invalid Application-Id " + message.getApplicationId()); } for (CommandFlag f : cd.flags()) { switch (f) { case E: if (!message.isError()) { throw new IllegalArgumentException("Flag e is not set"); } break; case P: if (!message.isProxiable()) { throw new IllegalArgumentException("Flag p is not set"); } break; case R: if (!message.isRequest()) { throw new IllegalArgumentException("Flag m is not set"); } break; case T: if (!message.isReTransmitted()) { throw new IllegalArgumentException("Flag t is not set"); } break; } } // Find max constructor + lost avp set by setters int cacount = 0; Constructor<?> cm = null; Map<String, Class<?>> cmargs = new HashMap<String, Class<?>>(); for (ConstructorInfo ci : c.getConstructorsInfo()) { if (ci.getAnnotation(Setter.class) != null) { // check params - all params must have avp annotation Class<?>[] params = ci.getConstructor().getParameterTypes(); boolean correct = true; for (Class<?> j : params) { if (j.isArray()) { j = j.getComponentType(); } if (storage.getClassInfo(j).getAnnotation(AvpDscr.class) == null) { correct = false; break; } } if (!correct) { continue; } // find max args constructor if (cacount < params.length) { cacount = params.length; cm = ci.getConstructor(); } } } // fill cm args List<Object> initargs = new ArrayList<Object>(); if (cm != null) { for (Class<?> ac : cm.getParameterTypes()) { Class<?> lac = ac.isArray() ? ac.getComponentType() : ac; cmargs.put(lac.getName(), ac); // Create params initargs.add(createChildByAvp(findChildDscr(cd.childs(), ac), ac, message.getAvps())); } // Create instance class rc = cm.newInstance(initargs.toArray()); } else { rc = yourDomainMessageObject.newInstance(); } // for (MethodInfo mi : c.getMethodsInfo()) { if (mi.getAnnotation(Setter.class) != null) { Class<?>[] pt = mi.getMethod().getParameterTypes(); if (pt.length == 1 && storage.getClassInfo(pt[0]).getAnnotation(AvpDscr.class) != null) { Class<?> ptc = pt[0].isArray() ? pt[0].getComponentType() : pt[0]; if (!cmargs.containsKey(ptc.getName())) { cmargs.put(ptc.getName(), ptc); mi.getMethod().invoke(rc, createChildByAvp(findChildDscr(cd.childs(), pt[0]), pt[0], message.getAvps())); } } } } // Fill undefined avp setUndefinedAvp(message.getAvps(), rc, c, cmargs); } catch (InstantiationException e) { throw new RecoderException(e); } catch (InvocationTargetException e) { throw new RecoderException(e); } catch (IllegalAccessException e) { throw new RecoderException(e); } } return (T) rc; } private void setUndefinedAvp(AvpSet set, Object rc, ClassInfo c, Map<String, Class<?>> cmargs) throws RecoderException { try { for (MethodInfo mi : c.getMethodsInfo()) { Setter s = mi.getAnnotation(Setter.class); if (s != null && Setter.Type.UNDEFINED.equals(s.value())) { Map<Integer, Integer> known = new HashMap<Integer, Integer>(); for (Class<?> argc : cmargs.values()) { AvpDscr argd = storage.getClassInfo((argc.isArray() ? argc.getComponentType() : argc)).getAnnotation(AvpDscr.class); known.put(argd.code(), argd.code()); } for (Avp a : set) { if (!known.containsKey(a.getCode())) { mi.getMethod().invoke(rc, new UnknownAvp(a.getCode(), a.isMandatory(), a.isVendorId(), a.isEncrypted(), a.getVendorId(), a.getRaw())); } } break; } } } catch (IllegalAccessException e) { throw new RecoderException(e); } catch (InvocationTargetException e) { throw new RecoderException(e); } catch (AvpDataException e) { throw new RecoderException(e); } } private Child findChildDscr(Child[] childs, Class<?> m) { for (Child c : childs) { Class<?> t = c.ref(); m = m.isArray() ? m.getComponentType() : m; if (m == t) { return c; } if (m.getSuperclass() == t) { return c; } for (Class<?> i : m.getInterfaces()) { if (i == t) { return c; } } } return null; } private Object createChildByAvp(Child mInfo, Class<?> m, AvpSet parentSet) throws RecoderException { Object rc; AvpDscr ad = storage.getClassInfo((m.isArray() ? m.getComponentType() : m)).getAnnotation(AvpDscr.class); Avp av = parentSet.getAvp(ad.code()); if (av != null) { for (AvpFlag i : ad.must()) { switch (i) { case M: if (!av.isMandatory()) { throw new IllegalArgumentException("not set flag M"); } break; case V: if (!av.isVendorId()) { throw new IllegalArgumentException("not set flag V"); } break; case P: if (!av.isEncrypted()) { throw new IllegalArgumentException("not set flag P"); } break; } } } else { if (mInfo.min() > 0) { throw new IllegalArgumentException("Avp " + ad.code() + " is mandatory"); } } if (AvpType.Grouped.equals(ad.type())) { if (m.isArray()) { Class<?> arrayClass = m.getComponentType(); AvpSet as = parentSet.getAvps(ad.code()); Object[] array = (Object[]) java.lang.reflect.Array.newInstance(arrayClass, as.size()); for (int ii = 0; ii < array.length; ii++) { array[ii] = newInstanceGroupedAvp(arrayClass, ad, as.getAvpByIndex(ii)); } rc = array; } else { rc = newInstanceGroupedAvp(m, ad, parentSet.getAvp(ad.code())); } } else { if (m.isArray()) { Class<?> arrayClass = m.getComponentType(); AvpSet as = parentSet.getAvps(ad.code()); Object[] array = (Object[]) java.lang.reflect.Array.newInstance(arrayClass, as.size()); for (int ii = 0; ii < array.length; ii++) { array[ii] = newInstanceSimpleAvp(arrayClass, ad, as.getAvpByIndex(ii)); } rc = array; } else { rc = newInstanceSimpleAvp(m, ad, parentSet.getAvp(ad.code())); } } // ========= return rc; } private Object newInstanceGroupedAvp(Class<?> m, AvpDscr ad, Avp avp) throws RecoderException { Object rc; int cacount = 0; ClassInfo c = storage.getClassInfo(m); Constructor<?> cm = null; Map<String, Class<?>> cmargs = new HashMap<String, Class<?>>(); for (ConstructorInfo ci : c.getConstructorsInfo()) { if (ci.getAnnotation(Setter.class) != null) { // check params - all params must have avp annotation Class<?>[] params = ci.getConstructor().getParameterTypes(); boolean correct = true; for (Class<?> j : params) { if (j.isArray()) { j = j.getComponentType(); } if (storage.getClassInfo(j).getAnnotation(AvpDscr.class) == null) { correct = false; break; } } if (!correct) { continue; } // find max args constructor if (cacount < params.length) { cacount = params.length; cm = ci.getConstructor(); } } } // fill cm args try { List<Object> initargs = new ArrayList<Object>(); if (cm != null) { for (Class<?> ac : cm.getParameterTypes()) { Class<?> lac = ac.isArray() ? ac.getComponentType() : ac; cmargs.put(lac.getName(), ac); // Create params initargs.add(createChildByAvp(findChildDscr(ad.childs(), ac), ac, avp.getGrouped())); } // Create instance class rc = cm.newInstance(initargs.toArray()); } else { rc = m.newInstance(); } // for (MethodInfo mi : c.getMethodsInfo()) { if (mi.getAnnotation(Setter.class) != null) { Class<?>[] pt = mi.getMethod().getParameterTypes(); if (pt.length == 1 && storage.getClassInfo(pt[0]).getAnnotation(AvpDscr.class) != null) { Class<?> ptc = pt[0].isArray() ? pt[0].getComponentType() : pt[0]; if (!cmargs.containsKey(ptc.getName())) { cmargs.put(ptc.getName(), ptc); mi.getMethod().invoke(rc, createChildByAvp(findChildDscr(ad.childs(), pt[0]), pt[0], avp.getGrouped())); } } } } // Fill undefined child setUndefinedAvp(avp.getGrouped(), rc, c, cmargs); } catch (InstantiationException e) { throw new RecoderException(e); } catch (InvocationTargetException e) { throw new RecoderException(e); } catch (AvpDataException e) { throw new RecoderException(e); } catch (IllegalAccessException e) { throw new RecoderException(e); } return rc; } private Object newInstanceSimpleAvp(Class<?> m, AvpDscr ad, Avp avp) { Object rc = null; if (avp == null) { return null; } ClassInfo c = storage.getClassInfo(m); try { for (ConstructorInfo ci : c.getConstructorsInfo()) { if (ci.getConstructor().getParameterTypes().length == 1 && ci.getAnnotation(Setter.class) != null) { List<Object> args = new ArrayList<Object>(); if (ci.getConstructor().getParameterTypes()[0].isArray()) { args.add(getValue(ad.type(), avp)); } else { args.add(getValue(ad.type(), avp)); } rc = ci.getConstructor().newInstance(args.toArray()); } } if (rc == null) { rc = m.newInstance(); for (MethodInfo mi : c.getMethodsInfo()) { if (mi.getAnnotation(Setter.class) != null) { List<Object> args = new ArrayList<Object>(); if (mi.getMethod().getParameterTypes()[0].isArray()) { args.add(getValue(ad.type(), avp)); } else { args.add(getValue(ad.type(), avp)); } mi.getMethod().invoke(rc, args); } } } } catch (InstantiationException e) { throw new RecoderException(e); } catch (InvocationTargetException e) { throw new RecoderException(e); } catch (AvpDataException e) { throw new RecoderException(e); } catch (IllegalAccessException e) { throw new RecoderException(e); } return rc; } private Object getValue(AvpType type, Avp avp) throws AvpDataException { switch (type) { case Integer32: case Enumerated: return avp.getInteger32(); case Unsigned32: return avp.getUnsigned32(); case Unsigned64: case Integer64: return avp.getInteger64(); case Float32: return avp.getFloat32(); case Float64: return avp.getFloat64(); case OctetString: case Address: case Time: case DiameterIdentity: case DiameterURI: case IPFilterRule: case QoSFilterRule: return avp.getOctetString(); case UTF8String: return avp.getUTF8String(); } return null; } // ======================================================================================= }