/* * This file is part of mlDHT. * * mlDHT is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 2 of the License, or * (at your option) any later version. * * mlDHT is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with mlDHT. If not, see <http://www.gnu.org/licenses/>. */ package lbms.plugins.mldht.kad.messages; import static the8472.bencode.Utils.prettyPrint; import static the8472.utils.Functional.castOrThrow; import static the8472.utils.Functional.tap; import static the8472.utils.Functional.tapThrow; import static the8472.utils.Functional.typedGet; import the8472.bencode.PathMatcher; import the8472.bencode.Tokenizer; import the8472.utils.Functional; import lbms.plugins.mldht.kad.BloomFilterBEP33; import lbms.plugins.mldht.kad.DBItem; import lbms.plugins.mldht.kad.DHT; import lbms.plugins.mldht.kad.DHT.DHTtype; import lbms.plugins.mldht.kad.DHT.LogLevel; import lbms.plugins.mldht.kad.Key; import lbms.plugins.mldht.kad.NodeList; import lbms.plugins.mldht.kad.NodeList.AddressType; import lbms.plugins.mldht.kad.PeerAddressDBItem; import lbms.plugins.mldht.kad.messages.ErrorMessage.ErrorCode; import lbms.plugins.mldht.kad.messages.MessageBase.Method; import lbms.plugins.mldht.kad.messages.MessageBase.Type; import lbms.plugins.mldht.kad.utils.AddressUtils; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.function.Function; import java.util.stream.Stream; /** * @author Damokles * */ public class MessageDecoder { public MessageDecoder(Function<byte[], Optional<Method>> transactionIdMapper, DHTtype type) { this.transactionIdMapper = transactionIdMapper; this.type = type; } Map<String, Object> rootMap; ByteBuffer raw; final Function<byte[], Optional<Method>> transactionIdMapper; final DHTtype type; public void toDecode(ByteBuffer rawMessage, Map<String, Object> map) { this.raw = rawMessage; this.rootMap = map; } public MessageBase parseMessage() throws MessageException, IOException { try { String msgType = getStringFromBytes((byte[]) rootMap.get(Type.TYPE_KEY), true); if (msgType == null) { throw new MessageException("message type (y) missing", ErrorCode.ProtocolError); } Optional<byte[]> version = typedGet(rootMap, MessageBase.VERSION_KEY, byte[].class); MessageBase mb = null; if (msgType.equals(Type.REQ_MSG.getRPCTypeName())) { mb = parseRequest(rootMap, transactionIdMapper, type); } else if (msgType.equals(Type.RSP_MSG.getRPCTypeName())) { mb = parseResponse(rootMap, transactionIdMapper); } else if (msgType.equals(Type.ERR_MSG.getRPCTypeName())) { mb = parseError(rootMap, transactionIdMapper); } else throw new MessageException("unknown RPC type (y="+msgType+")"); if (mb != null) { version.ifPresent(mb::setVersion); } return mb; } catch (Exception e) { if(e instanceof MessageException) throw (MessageException)e; throw new IOException("could not parse message",e); } } /** * @param map * @return */ private MessageBase parseError (Map<String, Object> map, Function<byte[], Optional<Method>> transactionIdMapper) { Object error = map.get(Type.ERR_MSG.innerKey()); int errorCode = 0; String errorMsg = null; if(error instanceof byte[]) errorMsg = getStringFromBytes((byte[])error); else if (error instanceof List<?>) { List<Object> errmap = (List<Object>)error; try { errorCode = ((Long) errmap.get(0)).intValue(); errorMsg = getStringFromBytes((byte[]) errmap.get(1)); } catch (Exception e) { // do nothing } } Object rawMtid = map.get(MessageBase.TRANSACTION_KEY); if (errorMsg == null && (rawMtid == null || !(rawMtid instanceof byte[]))) return null; byte[] mtid = (byte[]) rawMtid; ErrorMessage msg = new ErrorMessage(mtid, errorCode,errorMsg); typedGet(map, "id", byte[].class).filter(b -> b.length == Key.SHA1_HASH_LENGTH).ifPresent(h -> msg.setID(new Key(h))); transactionIdMapper.apply(mtid).ifPresent(m -> msg.method = m); return msg; } /** * @param map * @param srv * @return */ private MessageBase parseResponse (Map<String, Object> map, Function<byte[], Optional<Method>> transactionIdMapper) throws MessageException { byte[] mtid = (byte[]) map.get(MessageBase.TRANSACTION_KEY); if (mtid == null || mtid.length < 1) throw new MessageException("missing transaction ID",ErrorCode.ProtocolError); // responses don't have explicit methods, need to match them to a request to figure that one out Method m = transactionIdMapper.apply(mtid).orElse(Method.UNKNOWN); return parseResponse(map, m, mtid); } /** * @param map * @param msgMethod * @param mtid * @return */ private MessageBase parseResponse (Map<String, Object> map, Method msgMethod, byte[] mtid) throws MessageException { Map<String, Object> args = (Map<String, Object>) map.get(Type.RSP_MSG.innerKey()); if (args == null) { throw new MessageException("response did not contain a body",ErrorCode.ProtocolError); } byte[] hash = Optional.ofNullable(args.get("id")) .map(castOrThrow(byte[].class, (o) -> new MessageException("expected parameter 'id' to be a byte-string, got "+o.getClass().getSimpleName(), ErrorCode.ProtocolError))) .orElseThrow(() -> new MessageException("mandatory parameter 'id' missing", ErrorCode.ProtocolError)); byte[] ip = (byte[]) map.get(MessageBase.EXTERNAL_IP_KEY); if (hash.length != Key.SHA1_HASH_LENGTH) { throw new MessageException("invalid or missing origin ID",ErrorCode.ProtocolError); } Key id = new Key(hash); MessageBase msg = null; switch (msgMethod) { case PING: msg = new PingResponse(mtid); break; case PUT: msg = new PutResponse(mtid); break; case ANNOUNCE_PEER: msg = new AnnounceResponse(mtid); break; case FIND_NODE: if (!args.containsKey("nodes") && !args.containsKey("nodes6")) throw new MessageException("received response to find_node request with neither 'nodes' nor 'nodes6' entry", ErrorCode.ProtocolError); //return null; msg = tapThrow(new FindNodeResponse(mtid), (m) -> { extractNodes(args, "nodes", DHTtype.IPV4_DHT).ifPresent(n -> m.setNodes(n)); extractNodes(args, "nodes6", DHTtype.IPV6_DHT).ifPresent(n -> m.setNodes(n)); }); break; case GET: GetResponse get = new GetResponse(mtid); extractNodes(args, "nodes", DHTtype.IPV4_DHT).ifPresent(get::setNodes); extractNodes(args, "nodes6", DHTtype.IPV6_DHT).ifPresent(get::setNodes); PathMatcher m = new PathMatcher(Type.RSP_MSG.innerKey(),"v"); Tokenizer t = new Tokenizer(); m.tokenizer(t); ByteBuffer rawVal = m.match(raw); get.setRawValue(rawVal); typedGet(args, "token", byte[].class).ifPresent(get::setToken);; typedGet(args, "k", byte[].class).ifPresent(get::setKey); typedGet(args, "sig", byte[].class).ifPresent(get::setSignature); typedGet(args, "seq", Long.class).ifPresent(get::setSequenceNumber); msg = get; break; case GET_PEERS: byte[] token = Functional.typedGet(args, "token", byte[].class).orElse(null); Optional<NodeList> nodes = extractNodes(args, "nodes", DHTtype.IPV4_DHT); Optional<NodeList> nodes6 = extractNodes(args, "nodes6", DHTtype.IPV6_DHT); List<DBItem> dbl = null; @SuppressWarnings("unchecked") List<byte[]> vals = Optional.ofNullable(args.get("values")) .map(castOrThrow(List.class, val -> new MessageException("expected 'values' field in get_peers to be list of strings, got "+val.getClass(), ErrorCode.ProtocolError))) .orElse(Collections.EMPTY_LIST); if(vals.size() > 0) { dbl = new ArrayList<>(vals.size()); for (int i = 0; i < vals.size(); i++) { // only accept ipv4 or ipv6 for now if (vals.get(i).length != DHTtype.IPV4_DHT.ADDRESS_ENTRY_LENGTH && vals.get(i).length != DHTtype.IPV6_DHT.ADDRESS_ENTRY_LENGTH) continue; dbl.add(new PeerAddressDBItem(vals.get(i), false)); } } byte[] peerFilter = (byte[]) args.get("BFpe"); byte[] seedFilter = (byte[]) args.get("BFse"); if((peerFilter != null && peerFilter.length != BloomFilterBEP33.m/8) || (seedFilter != null && seedFilter.length != BloomFilterBEP33.m/8)) throw new MessageException("invalid BEP33 filter length", ErrorCode.ProtocolError); if (dbl != null || nodes.isPresent() || nodes6.isPresent()) { GetPeersResponse resp = new GetPeersResponse(mtid); nodes.ifPresent(l -> resp.setNodes(l)); nodes6.ifPresent(l -> resp.setNodes(l)); resp.setPeerItems(dbl); resp.setToken(token); resp.setScrapePeers(peerFilter); resp.setScrapeSeeds(seedFilter); msg = resp; break; } throw new MessageException("Neither nodes nor values in get_peers response",ErrorCode.ProtocolError); case UNKNOWN: msg = new UnknownTypeResponse(mtid); break; default: throw new RuntimeException("should not happen!!!"); } if(ip != null) { InetSocketAddress addr = AddressUtils.unpackAddress(ip); msg.setPublicIP(addr); if(addr == null) DHT.logError("could not decode IP: " + prettyPrint(map)); } msg.setID(id); return msg; } private Optional<NodeList> extractNodes(Map<String, Object> args, String key, DHTtype nodesType) throws MessageException { byte[] raw = typedGet(args, key, byte[].class).orElse(null); if(raw == null) return Optional.empty(); if(raw.length % nodesType.NODES_ENTRY_LENGTH != 0) throw new MessageException("expected "+key+" length to be a multiple of "+nodesType.NODES_ENTRY_LENGTH+", received "+raw.length, ErrorCode.ProtocolError); return Optional.of(NodeList.fromBuffer(ByteBuffer.wrap(raw), nodesType == DHTtype.IPV4_DHT ? AddressType.V4 : AddressType.V6)); } /** * @param map * @return */ private MessageBase parseRequest (Map<String, Object> map, Function<byte[], Optional<Method>> transactionIdMapper, DHTtype type) throws MessageException { Object rawRequestMethod = map.get(Type.REQ_MSG.getRPCTypeName()); Map<String, Object> args = typedGet(map, Type.REQ_MSG.innerKey(), Map.class).orElseThrow(() -> new MessageException("expected a bencoded dictionary under key " + Type.REQ_MSG.innerKey(), ErrorCode.ProtocolError)); if (rawRequestMethod == null || args == null) return null; byte[] mtid = Functional.typedGet(map, MessageBase.TRANSACTION_KEY, byte[].class).filter(tid -> tid.length > 0).orElseThrow(() -> new MessageException("missing or zero-length transaction ID in request", ErrorCode.ProtocolError)); byte[] hash = Functional.typedGet(args,"id", byte[].class).filter(id -> id.length == Key.SHA1_HASH_LENGTH).orElseThrow(() -> new MessageException("missing or invalid node ID", ErrorCode.ProtocolError)); Key id = new Key(hash); MessageBase msg = null; String requestMethod = getStringFromBytes((byte[]) rawRequestMethod, true); Method method = Optional.ofNullable(MessageBase.messageMethod.get(requestMethod)).orElse(Method.UNKNOWN); switch(method) { case PING: msg = new PingRequest(); break; case FIND_NODE: case GET_PEERS: case GET: case UNKNOWN: hash = Stream.of(args.get("target"), args.get("info_hash")).filter(byte[].class::isInstance).findFirst().map(byte[].class::cast).orElseThrow(() -> { if(method == Method.UNKNOWN) return new MessageException("Received unknown Message Type: " + requestMethod,ErrorCode.MethodUnknown); return new MessageException("missing/invalid target key in request",ErrorCode.ProtocolError); }); if (hash.length != Key.SHA1_HASH_LENGTH) { throw new MessageException("invalid target key in request",ErrorCode.ProtocolError); } Key target = new Key(hash); AbstractLookupRequest req; switch(method) { case FIND_NODE: req = new FindNodeRequest(target); break; case GET_PEERS: req = new GetPeersRequest(target); break; case GET: req = new GetRequest(target); break; default: req = new UnknownTypeRequest(target); } @SuppressWarnings("unchecked") List<byte[]> explicitWants = Optional.ofNullable(args.get("want")).map(castOrThrow(List.class, w -> new MessageException("invalid 'want' parameter, expected a list of byte-strings"))).orElse(null); if(explicitWants != null) req.decodeWant(explicitWants); else { req.setWant4(type == DHTtype.IPV4_DHT); req.setWant6(type == DHTtype.IPV6_DHT); } if (req instanceof GetPeersRequest) { GetPeersRequest peerReq = (GetPeersRequest) req; peerReq.setNoSeeds(Long.valueOf(1).equals(args.get("noseed"))); peerReq.setScrape(Long.valueOf(1).equals(args.get("scrape"))); } if(req instanceof GetRequest) { GetRequest getReq = (GetRequest) req; typedGet(args, "seq", Long.class).ifPresent(seq -> { getReq.setSeq(seq); }); } msg = req; break; case PUT: PathMatcher m = new PathMatcher(Type.REQ_MSG.innerKey(),"v"); Tokenizer t = new Tokenizer(); m.tokenizer(t); ByteBuffer rawVal = m.match(raw); msg = tapThrow(new PutRequest(), put -> { if(rawVal != null) put.setValue(rawVal); put.pubkey = Functional.typedGet(args, "k", byte[].class).orElse(null); put.sequenceNumber = Functional.typedGet(args, "seq", Long.class).orElse(-1L); put.expectedSequenceNumber = Functional.typedGet(args, "cas", Long.class).orElse(-1L); put.salt = Functional.typedGet(args, "salt", byte[].class).orElse(null); put.signature = Functional.typedGet(args, "sig", byte[].class).orElse(null); put.token = Functional.typedGet(args, "token", byte[].class).filter(b -> b.length > 0).orElseThrow(() -> new MessageException("missing or invalid token in PUT request")); put.validate(); }); break; case ANNOUNCE_PEER: hash = Functional.typedGet(args, "info_hash", byte[].class).filter(b -> b.length == Key.SHA1_HASH_LENGTH).orElse(null); int port = Functional.typedGet(args, "port", Long.class).filter(p -> p > 0 && p <= 65535).orElse(0L).intValue(); byte[] token = Functional.typedGet(args, "token", byte[].class).orElse(null); boolean isSeed = Long.valueOf(1).equals(args.get("seed")); if(hash == null || token == null || port == 0) throw new MessageException("missing or invalid mandatory arguments (info_hash, port, token) for announce", ErrorCode.ProtocolError); if(token.length == 0) throw new MessageException("zero-length token in announce_peer request. see BEP33 for reasons why tokens might not have been issued by get_peers response", ErrorCode.ProtocolError); Key infoHash = new Key(hash); msg = tap(new AnnounceRequest(infoHash, port, token), ar -> { ar.setSeed(isSeed); typedGet(args, "name", byte[].class).ifPresent(b -> ar.setName(ByteBuffer.wrap(b))); }); break; } if (msg != null) { msg.setMTID(mtid); msg.setID(id); } return msg; } private static String getStringFromBytes (byte[] bytes, boolean preserveBytes) { if (bytes == null) { return null; } try { return new String(bytes, preserveBytes ? StandardCharsets.ISO_8859_1 : StandardCharsets.UTF_8); } catch (Exception e) { DHT.log(e, LogLevel.Verbose); return null; } } private static String getStringFromBytes (byte[] bytes) { return getStringFromBytes(bytes, false); } }