package com.intrbiz.bergamot.net.raw.model; import java.net.Inet4Address; import java.net.InetAddress; import java.net.UnknownHostException; import java.nio.ByteBuffer; import com.intrbiz.bergamot.net.raw.model.payload.ICMPPacket; public class IPPacket { public static final byte IP_PROTO_ICMP = 1; public static final byte IP_PROTO_UDP = 17; private byte ipVersion; private byte ipHeaderLength; private byte dscp; private byte ecn; private short totalLength; private short id; private short flags; private short fragmentOffset; private byte ttl; private byte protocol; private short headerChecksum; private byte[] sourceIPAddress; private byte[] destinationIPAddress; private IPPayload payload; public IPPacket() { super(); this.ipVersion = 4; this.ipHeaderLength = 5; this.ttl = 64; } public IPPacket(byte protocol, InetAddress source, InetAddress destination, IPPayload payload) { super(); if (! (source instanceof Inet4Address && destination instanceof Inet4Address)) throw new IllegalArgumentException("Source and destination must be IPv4 addresses!"); this.ipVersion = 4; this.ipHeaderLength = 5; this.ttl = 64; this.protocol = protocol; this.sourceIPAddress = source.getAddress(); this.destinationIPAddress = source.getAddress(); this.payload = payload; } public IPPacket(ByteBuffer from) { byte b; short s; int start = from.position(); // verify the checksum this.headerChecksum = this.verifyChecksum(from); // 0 - version and ihl b = from.get(); this.ipHeaderLength = (byte) (b & 0xF); this.ipVersion = (byte) (b >> 4); // 1 - dscp and ecn b = from.get(); this.dscp = (byte) (b >> 2); this.ecn = (byte) (b & 0x3); // 2 - length this.totalLength = from.getShort(); // 4 - id this.id = from.getShort(); // 6 - flags and frag offset s = from.getShort(); this.flags = (short) (s >> 13); this.fragmentOffset = (short) (s & 0x1FFF); // 8 - ttl this.ttl = from.get(); // 9 - protocol this.protocol = from.get(); // 10 - checksum from.getShort(); // 12 - src this.sourceIPAddress = new byte[4]; from.get(this.sourceIPAddress); // 16 - dest this.destinationIPAddress = new byte[4]; from.get(this.destinationIPAddress); // options if (this.ipHeaderLength > 5) { // seek to the end of the options from.position(start + (this.ipHeaderLength * 4)); } // payload if (this.protocol == IP_PROTO_ICMP) { this.payload = new ICMPPacket(from); } } public int computeLength() { return 20 + (this.payload == null ? 0 : this.payload.computeLength()); } public void pack(ByteBuffer to) { int start = to.position(); // compute the total packet length this.totalLength = (short) (this.computeLength() & 0xFFFF); // version and ihl to.put((byte) ((this.ipVersion << 4) & (this.ipHeaderLength & 0xF))); // dscp and ecn to.put((byte) ((this.dscp << 2) & (this.ecn & 0x3))); // total length to.putShort(this.totalLength); // id to.putShort(this.id); // flags and fragment offset to.putShort((short) ((this.flags << 13) & (this.fragmentOffset & 0x1FFF))); // ttl to.put(this.ttl); // protocol to.put(this.protocol); // checksum to.putShort((short) 0); // src ip to.put(this.sourceIPAddress); // dst ip to.put(this.destinationIPAddress); // compute the header checksum int end = to.position(); to.position(start); this.headerChecksum = this.computeChecksum(to, 20); to.putShort(start + 10, this.headerChecksum); to.position(end); // pack the payload this.payload.pack(to); } private short computeChecksum(ByteBuffer buffer, int headerLength) { int start = buffer.position(); int end = start + headerLength; // ensure the checksum is zeroed buffer.putShort(start + 10, (short) 0); // compute int sum = 0; while (buffer.position() < end) { sum += ((int) buffer.getShort()) & 0xFFFF; } sum = (sum >> 16) + (sum & 0xFFFF); sum = (sum >> 16) + (sum & 0xFFFF); // reset to start of the payload buffer.position(start); return (short) ((~sum) & 0xFFFF); } private short verifyChecksum(ByteBuffer buffer) { int start = buffer.position(); int headerLength = (buffer.get(start) & 0xF) * 4; // get the checksum short checksum = buffer.getShort(start + 10); // compute the checksum short computedChecksum = this.computeChecksum(buffer, headerLength); // verify if (checksum != computedChecksum) { throw new RuntimeException("Invalid IP Header Checksum: got=" + checksum + ", expected=" + computedChecksum + ", ihl=" + headerLength); } return checksum; } public byte getIpVersion() { return ipVersion; } public void setIpVersion(byte ipVersion) { this.ipVersion = ipVersion; } public byte getIpHeaderLength() { return ipHeaderLength; } public void setIpHeaderLength(byte ipHeaderLength) { this.ipHeaderLength = ipHeaderLength; } public byte getDscp() { return dscp; } public void setDscp(byte dscp) { this.dscp = dscp; } public byte getEcn() { return ecn; } public void setEcn(byte ecn) { this.ecn = ecn; } public short getTotalLength() { return totalLength; } public void setTotalLength(short totalLength) { this.totalLength = totalLength; } public short getId() { return id; } public void setId(short id) { this.id = id; } public short getFlags() { return flags; } public void setFlags(short flags) { this.flags = flags; } public short getFragmentOffset() { return fragmentOffset; } public void setFragmentOffset(short fragmentOffset) { this.fragmentOffset = fragmentOffset; } public byte getTtl() { return ttl; } public void setTtl(byte ttl) { this.ttl = ttl; } public byte getProtocol() { return protocol; } public boolean isICMP() { return this.protocol == IP_PROTO_ICMP; } public boolean isUDP() { return this.protocol == IP_PROTO_UDP; } public String getProtocolStr() { switch (this.protocol) { case IP_PROTO_ICMP: return "ICMP"; case IP_PROTO_UDP: return "UDP"; } return "UNKNOWN"; } public void setProtocol(byte protocol) { this.protocol = protocol; } public short getHeaderChecksum() { return headerChecksum; } public void setHeaderChecksum(short headerChecksum) { this.headerChecksum = headerChecksum; } public byte[] getSourceIPAddress() { return sourceIPAddress; } public InetAddress getSource() { try { return InetAddress.getByAddress(this.getSourceIPAddress()); } catch (UnknownHostException e) { } return null; } public void setSourceIPAddress(byte[] sourceIPAddress) { this.sourceIPAddress = sourceIPAddress; } public byte[] getDestinationIPAddress() { return destinationIPAddress; } public InetAddress getDestination() { try { return InetAddress.getByAddress(this.getDestinationIPAddress()); } catch (UnknownHostException e) { } return null; } public void setDestinationIPAddress(byte[] destinationIPAddress) { this.destinationIPAddress = destinationIPAddress; } public IPPayload getPayload() { return payload; } public void setPayload(IPPayload payload) { this.payload = payload; } public String toString() { return "ip-packet {\n" + "src: " + this.getSource() + "\n" + "dst: " + this.getDestination() + "\n" + "proto: " + this.getProtocolStr() + "(" + this.getProtocol() + ")\n" + "ttl: " + this.getTtl() + "\n" + "ihl: " + this.getIpHeaderLength() + "\n" + "len: " + this.getTotalLength() + "\n" + this.getPayload() + "}"; } }