package io.fathom.cloud.compute.networks;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Arrays;
import com.google.common.base.Strings;
import com.google.common.net.InetAddresses;
public abstract class IpRange {
protected final InetAddress address;
protected final int netmaskLength;
protected final byte[] mask;
protected final byte[] masked;
public static IpRange build(byte[] addr, int prefixLength) {
InetAddress address;
try {
address = InetAddress.getByAddress(addr);
} catch (UnknownHostException e) {
throw new IllegalStateException("Error building address", e);
}
return build(address, prefixLength);
}
private static IpRange build(InetAddress address, int prefixLength) {
int addressLength = address.getAddress().length * 8;
if (prefixLength == -1) {
prefixLength = addressLength;
}
if (addressLength == 32) {
return new IpV4Range(address, prefixLength);
}
if (addressLength == 128) {
return new IpV6Range((Inet6Address) address, prefixLength);
}
throw new IllegalStateException();
}
public static IpRange parse(String cidr) {
if (Strings.isNullOrEmpty(cidr)) {
return null;
}
int slashPosition = cidr.indexOf('/');
int maskLength = -1;
String addressString;
if (slashPosition != -1) {
addressString = cidr.substring(0, slashPosition);
maskLength = Integer.parseInt(cidr.substring(slashPosition + 1));
} else {
addressString = cidr;
}
InetAddress address;
try {
address = InetAddress.getByName(addressString);
} catch (UnknownHostException e) {
throw new IllegalArgumentException("Cannot resolve address: " + addressString, e);
}
return build(address, maskLength);
}
public IpRange(InetAddress address, int netmaskLength) {
this.address = address;
this.netmaskLength = netmaskLength;
if (address instanceof Inet4Address) {
// TODO: Cache?
this.mask = buildMask(new byte[4], netmaskLength);
} else if (address instanceof Inet6Address) {
// TODO: Cache
this.mask = buildMask(new byte[16], netmaskLength);
} else {
throw new IllegalArgumentException();
}
this.masked = address.getAddress();
for (int i = 0; i < masked.length; i++) {
masked[i] &= mask[i];
}
}
// public abstract String getNetmask();
public InetAddress getAddress() {
return address;
}
// public Iterable<InetAddress> all() {
// return new Iterable<InetAddress>() {
// @Override
// public Iterator<InetAddress> iterator() {
// return new SimpleIterator<InetAddress>() {
// @Override
// protected InetAddress getNext(InetAddress current) {
// if (current == null) {
// return IpRange.this.getAddressInRange(0);
// } else {
// return IpRange.this.getNext(current);
// }
// }
// };
// }
// };
// }
//
// public byte[] getMasked() {
// byte[] addressBytes = address.getAddress();
// applyMask(addressBytes, netmaskLength);
// return addressBytes;
// }
//
// protected static void applyMask(byte[] addr, int length) {
// // TODO: This is slow (ish)
// for (int i = length; i < addr.length * 8; i++) {
// int pos = i / 8;
// if ((i % 8) == 0) {
// addr[pos] = 0;
// i += 7;
// } else {
// int bit = 7 - (i % 8);
// int bitMask = 1 << bit;
// addr[pos] &= ~bitMask;
// }
// }
// }
//
protected byte[] getNetmaskBytes() {
return mask;
}
private static byte[] buildMask(byte[] addr, int length) {
// TOOD: This is kind of slow... implement 8 byte fast-jump?
for (int i = 0; i < length; i++) {
int pos = i / 8;
int bit = 7 - (i % 8);
int bitMask = 1 << bit;
addr[pos] |= bitMask;
}
return addr;
}
// static void addBit(byte[] data, int addBitIndex) {
// if (addBitIndex < 0 || addBitIndex >= (data.length * 8)) {
// throw new IllegalArgumentException();
// }
//
// int pos = addBitIndex / 8;
// int bit = 7 - addBitIndex % 8;
// int bitMask = 1 << bit;
//
// int v = data[pos] & 0xff;
// v += bitMask;
//
// data[pos] = (byte) (v & 0xff);
// if (v > 0xff) {
// v >>= 8;
// if (v == 1) {
// addBit(data, addBitIndex - 8);
// } else {
// throw new UnsupportedOperationException();
// }
// }
// }
//
// protected InetAddress getNext(InetAddress current) {
// byte[] addr = current.getAddress();
// addBit(addr, addr.length * 8 - 1);
//
// InetAddress next = toAddress(addr);
// if (!isInRange(next)) {
// return null;
// }
//
// return next;
// }
//
// protected InetAddress toAddress(byte[] addr) {
// try {
// return InetAddress.getByAddress(addr);
// } catch (UnknownHostException e) {
// throw new IllegalStateException(
// "Error building address from bytes", e);
// }
// }
@Override
public String toString() {
return getClass().getSimpleName() + ":" + toCidr();
}
public String toCidr() {
return InetAddresses.toAddrString(address) + "/" + netmaskLength;
}
public int getNetmaskLength() {
return netmaskLength;
}
// public InetAddress getGatewayAddress() {
// return getAddressInRange(1);
// }
//
// protected InetAddress getAddressInRange(int offset) {
// byte[] addr = getMasked();
// if (offset == 0) {
// } else if (offset == 1) {
// addBit(addr, 8 * addr.length - 1);
// } else {
// // TODO: Not implemented
// throw new UnsupportedOperationException();
// }
//
// return toAddress(addr);
// }
public boolean isIpv6() {
return this instanceof IpV6Range;
}
public boolean isIpv4() {
return this instanceof IpV4Range;
}
public boolean contains(InetAddress address) {
byte[] addressBytes = address.getAddress();
for (int i = 0; i < addressBytes.length; i++) {
addressBytes[i] &= mask[i];
}
return Arrays.equals(masked, addressBytes);
}
@Override
public int hashCode() {
String cidr = toCidr();
final int prime = 31;
int result = 1;
result = prime * result + ((cidr == null) ? 0 : cidr.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
IpRange other = (IpRange) obj;
String cidr = toCidr();
if (cidr == null) {
if (other.toCidr() != null) {
return false;
}
} else if (!cidr.equals(other.toCidr())) {
return false;
}
return true;
}
}