/*_########################################################################## _## _## Copyright (C) 2011-2013 Kaito Yamada _## _########################################################################## */ package com.github.kaitoy.sneo.network.protocol; import java.net.Inet4Address; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.TimeoutException; import org.pcap4j.packet.IcmpV4CommonPacket; import org.pcap4j.packet.IpV4Packet; import org.pcap4j.packet.IpV4Rfc791Tos; import org.pcap4j.packet.Packet; import org.pcap4j.packet.SimpleBuilder; import org.pcap4j.packet.TcpPacket; import org.pcap4j.packet.UdpPacket; import org.pcap4j.packet.namednumber.IpNumber; import org.pcap4j.packet.namednumber.IpVersion; import org.pcap4j.util.ByteArrays; import com.github.kaitoy.sneo.network.NetworkInterface; import com.github.kaitoy.sneo.network.NifIpAddress; import com.github.kaitoy.sneo.network.NifIpV4Address; public final class IpV4Helper { public static final Inet4Address UNSPECIFIED_ADDRESS; static { try { UNSPECIFIED_ADDRESS = (Inet4Address)InetAddress.getByName("0.0.0.0"); } catch (UnknownHostException e) { throw new AssertionError("Never get here."); } } private IpV4Helper() { throw new AssertionError(); } public static IpV4RoutingTable newRoutingTable() { return new IpV4RoutingTable(); } public static boolean matchesDestination( Packet packet, Inet4Address addr, Inet4Address subnetmask ) { IpV4Packet ipv4Packet = packet.get(IpV4Packet.class); if (ipv4Packet == null) { throw new IllegalArgumentException(packet.toString()); } Inet4Address dstAddr = ipv4Packet.getHeader().getDstAddr(); if (dstAddr.equals(addr)) { return true; } if (!isSameNetwork(addr, dstAddr, subnetmask)) { return false; } return isBroadcastAddr(dstAddr, subnetmask); } public static boolean matchesDestination(Packet packet, NetworkInterface nif) { for (NifIpAddress nifAddr: nif.getIpAddresses()) { if (nifAddr instanceof NifIpV4Address) { NifIpV4Address nifV4Addr = (NifIpV4Address)nifAddr; Inet4Address mask = getSubnetMaskFrom(nifV4Addr.getPrefixLength()); if (matchesDestination(packet, nifV4Addr.getIpAddr(), mask)) { return true; } } } return false; } public static boolean isBroadcastAddr(Inet4Address addr, Inet4Address subnetmask) { int subnetmaskBitmap = ByteArrays.getInt(subnetmask.getAddress(), 0); int addrBitmap = ByteArrays.getInt(addr.getAddress(), 0); return ~((addrBitmap & ~subnetmaskBitmap) | subnetmaskBitmap) == 0; } public static boolean isNetworkAddr(Inet4Address addr, Inet4Address subnetmask) { int subnetmaskBitmap = ByteArrays.getInt(subnetmask.getAddress(), 0); int addrBitmap = ByteArrays.getInt(addr.getAddress(), 0); return (addrBitmap & ~subnetmaskBitmap) == 0; } public static boolean isSameNetwork( Inet4Address addr1, Inet4Address addr2, Inet4Address subnetmask ) { int addr1Bitmap = ByteArrays.getInt(addr1.getAddress(), 0); int addr2Bitmap = ByteArrays.getInt(addr2.getAddress(), 0); int subnetmaskBitmap = ByteArrays.getInt(subnetmask.getAddress(), 0); return (addr1Bitmap & subnetmaskBitmap) == (addr2Bitmap & subnetmaskBitmap); } public static boolean isSameNetwork( Inet4Address addr1, NetworkInterface nif ) { for (NifIpAddress nifAddr: nif.getIpAddresses()) { if (nifAddr instanceof NifIpV4Address) { NifIpV4Address nifV4Addr = (NifIpV4Address)nifAddr; Inet4Address mask = getSubnetMaskFrom(nifV4Addr.getPrefixLength()); if (isSameNetwork(addr1, nifV4Addr.getIpAddr(), mask)) { return true; } } } return false; } public static Inet4Address getSubnetMaskFrom(int prefixLength) { if (prefixLength < 0 || prefixLength > 32) { throw new IllegalArgumentException( "Invalid prefix length: " + prefixLength ); } byte[] mask = new byte[4]; int byteIdx = 0; for (; byteIdx < prefixLength / 8; byteIdx++) { mask[byteIdx] = (byte)255; } if (!(byteIdx == mask.length)) { int value = 0; int tmp = 128; for (int i = 0; i < prefixLength % 8; i++) { value += tmp; tmp >>= 1; } mask[byteIdx] = (byte)value; } try { return (Inet4Address)InetAddress.getByAddress(mask); } catch (UnknownHostException e) { throw new AssertionError("Never get here"); } } public static int getPrefixLengthFrom(Inet4Address subnetMask) { int length = 0; for (byte b: subnetMask.getAddress()) { for (int mask = 128; mask > 0; mask >>= 1) { if ((b & mask) == 0) { return length; } length++; } } return length; } public static Inet4Address getNextAddress(Inet4Address addr, Inet4Address mask) { if (isBroadcastAddr(addr, mask)) { return null; } byte[] rawAddr = addr.getAddress(); rawAddr[3]++; Inet4Address newAddr; try { newAddr = (Inet4Address)InetAddress.getByAddress(rawAddr); } catch (UnknownHostException e) { throw new AssertionError("Never get here."); } if (isBroadcastAddr(newAddr, mask)) { return null; } else { return newAddr; } } public static Inet4Address getPrevAddress(Inet4Address addr, Inet4Address mask) { if (isNetworkAddr(addr, mask)) { return null; } byte[] rawAddr = addr.getAddress(); rawAddr[3]--; Inet4Address newAddr; try { newAddr = (Inet4Address)InetAddress.getByAddress(rawAddr); } catch (UnknownHostException e) { throw new AssertionError("Never get here."); } if (isNetworkAddr(newAddr, mask)) { return null; } else { return newAddr; } } public static IpV4Packet pack( final Packet payload, Inet4Address src, Inet4Address dst, int ttl, short id ) { IpNumber ipNum; if (payload instanceof UdpPacket) { ipNum = IpNumber.UDP; } else if (payload instanceof IcmpV4CommonPacket) { ipNum = IpNumber.ICMPV4; } else if (payload instanceof TcpPacket) { ipNum = IpNumber.TCP; } else { throw new AssertionError(); } IpV4Packet.Builder builder = new IpV4Packet.Builder(); return builder.version(IpVersion.IPV4) .tos(IpV4Rfc791Tos.newInstance((byte)0)) .identification(id) .ttl((byte)ttl) .protocol(ipNum) .srcAddr(src) .dstAddr(dst) .payloadBuilder(new SimpleBuilder(payload)) .correctChecksumAtBuild(true) .correctLengthAtBuild(true) .build(); } public static IpV4Packet decrementTtl( IpV4Packet packet ) throws TimeoutException { int ttl = packet.getHeader().getTtlAsInt(); if (ttl <= 1) { throw new TimeoutException(); } ttl--; IpV4Packet.Builder b = packet.getBuilder().ttl((byte)ttl).correctChecksumAtBuild(true); return b.build(); } public static Inet4Address getNextHop( Inet4Address dstIpAddr, IpV4RoutingTable ipV4RoutingTable ) { return ipV4RoutingTable.getNextHop(dstIpAddr); } public static class IpV4RoutingTable { private final Map<Inet4Address, IpV4RoutingTableEntry> entries = new HashMap<Inet4Address, IpV4RoutingTableEntry>(); private IpV4RoutingTable() {} public void addRoute( Inet4Address dst, Inet4Address mask, Inet4Address gw, int metric ) { synchronized (entries) { entries.put( dst, new IpV4RoutingTableEntry(dst, mask, gw, metric) ); } } private Inet4Address getNextHop(Inet4Address dst) { Collection<IpV4RoutingTableEntry> values = null; synchronized (entries) { IpV4RoutingTableEntry justMatchedEntry = entries.get(dst); if (justMatchedEntry != null) { return justMatchedEntry.gw; } values = entries.values(); } int dstBitmap = ByteArrays.getInt(dst.getAddress(), 0); IpV4RoutingTableEntry mostMatchedEntry = null; for (IpV4RoutingTableEntry entry: values) { if ( entry.dstBitmap == (dstBitmap & entry.maskBitmap) ) { if (mostMatchedEntry == null) { mostMatchedEntry = entry; } else if (entry.prefixLength > mostMatchedEntry.prefixLength) { mostMatchedEntry = entry; } else if ( entry.prefixLength == mostMatchedEntry.prefixLength && entry.metric < mostMatchedEntry.metric ) { mostMatchedEntry = entry; } } } if (mostMatchedEntry == null) { return null; } return mostMatchedEntry.gw; } public List<IpV4RoutingTableEntry> getEntries() { return new ArrayList<IpV4RoutingTableEntry>(entries.values()); } public final class IpV4RoutingTableEntry { private final Inet4Address dst; private final int dstBitmap; private final Inet4Address mask; private final int maskBitmap; private final int prefixLength; private final Inet4Address gw; private final int metric; private IpV4RoutingTableEntry( Inet4Address dst, Inet4Address mask, Inet4Address gw, int metric ) { this.dst = dst; this.dstBitmap = ByteArrays.getInt(dst.getAddress(), 0); this.mask = mask; this.maskBitmap = ByteArrays.getInt(mask.getAddress(), 0); this.prefixLength = getPrefixLengthFrom(mask); this.gw = gw; this.metric = metric; } @Override public String toString() { StringBuilder sb = new StringBuilder(); sb.append("DST[").append(dst).append("] ") .append("MASK[").append(mask).append("] ") .append("GW[").append(gw).append("] ") .append("METRIC[").append(metric).append("]"); return sb.toString(); } } } }