package org.zstack.network.service.portforwarding;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.transaction.annotation.Transactional;
import org.zstack.core.cloudbus.CloudBus;
import org.zstack.core.db.DatabaseFacade;
import org.zstack.core.db.SimpleQuery;
import org.zstack.core.db.SimpleQuery.Op;
import org.zstack.core.errorcode.ErrorFacade;
import org.zstack.header.errorcode.SysErrors;
import org.zstack.header.apimediator.ApiMessageInterceptionException;
import org.zstack.header.apimediator.ApiMessageInterceptor;
import org.zstack.header.apimediator.StopRoutingException;
import org.zstack.header.message.APIMessage;
import org.zstack.header.vm.VmInstanceVO;
import org.zstack.header.vm.VmNicVO;
import org.zstack.header.vm.VmNicVO_;
import org.zstack.network.service.vip.VipVO;
import org.zstack.network.service.vip.VipVO_;
import org.zstack.utils.network.NetworkUtils;
import static org.zstack.core.Platform.argerr;
import static org.zstack.core.Platform.operr;
import javax.persistence.Tuple;
import javax.persistence.TypedQuery;
import java.util.List;
import java.util.concurrent.Callable;
/**
*/
public class PortForwardingApiInterceptor implements ApiMessageInterceptor {
@Autowired
private DatabaseFacade dbf;
@Autowired
private CloudBus bus;
@Autowired
private ErrorFacade errf;
@Override
public APIMessage intercept(APIMessage msg) throws ApiMessageInterceptionException {
if (msg instanceof APIDeletePortForwardingRuleMsg) {
validate((APIDeletePortForwardingRuleMsg) msg);
} else if (msg instanceof APICreatePortForwardingRuleMsg) {
validate((APICreatePortForwardingRuleMsg)msg);
} else if (msg instanceof APIAttachPortForwardingRuleMsg) {
validate((APIAttachPortForwardingRuleMsg) msg);
} else if (msg instanceof APIDetachPortForwardingRuleMsg) {
validate((APIDetachPortForwardingRuleMsg) msg);
} else if (msg instanceof APIGetPortForwardingAttachableVmNicsMsg) {
validate((APIGetPortForwardingAttachableVmNicsMsg) msg);
}
return msg;
}
private void validate(APIGetPortForwardingAttachableVmNicsMsg msg) {
SimpleQuery<PortForwardingRuleVO> q = dbf.createQuery(PortForwardingRuleVO.class);
q.select(PortForwardingRuleVO_.state, PortForwardingRuleVO_.vmNicUuid);
q.add(PortForwardingRuleVO_.uuid, Op.EQ, msg.getRuleUuid());
Tuple t = q.findTuple();
PortForwardingRuleState state = t.get(0, PortForwardingRuleState.class);
if (state != PortForwardingRuleState.Enabled) {
throw new ApiMessageInterceptionException(operr("Port forwarding rule[uuid:%s] is not in state of Enabled, current state is %s", msg.getRuleUuid(), state));
}
String vmNicUuid = t.get(1, String.class);
if (vmNicUuid != null) {
return ;
}
}
private void validate(APIDetachPortForwardingRuleMsg msg) {
SimpleQuery<PortForwardingRuleVO> q = dbf.createQuery(PortForwardingRuleVO.class);
q.select(PortForwardingRuleVO_.vmNicUuid);
q.add(PortForwardingRuleVO_.uuid, Op.EQ, msg.getUuid());
String vmNicUuid = q.findValue();
if (vmNicUuid == null) {
throw new ApiMessageInterceptionException(operr("port forwarding rule rule[uuid:%s] has not been attached to any vm nic, can't detach", msg.getUuid()));
}
}
private void validate(final APIAttachPortForwardingRuleMsg msg) {
SimpleQuery<PortForwardingRuleVO> q = dbf.createQuery(PortForwardingRuleVO.class);
q.select(PortForwardingRuleVO_.vmNicUuid, PortForwardingRuleVO_.state);
q.add(PortForwardingRuleVO_.uuid, Op.EQ, msg.getRuleUuid());
Tuple t = q.findTuple();
String vmNicUuid = t.get(0, String.class);
if (vmNicUuid != null) {
throw new ApiMessageInterceptionException(operr("port forwarding rule[uuid:%s] has been attached to vm nic[uuid:%s], can't attach again", msg.getRuleUuid(), vmNicUuid));
}
PortForwardingRuleState state = t.get(1, PortForwardingRuleState.class);
if (state != PortForwardingRuleState.Enabled) {
throw new ApiMessageInterceptionException(operr("port forwarding rule[uuid:%s] is not in state of Enabled, current state is %s. A rule can only be attached when its state is Enabled", msg.getRuleUuid(), state));
}
VipVO vip = new Callable<VipVO>() {
@Override
@Transactional(readOnly = true)
public VipVO call() {
String sql = "select vip from VipVO vip, PortForwardingRuleVO pf where vip.uuid = pf.vipUuid and pf.uuid = :pfUuid";
TypedQuery<VipVO> q = dbf.getEntityManager().createQuery(sql, VipVO.class);
q.setParameter("pfUuid", msg.getRuleUuid());
return q.getSingleResult();
}
}.call();
SimpleQuery<VmNicVO> vq = dbf.createQuery(VmNicVO.class);
vq.select(VmNicVO_.l3NetworkUuid);
vq.add(VmNicVO_.uuid, Op.EQ, msg.getVmNicUuid());
String guestL3Uuid = vq.findValue();
if (guestL3Uuid.equals(vip.getL3NetworkUuid())) {
throw new ApiMessageInterceptionException(argerr("guest l3Network of vm nic[uuid:%s] and vip l3Network of port forwarding rule[uuid:%s] are the same network",
msg.getVmNicUuid(), msg.getRuleUuid()));
}
if (vip.getPeerL3NetworkUuid() != null && !vip.getPeerL3NetworkUuid().equals(guestL3Uuid)) {
throw new ApiMessageInterceptionException(argerr("the VIP[uuid:%s] is already bound the a guest L3 network[uuid:%s], but the VM nic[uuid:%s]" +
" is on another guest L3 network[uuid:%s]", vip.getUuid(), vip.getPeerL3NetworkUuid(),
msg.getVmNicUuid(), guestL3Uuid));
}
checkIfAnotherVip(vip.getUuid(), msg.getVmNicUuid());
}
private boolean rangeOverlap(int s1, int e1, int s2, int e2) {
return (s1 >= s2 && s1 <= e2) || (s1 <= s2 && s2 <= e1);
}
private void validate(APICreatePortForwardingRuleMsg msg) {
if (msg.getVipPortEnd() == null) {
msg.setVipPortEnd(msg.getVipPortStart());
}
if (msg.getPrivatePortStart() == null) {
msg.setPrivatePortStart(msg.getVipPortStart());
}
if (msg.getPrivatePortEnd() == null) {
msg.setPrivatePortEnd(msg.getVipPortEnd());
}
int vipStart = Math.min(msg.getVipPortStart(), msg.getVipPortEnd());
int vipEnd = Math.max(msg.getVipPortStart(), msg.getVipPortEnd());
msg.setVipPortStart(vipStart);
msg.setVipPortEnd(vipEnd);
int privateStart = Math.min(msg.getPrivatePortStart(), msg.getPrivatePortEnd());
int privateEnd = Math.max(msg.getPrivatePortStart(), msg.getPrivatePortEnd());
msg.setPrivatePortStart(privateStart);
msg.setPrivatePortEnd(privateEnd);
if (!msg.getVipPortStart().equals(msg.getVipPortEnd())) {
// it's a port range
if (msg.getVipPortEnd() - msg.getVipPortStart() != msg.getPrivatePortEnd() - msg.getPrivatePortStart()) {
throw new ApiMessageInterceptionException(argerr("for range port forwarding, the port range size must match; vip range[%s, %s]'s size doesn't match range[%s, %s]'s size",
msg.getVipPortStart(), msg.getVipPortEnd(), msg.getPrivatePortStart(), msg.getPrivatePortEnd()));
}
}
if (msg.getAllowedCidr() != null) {
if (!NetworkUtils.isCidr(msg.getAllowedCidr())) {
throw new ApiMessageInterceptionException(argerr("invalid CIDR[%s]", msg.getAllowedCidr()));
}
}
SimpleQuery<PortForwardingRuleVO> q = dbf.createQuery(PortForwardingRuleVO.class);
q.add(PortForwardingRuleVO_.vipUuid, Op.EQ, msg.getVipUuid());
List<PortForwardingRuleVO> vos = q.list();
for (PortForwardingRuleVO vo : vos) {
if (vo.getProtocolType().toString().equals(msg.getProtocolType())) {
if (rangeOverlap(vipStart, vipEnd, vo.getVipPortStart(), vo.getVipPortEnd())) {
throw new ApiMessageInterceptionException(argerr("vip port range[vipStartPort:%s, vipEndPort:%s] overlaps with rule[uuid:%s, vipStartPort:%s, vipEndPort:%s]",
vipStart, vipEnd, vo.getUuid(), vo.getVipPortStart(), vo.getVipPortEnd()));
}
}
}
if (msg.getVmNicUuid() != null) {
SimpleQuery<VipVO> vq = dbf.createQuery(VipVO.class);
vq.select(VipVO_.l3NetworkUuid, VipVO_.peerL3NetworkUuid);
vq.add(VipVO_.uuid, Op.EQ, msg.getVipUuid());
Tuple t = vq.findTuple();
String vipL3Uuid = t.get(0, String.class);
String peerL3Uuid = t.get(1, String.class);
SimpleQuery<VmNicVO> nicq = dbf.createQuery(VmNicVO.class);
nicq.select(VmNicVO_.l3NetworkUuid);
nicq.add(VmNicVO_.uuid, Op.EQ, msg.getVmNicUuid());
String nicL3Uuid = nicq.findValue();
if (nicL3Uuid.equals(vipL3Uuid)) {
throw new ApiMessageInterceptionException(argerr("guest l3Network of vm nic[uuid:%s] and vip l3Network of vip[uuid: %s] are the same network", msg.getVmNicUuid(), msg.getVipUuid()));
}
if (peerL3Uuid != null && !peerL3Uuid.equals(nicL3Uuid)) {
throw new ApiMessageInterceptionException(argerr("the VIP[uuid:%s] is already bound the a guest L3 network[uuid:%s], but the VM nic[uuid:%s]" +
" is on another guest L3 network[uuid:%s]", msg.getVipUuid(), peerL3Uuid, msg.getVmNicUuid(), nicL3Uuid));
}
checkIfAnotherVip(msg.getVipUuid(), msg.getVmNicUuid());
}
}
@Transactional(readOnly = true)
private void checkIfAnotherVip(String vipUuid, String vmNicUuid) {
String sql = "select nic.uuid from VmNicVO nic where nic.vmInstanceUuid = (select n.vmInstanceUuid from VmNicVO n where" +
" n.uuid = :nicUuid)";
TypedQuery<String> q = dbf.getEntityManager().createQuery(sql, String.class);
q.setParameter("nicUuid", vmNicUuid);
List<String> nicUuids = q.getResultList();
sql = "select count(*) from VmNicVO nic, PortForwardingRuleVO pf where nic.uuid = pf.vmNicUuid and pf.vipUuid != :vipUuid and nic.uuid in (:nicUuids)";
TypedQuery<Long> lq = dbf.getEntityManager().createQuery(sql, Long.class);
lq.setParameter("vipUuid", vipUuid);
lq.setParameter("nicUuids", nicUuids);
long count = lq.getSingleResult();
if (count > 0) {
sql = "select vm from VmInstanceVO vm, VmNicVO nic where vm.uuid = nic.vmInstanceUuid and nic.uuid = :nicUuid";
TypedQuery<VmInstanceVO> vq = dbf.getEntityManager().createQuery(sql, VmInstanceVO.class);
vq.setParameter("nicUuid", vmNicUuid);
VmInstanceVO vm = vq.getSingleResult();
throw new ApiMessageInterceptionException(operr("the VM[name:%s uuid:%s] already has port forwarding rules that have different VIPs than the one[uuid:%s]",
vm.getName(), vm.getUuid(), vipUuid));
}
}
private void validate(APIDeletePortForwardingRuleMsg msg) {
if (!dbf.isExist(msg.getUuid(), PortForwardingRuleVO.class)) {
bus.publish(new APIDeletePortForwardingRuleEvent(msg.getId()));
throw new StopRoutingException();
}
}
}