/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.sun.jini.discovery; import com.sun.jini.collection.WeakIdentityMap; import com.sun.jini.logging.Levels; import com.sun.jini.resource.Service; import java.io.DataInputStream; import java.io.IOException; import java.io.OutputStream; import java.lang.ref.SoftReference; import java.lang.ref.Reference; import java.net.DatagramPacket; import java.net.InetAddress; import java.net.Socket; import java.nio.ByteBuffer; import java.security.AccessController; import java.security.MessageDigest; import java.security.PrivilegedAction; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; import net.jini.core.constraint.InvocationConstraints; import net.jini.discovery.Constants; import net.jini.io.UnsupportedConstraintException; /** * Class providing methods for implementing discovery protocol version 2. */ class DiscoveryV2 extends Discovery { private static final byte MULTICAST_ANNOUNCEMENT = 0; private static final byte MULTICAST_REQUEST = 1; private static final long NULL_FORMAT_ID = 0; private static final int FORMAT_ID_LEN = 8; private static final int MULTICAST_HEADER_LEN = 4 + 1 + 8; private static final int UNICAST_REQUEST_HEADER_LEN = 4 + 2; private static final int UNICAST_RESPONSE_HEADER_LEN = 4 + 8; private static final int MULTICAST_REQUEST_ENCODER = 0; private static final int MULTICAST_REQUEST_DECODER = 1; private static final int MULTICAST_ANNOUNCEMENT_ENCODER = 2; private static final int MULTICAST_ANNOUNCEMENT_DECODER = 3; private static final int UNICAST_DISCOVERY_CLIENT = 4; private static final int UNICAST_DISCOVERY_SERVER = 5; private static final int NUM_PROVIDER_TYPES = 6; private static final Class[] providerTypes; static { Class[] t = new Class[NUM_PROVIDER_TYPES]; t[MULTICAST_REQUEST_ENCODER] = MulticastRequestEncoder.class; t[MULTICAST_REQUEST_DECODER] = MulticastRequestDecoder.class; t[MULTICAST_ANNOUNCEMENT_ENCODER] = MulticastAnnouncementEncoder.class; t[MULTICAST_ANNOUNCEMENT_DECODER] = MulticastAnnouncementDecoder.class; t[UNICAST_DISCOVERY_CLIENT] = UnicastDiscoveryClient.class; t[UNICAST_DISCOVERY_SERVER] = UnicastDiscoveryServer.class; providerTypes = t; } private static final WeakIdentityMap instances = new WeakIdentityMap(); private static final Logger logger = Logger.getLogger(DiscoveryV2.class.getName()); private final Map[] formatIdMaps; /** * Returns DiscoveryV2 instance which uses providers loaded from the given * class loader, or the current context class loader if the given loader is * null. */ static DiscoveryV2 getInstance(ClassLoader loader) { if (loader == null) { loader = getContextClassLoader(); } DiscoveryV2 disco; synchronized (instances) { disco = null; Reference softDisco = (Reference) instances.get(loader); if (softDisco != null) { disco = (DiscoveryV2) softDisco.get(); } } if (disco == null) { disco = new DiscoveryV2(getProviders(loader)); synchronized (instances) { instances.put(loader, new SoftReference(disco)); } } if (logger.isLoggable(Level.FINEST)) { logger.log(Level.FINEST, "returning {0}", new Object[]{ disco }); } return disco; } /** * Returns DiscoveryV2 instance which uses the given providers. */ static DiscoveryV2 getInstance(MulticastRequestEncoder[] mre, MulticastRequestDecoder[] mrd, MulticastAnnouncementEncoder[] mae, MulticastAnnouncementDecoder[] mad, UnicastDiscoveryClient[] udc, UnicastDiscoveryServer[] uds) { List[] providers = new List[NUM_PROVIDER_TYPES]; providers[MULTICAST_REQUEST_ENCODER] = asList(mre); providers[MULTICAST_REQUEST_DECODER] = asList(mrd); providers[MULTICAST_ANNOUNCEMENT_ENCODER] = asList(mae); providers[MULTICAST_ANNOUNCEMENT_DECODER] = asList(mad); providers[UNICAST_DISCOVERY_CLIENT] = asList(udc); providers[UNICAST_DISCOVERY_SERVER] = asList(uds); DiscoveryV2 disco = new DiscoveryV2(providers); if (logger.isLoggable(Level.FINEST)) { logger.log(Level.FINEST, "returning {0}", new Object[]{ disco }); } return disco; } private DiscoveryV2(List[] providers) { formatIdMaps = new Map[NUM_PROVIDER_TYPES]; for (int i = 0; i < formatIdMaps.length; i++) { formatIdMaps[i] = makeFormatIdMap(providers[i]); } } public EncodeIterator encodeMulticastRequest( final MulticastRequest request, final int maxPacketSize, InvocationConstraints constraints) { if (maxPacketSize < MIN_MAX_PACKET_SIZE) { throw new IllegalArgumentException("maxPacketSize too small"); } final InvocationConstraints absc = (constraints != null) ? constraints.makeAbsolute() : null; return new EncodeIterator() { private final Iterator entries = formatIdMaps[MULTICAST_REQUEST_ENCODER].entrySet().iterator(); public DatagramPacket[] next() throws IOException { // fetch next encoder, format ID Map.Entry ent = (Map.Entry) entries.next(); long fid = ((Long) ent.getKey()).longValue(); MulticastRequestEncoder mre = (MulticastRequestEncoder) ent.getValue(); // prepare buffer factory, which writes packet headers DatagramBuffers db = new DatagramBuffers( Constants.getRequestAddress(), maxPacketSize, MULTICAST_REQUEST, fid); // encode data mre.encodeMulticastRequest(request, db, absc); if (logger.isLoggable(Level.FINEST)) { logger.log(Level.FINEST, "encoded {0} using {1}, {2}", new Object[]{ request, mre, absc }); } return db.getDatagrams(); } public boolean hasNext() { return entries.hasNext(); } }; } public MulticastRequest decodeMulticastRequest( DatagramPacket packet, InvocationConstraints constraints, ClientSubjectChecker checker, boolean delayConstraintCheck) throws IOException { if (constraints != null) { constraints = constraints.makeAbsolute(); } ByteBuffer buf = ByteBuffer.wrap( packet.getData(), packet.getOffset(), packet.getLength()).slice(); if (buf.remaining() < MULTICAST_HEADER_LEN) { throw new DiscoveryProtocolException("incomplete header"); } // read protocol version int pv = buf.getInt(); if (pv != PROTOCOL_VERSION_2) { throw new DiscoveryProtocolException( "wrong protocol version: " + pv); } // read packet type byte pt = buf.get(); if (pt != MULTICAST_REQUEST) { throw new DiscoveryProtocolException("wrong packet type: " + pt); } // read format ID long fid = buf.getLong(); // lookup decoder MulticastRequestDecoder mrd = (MulticastRequestDecoder) formatIdMaps[MULTICAST_REQUEST_DECODER].get(new Long(fid)); if (mrd == null) { throw new DiscoveryProtocolException( "unsupported format ID: " + fid); } // decode payload MulticastRequest req; if (mrd instanceof DelayedMulticastRequestDecoder) { DelayedMulticastRequestDecoder dmrd = (DelayedMulticastRequestDecoder) mrd; req = dmrd.decodeMulticastRequest(buf, constraints, checker, delayConstraintCheck); } else { req = mrd.decodeMulticastRequest(buf, constraints, checker); } if (logger.isLoggable(Level.FINEST)) { logger.log(Level.FINEST, "decoded {0} using {1}, {2}, {3}", new Object[]{ req, mrd, constraints, checker }); } return req; } public MulticastRequest decodeMulticastRequest(DatagramPacket packet, InvocationConstraints constraints, ClientSubjectChecker checker) throws IOException { // default behavior is no delayed constraint checking. return decodeMulticastRequest(packet, constraints, checker, false); } public EncodeIterator encodeMulticastAnnouncement( final MulticastAnnouncement announcement, final int maxPacketSize, InvocationConstraints constraints) { if (maxPacketSize < MIN_MAX_PACKET_SIZE) { throw new IllegalArgumentException("maxPacketSize too small"); } final InvocationConstraints absc = (constraints != null) ? constraints.makeAbsolute() : null; return new EncodeIterator() { private final Iterator entries = formatIdMaps[ MULTICAST_ANNOUNCEMENT_ENCODER].entrySet().iterator(); public DatagramPacket[] next() throws IOException { // fetch next encoder, format ID Map.Entry ent = (Map.Entry) entries.next(); long fid = ((Long) ent.getKey()).longValue(); MulticastAnnouncementEncoder mae = (MulticastAnnouncementEncoder) ent.getValue(); // prepare buffer factory, which writes packet headers DatagramBuffers db = new DatagramBuffers( Constants.getAnnouncementAddress(), maxPacketSize, MULTICAST_ANNOUNCEMENT, fid); // encode data mae.encodeMulticastAnnouncement(announcement, db, absc); if (logger.isLoggable(Level.FINEST)) { logger.log(Level.FINEST, "encoded {0} using {1}, {2}", new Object[]{ announcement, mae, absc }); } return db.getDatagrams(); } public boolean hasNext() { return entries.hasNext(); } }; } public MulticastAnnouncement decodeMulticastAnnouncement( DatagramPacket packet, InvocationConstraints constraints, boolean delayConstraintCheck) throws IOException { if (constraints != null) { constraints = constraints.makeAbsolute(); } ByteBuffer buf = ByteBuffer.wrap( packet.getData(), packet.getOffset(), packet.getLength()).slice(); if (buf.remaining() < MULTICAST_HEADER_LEN) { throw new DiscoveryProtocolException("incomplete header"); } // read protocol version int pv = buf.getInt(); if (pv != PROTOCOL_VERSION_2) { throw new DiscoveryProtocolException( "wrong protocol version: " + pv); } // read packet type byte pt = buf.get(); if (pt != MULTICAST_ANNOUNCEMENT) { throw new DiscoveryProtocolException("wrong packet type: " + pt); } // read format ID long fid = buf.getLong(); // lookup decoder MulticastAnnouncementDecoder mad = (MulticastAnnouncementDecoder) formatIdMaps[MULTICAST_ANNOUNCEMENT_DECODER].get(new Long(fid)); if (mad == null) { throw new DiscoveryProtocolException( "unsupported format ID: " + fid); } MulticastAnnouncement ann; // decode payload if (mad instanceof DelayedMulticastAnnouncementDecoder) { DelayedMulticastAnnouncementDecoder dmad = (DelayedMulticastAnnouncementDecoder) mad; ann = dmad.decodeMulticastAnnouncement(buf, constraints, delayConstraintCheck); } else { ann = mad.decodeMulticastAnnouncement(buf, constraints); } if (logger.isLoggable(Level.FINEST)) { logger.log(Level.FINEST, "decoded {0} using {1}, {2}", new Object[]{ ann, mad, constraints }); } return ann; } public MulticastAnnouncement decodeMulticastAnnouncement( DatagramPacket packet, InvocationConstraints constraints) throws IOException { // default behavior is no delayed constraint checking. return decodeMulticastAnnouncement(packet, constraints, false); } public UnicastResponse doUnicastDiscovery( Socket socket, InvocationConstraints constraints, ClassLoader defaultLoader, ClassLoader verifierLoader, Collection context) throws IOException, ClassNotFoundException { final int MAX_FORMATS = 0xFFFF; if (constraints != null) { constraints = constraints.makeAbsolute(); } // determine set of acceptable formats to propose Map udcMap = formatIdMaps[UNICAST_DISCOVERY_CLIENT]; Set fids = new HashSet(); Exception ex = null; for (Iterator i = udcMap.entrySet().iterator(); i.hasNext(); ) { Map.Entry ent = (Map.Entry) i.next(); UnicastDiscoveryClient udc = (UnicastDiscoveryClient) ent.getValue(); try { udc.checkUnicastDiscoveryConstraints(constraints); fids.add(ent.getKey()); if (fids.size() == MAX_FORMATS) { logger.log(Level.WARNING, "truncating format ID list"); break; } } catch (Exception e) { if (e instanceof UnsupportedConstraintException || e instanceof SecurityException) { ex = e; logger.log(Levels.HANDLED, "constraint check failed", e); } else { throw (RuntimeException) e; } } } if (fids.isEmpty()) { if (ex == null) { throw new DiscoveryProtocolException("no supported formats"); } else if (ex instanceof UnsupportedConstraintException) { throw (UnsupportedConstraintException) ex; } else { throw (SecurityException) ex; } } ByteBuffer outBuf = ByteBuffer.allocate( UNICAST_REQUEST_HEADER_LEN + (FORMAT_ID_LEN * fids.size())); // write protocol version outBuf.putInt(PROTOCOL_VERSION_2); // write proposed format IDs outBuf.putShort((short) fids.size()); for (Iterator i = fids.iterator(); i.hasNext(); ) { outBuf.putLong(((Long) i.next()).longValue()); } OutputStream out = socket.getOutputStream(); out.write(outBuf.array(), outBuf.arrayOffset(), outBuf.position()); out.flush(); ByteBuffer inBuf = ByteBuffer.allocate(UNICAST_RESPONSE_HEADER_LEN); new DataInputStream(socket.getInputStream()).readFully( inBuf.array(), inBuf.arrayOffset() + inBuf.position(), inBuf.remaining()); // read protocol version int pv = inBuf.getInt(); if (pv != PROTOCOL_VERSION_2) { throw new DiscoveryProtocolException( "wrong protocol version: " + pv); } // read selected format ID Long fid = new Long(inBuf.getLong()); if (fid.longValue() == NULL_FORMAT_ID) { throw new DiscoveryProtocolException("format negotiation failed"); } if (!fids.contains(fid)) { throw new DiscoveryProtocolException( "response format ID not proposed: " + fid); } // hand off to format provider to receive response data UnicastDiscoveryClient udc = (UnicastDiscoveryClient) udcMap.get(fid); UnicastResponse resp = udc.doUnicastDiscovery( socket, constraints, defaultLoader, verifierLoader, context, (ByteBuffer) outBuf.flip(), (ByteBuffer) inBuf.flip()); if (logger.isLoggable(Level.FINEST)) { logger.log(Level.FINEST, "received {0} using {1}, {2}", new Object[]{ resp, udc, constraints }); } return resp; } public void handleUnicastDiscovery(UnicastResponse response, Socket socket, InvocationConstraints constraints, ClientSubjectChecker checker, Collection context) throws IOException { if (constraints != null) { constraints = constraints.makeAbsolute(); } // note: protocol version already consumed // read proposed format IDs DataInputStream din = new DataInputStream(socket.getInputStream()); int nfids = din.readUnsignedShort(); if (nfids < 0) { throw new DiscoveryProtocolException( "invalid format ID count: " + nfids); } ByteBuffer inBuf = ByteBuffer.allocate( UNICAST_REQUEST_HEADER_LEN + (FORMAT_ID_LEN * nfids)); inBuf.putInt(PROTOCOL_VERSION_2); inBuf.putShort((short) nfids); din.readFully(inBuf.array(), inBuf.arrayOffset() + inBuf.position(), inBuf.remaining()); // select format provider UnicastDiscoveryServer uds = null; long fid = NULL_FORMAT_ID; Map udsMap = formatIdMaps[UNICAST_DISCOVERY_SERVER]; while (inBuf.hasRemaining()) { fid = inBuf.getLong(); UnicastDiscoveryServer s = (UnicastDiscoveryServer) udsMap.get(new Long(fid)); if (s != null) { try { s.checkUnicastDiscoveryConstraints(constraints); uds = s; break; } catch (Exception e) { logger.log(Levels.HANDLED, "constraint check failed", e); } } } ByteBuffer outBuf = ByteBuffer.allocate(UNICAST_RESPONSE_HEADER_LEN); // write protocol version outBuf.putInt(PROTOCOL_VERSION_2); // write selected format ID outBuf.putLong((uds != null) ? fid : NULL_FORMAT_ID); OutputStream out = socket.getOutputStream(); out.write(outBuf.array(), outBuf.arrayOffset(), outBuf.position()); out.flush(); if (uds == null) { throw new DiscoveryProtocolException("format negotiation failed"); } // hand off to format provider to send response data uds.handleUnicastDiscovery( response, socket, constraints, checker, context, (ByteBuffer) inBuf.flip(), (ByteBuffer) outBuf.flip()); if (logger.isLoggable(Level.FINEST)) { logger.log(Level.FINEST, "sent {0} using {1}, {2}, {3}", new Object[]{ response, uds, constraints, checker }); } } public String toString() { // REMIND: cache string? List l = new ArrayList(NUM_PROVIDER_TYPES); for (int i = 0; i < NUM_PROVIDER_TYPES; i++) { l.add(formatIdMaps[i].values()); } return "DiscoveryV2" + l; } private static ClassLoader getContextClassLoader() { return (ClassLoader) AccessController.doPrivileged( new PrivilegedAction() { public Object run() { return Thread.currentThread().getContextClassLoader(); } }); } private static List[] getProviders(final ClassLoader ldr) { return (List[]) AccessController.doPrivileged(new PrivilegedAction() { public Object run() { List[] providers = new List[NUM_PROVIDER_TYPES]; for (int i = 0; i < providers.length; i++) { providers[i] = new ArrayList(); } Iterator iter = Service.providers( DiscoveryFormatProvider.class, ldr); while (iter.hasNext()) { Object obj = iter.next(); boolean used = false; for (int i = 0; i < providerTypes.length; i++) { if (providerTypes[i].isInstance(obj)) { providers[i].add(obj); used = true; } } if (!used) { logger.log(Level.WARNING, "unusable format provider {0}", new Object[]{ obj }); } } return providers; } }); } private static Map makeFormatIdMap(List providers) { Map map = new HashMap(); for (Iterator i = providers.iterator(); i.hasNext(); ) { DiscoveryFormatProvider p = (DiscoveryFormatProvider) i.next(); Long fid = new Long(computeFormatID(p.getFormatName())); if (map.keySet().contains(fid)) { logger.log(Level.WARNING, "ignoring provider {0} ({1}) with " + "conflicting format ID {2}", new Object[]{ p, p.getFormatName(), fid }); continue; } map.put(fid, p); } return map; } private static long computeFormatID(String format) { try { MessageDigest md = MessageDigest.getInstance("SHA-1"); byte[] b = md.digest(format.getBytes("UTF-8")); return ((b[7] & 0xFFL) << 0) + ((b[6] & 0xFFL) << 8) + ((b[5] & 0xFFL) << 16) + ((b[4] & 0xFFL) << 24) + ((b[3] & 0xFFL) << 32) + ((b[2] & 0xFFL) << 40) + ((b[1] & 0xFFL) << 48) + ((b[0] & 0xFFL) << 56); } catch (Exception e) { throw new AssertionError(e); } } private static List asList(Object[] a) { return (a != null) ? Arrays.asList(a) : Collections.EMPTY_LIST; } /** * Buffer factory passed to multicast request and announcement encoders. */ private static class DatagramBuffers implements DatagramBufferFactory { private static final int TRIM_THRESHOLD = 512; private final List datagrams = new ArrayList(); private final InetAddress addr; private final int maxPacketSize; private final byte packetType; private final long formatId; DatagramBuffers(InetAddress addr, int maxPacketSize, byte packetType, long formatId) { this.addr = addr; this.maxPacketSize = maxPacketSize; this.packetType = packetType; this.formatId = formatId; } public ByteBuffer newBuffer() { DatagramInfo di = new DatagramInfo(); datagrams.add(di); return di.getBuffer(); } DatagramPacket[] getDatagrams() { DatagramPacket[] dp = new DatagramPacket[datagrams.size()]; for (int i = 0; i < dp.length; i++) { dp[i] = ((DatagramInfo) datagrams.get(i)).getDatagram(); } return dp; } private class DatagramInfo { private final DatagramPacket datagram; private final ByteBuffer buf; DatagramInfo() { datagram = new DatagramPacket(new byte[maxPacketSize], 0, addr, Constants.discoveryPort); buf = ByteBuffer.wrap(datagram.getData()); // write packet header buf.putInt(PROTOCOL_VERSION_2); buf.put(packetType); buf.putLong(formatId); } ByteBuffer getBuffer() { return buf; } DatagramPacket getDatagram() { int len = buf.position(); // trim excess buffer space if too large if (buf.remaining() > TRIM_THRESHOLD) { byte[] b = new byte[len]; System.arraycopy(datagram.getData(), 0, b, 0, len); datagram.setData(b); } datagram.setLength(len); return datagram; } } } }