package com.limegroup.gnutella.messages.vendor;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Comparator;
import java.util.Map;
import java.util.TreeMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.limewire.collection.Comparators;
import org.limewire.collection.IntHashMap;
import org.limewire.inject.EagerSingleton;
import org.limewire.service.ErrorService;
import org.limewire.util.ByteUtils;
import com.google.inject.Inject;
import com.limegroup.gnutella.messages.BadPacketException;
import com.limegroup.gnutella.messages.Message.Network;
/**
* Factory to turn binary input as read from Network to VendorMessage
* Objects.
*/
@EagerSingleton
public class VendorMessageFactoryImpl implements VendorMessageFactory {
private static final Log LOG = LogFactory.getLog(VendorMessageFactoryImpl.class);
private static final Comparator<byte[]> COMPARATOR = new Comparators.ByteArrayComparator();
/** Map (VendorID -> Map (selector -> Parser)) */
private volatile Map<byte[], IntHashMap<VendorMessageParser>> VENDORS =
new TreeMap<byte[], IntHashMap<VendorMessageParser>>(COMPARATOR);
private static final BadPacketException UNRECOGNIZED_EXCEPTION =
new BadPacketException("Unrecognized Vendor Message");
public VendorMessageFactoryImpl() {
}
@Inject
public VendorMessageFactoryImpl(VendorMessageParserBinder vendorMessageParserBinder) {
vendorMessageParserBinder.bind(this);
}
public void setParser(int selector, byte[] vendorId, VendorMessageParser parser) {
if (selector < 0 || selector > 0xFFFF) {
throw new IllegalArgumentException("Selector is out of range: " + selector);
}
if (vendorId == null) {
throw new NullPointerException("Vendor ID is null");
}
if (vendorId.length != 4) {
throw new IllegalArgumentException("Vendor ID must be 4 bytes long");
}
if (parser == null) {
throw new NullPointerException("VendorMessageParser is null");
}
Object o = null;
synchronized (VENDORS) {
Map<byte[], IntHashMap<VendorMessageParser>> vendors = copyVendors();
IntHashMap<VendorMessageParser> selectors = vendors.get(vendorId);
if (selectors == null) {
selectors = new IntHashMap<VendorMessageParser>();
vendors.put(vendorId, selectors);
}
o = selectors.put(selector, parser);
VENDORS = vendors;
}
if (o != null && LOG.isErrorEnabled()) {
LOG.error("There was already a VendorMessageParser of type "
+ o.getClass() + " registered for selector " + selector);
}
}
/** A helper method to create a deep copy of the VENDORS TreeMap. */
private Map<byte[], IntHashMap<VendorMessageParser>> copyVendors() {
Map<byte[], IntHashMap<VendorMessageParser>> copy =
new TreeMap<byte[], IntHashMap<VendorMessageParser>>(COMPARATOR);
for(Map.Entry<byte[], IntHashMap<VendorMessageParser>> entry : VENDORS.entrySet()) {
copy.put(entry.getKey(), new IntHashMap<VendorMessageParser>(entry.getValue()));
}
return copy;
}
public VendorMessageParser getParser(int selector, byte[] vendorId) {
IntHashMap<VendorMessageParser> selectors = VENDORS.get(vendorId);
if (selectors == null) {
return null;
}
return selectors.get(selector);
}
public VendorMessage deriveVendorMessage(byte[] guid, byte ttl,
byte hops, byte[] fromNetwork, Network network)
throws BadPacketException {
// sanity check
if (fromNetwork.length < VendorMessage.LENGTH_MINUS_PAYLOAD) {
throw new BadPacketException("Not enough bytes for a VM!!");
}
// get very necessary parameters....
ByteArrayInputStream bais = new ByteArrayInputStream(fromNetwork);
byte[] vendorID = null, restOf = null;
int selector = -1, version = -1;
try {
// first 4 bytes are vendor ID
vendorID = new byte[4];
bais.read(vendorID, 0, vendorID.length);
// get the selector....
selector = ByteUtils.ushort2int(ByteUtils.leb2short(bais));
// get the version....
version = ByteUtils.ushort2int(ByteUtils.leb2short(bais));
// get the rest....
restOf = new byte[bais.available()];
bais.read(restOf, 0, restOf.length);
} catch (IOException ioe) {
ErrorService.error(ioe); // impossible.
}
VendorMessageParser parser = getParser(selector, vendorID);
if (parser == null) {
throw UNRECOGNIZED_EXCEPTION;
}
return parser.parse(guid, ttl, hops, version, restOf, network);
}
}