package org.zstack.network.securitygroup;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.transaction.annotation.Transactional;
import org.zstack.core.cloudbus.CloudBus;
import org.zstack.core.db.DatabaseFacade;
import org.zstack.core.db.SimpleQuery;
import org.zstack.core.db.SimpleQuery.Op;
import org.zstack.core.errorcode.ErrorFacade;
import org.zstack.header.errorcode.SysErrors;
import org.zstack.header.apimediator.ApiMessageInterceptionException;
import org.zstack.header.apimediator.ApiMessageInterceptor;
import org.zstack.header.apimediator.StopRoutingException;
import org.zstack.header.message.APIMessage;
import org.zstack.header.network.service.NetworkServiceL3NetworkRefVO;
import org.zstack.header.network.service.NetworkServiceL3NetworkRefVO_;
import org.zstack.header.vm.VmNicVO;
import org.zstack.header.vm.VmNicVO_;
import org.zstack.network.securitygroup.APIAddSecurityGroupRuleMsg.SecurityGroupRuleAO;
import org.zstack.utils.gson.JSONObjectUtil;
import org.zstack.utils.network.NetworkUtils;
import static org.zstack.core.Platform.argerr;
import static org.zstack.core.Platform.operr;
import javax.persistence.TypedQuery;
import java.util.ArrayList;
import java.util.List;
/**
*/
public class SecurityGroupApiInterceptor implements ApiMessageInterceptor {
@Autowired
private CloudBus bus;
@Autowired
private DatabaseFacade dbf;
@Autowired
private ErrorFacade errf;
@Override
public APIMessage intercept(APIMessage msg) throws ApiMessageInterceptionException {
if (msg instanceof APIAddSecurityGroupRuleMsg) {
validate((APIAddSecurityGroupRuleMsg) msg);
} else if (msg instanceof APIAddVmNicToSecurityGroupMsg) {
validate((APIAddVmNicToSecurityGroupMsg) msg);
} else if (msg instanceof APIAttachSecurityGroupToL3NetworkMsg) {
validate((APIAttachSecurityGroupToL3NetworkMsg) msg);
} else if (msg instanceof APIDeleteSecurityGroupMsg) {
validate((APIDeleteSecurityGroupMsg) msg);
} else if (msg instanceof APIDeleteSecurityGroupRuleMsg) {
validate((APIDeleteSecurityGroupRuleMsg) msg);
} else if (msg instanceof APIDeleteVmNicFromSecurityGroupMsg) {
validate((APIDeleteVmNicFromSecurityGroupMsg) msg);
} else if (msg instanceof APIDetachSecurityGroupFromL3NetworkMsg) {
validate((APIDetachSecurityGroupFromL3NetworkMsg) msg);
}
return msg;
}
private void validate(APIDetachSecurityGroupFromL3NetworkMsg msg) {
SimpleQuery<SecurityGroupL3NetworkRefVO> q = dbf.createQuery(SecurityGroupL3NetworkRefVO.class);
q.add(SecurityGroupL3NetworkRefVO_.l3NetworkUuid, Op.EQ, msg.getL3NetworkUuid());
q.add(SecurityGroupL3NetworkRefVO_.securityGroupUuid, Op.EQ, msg.getSecurityGroupUuid());
if (!q.isExists()) {
throw new ApiMessageInterceptionException(operr("security group[uuid:%s] has not attached to l3Network[uuid:%s], can't detach",
msg.getSecurityGroupUuid(), msg.getL3NetworkUuid()));
}
}
private void validate(APIDeleteVmNicFromSecurityGroupMsg msg) {
SimpleQuery<VmNicSecurityGroupRefVO> q = dbf.createQuery(VmNicSecurityGroupRefVO.class);
q.select(VmNicSecurityGroupRefVO_.vmNicUuid);
q.add(VmNicSecurityGroupRefVO_.vmNicUuid, Op.IN, msg.getVmNicUuids());
q.add(VmNicSecurityGroupRefVO_.securityGroupUuid, Op.EQ, msg.getSecurityGroupUuid());
List<String> vmNicUuids = q.listValue();
if (vmNicUuids.isEmpty()) {
APIDeleteVmNicFromSecurityGroupEvent evt = new APIDeleteVmNicFromSecurityGroupEvent(msg.getId());
bus.publish(evt);
throw new StopRoutingException();
}
msg.setVmNicUuids(vmNicUuids);
}
private void validate(APIDeleteSecurityGroupRuleMsg msg) {
SimpleQuery<SecurityGroupRuleVO> q = dbf.createQuery(SecurityGroupRuleVO.class);
q.select(SecurityGroupRuleVO_.uuid);
q.add(SecurityGroupRuleVO_.uuid, Op.IN, msg.getRuleUuids());
List<String> uuids = q.listValue();
uuids.retainAll(msg.getRuleUuids());
if (uuids.isEmpty()) {
APIDeleteSecurityGroupRuleEvent evt = new APIDeleteSecurityGroupRuleEvent(msg.getId());
bus.publish(evt);
throw new StopRoutingException();
}
msg.setRuleUuids(uuids);
}
private void validate(APIDeleteSecurityGroupMsg msg) {
if (!dbf.isExist(msg.getUuid(), SecurityGroupVO.class)) {
APIDeleteSecurityGroupEvent evt = new APIDeleteSecurityGroupEvent(msg.getId());
bus.publish(evt);
throw new StopRoutingException();
}
}
private void validate(APIAttachSecurityGroupToL3NetworkMsg msg) {
SimpleQuery<SecurityGroupL3NetworkRefVO> q = dbf.createQuery(SecurityGroupL3NetworkRefVO.class);
q.add(SecurityGroupL3NetworkRefVO_.l3NetworkUuid, Op.EQ, msg.getL3NetworkUuid());
q.add(SecurityGroupL3NetworkRefVO_.securityGroupUuid, Op.EQ, msg.getSecurityGroupUuid());
if (q.isExists()) {
throw new ApiMessageInterceptionException(operr("security group[uuid:%s] has attached to l3Network[uuid:%s], can't attach again",
msg.getSecurityGroupUuid(), msg.getL3NetworkUuid()));
}
SimpleQuery<NetworkServiceL3NetworkRefVO> nq = dbf.createQuery(NetworkServiceL3NetworkRefVO.class);
nq.add(NetworkServiceL3NetworkRefVO_.l3NetworkUuid, Op.EQ, msg.getL3NetworkUuid());
nq.add(NetworkServiceL3NetworkRefVO_.networkServiceType, Op.EQ, SecurityGroupConstant.SECURITY_GROUP_NETWORK_SERVICE_TYPE);
if (!nq.isExists()) {
throw new ApiMessageInterceptionException(argerr("the L3 network[uuid:%s] doesn't have the network service type[%s] enabled", msg.getL3NetworkUuid(), SecurityGroupConstant.SECURITY_GROUP_NETWORK_SERVICE_TYPE));
}
}
private void validate(APIAddVmNicToSecurityGroupMsg msg) {
SimpleQuery<VmNicVO> q = dbf.createQuery(VmNicVO.class);
q.select(VmNicVO_.uuid);
q.add(VmNicVO_.uuid, Op.IN, msg.getVmNicUuids());
List<String> uuids = q.listValue();
if (!uuids.containsAll(msg.getVmNicUuids())) {
msg.getVmNicUuids().removeAll(uuids);
throw new ApiMessageInterceptionException(errf.instantiateErrorCode(SysErrors.RESOURCE_NOT_FOUND,
String.format("cannot find vm nics[uuids:%s]", msg.getVmNicUuids())
));
}
checkIfVmNicFromAttachedL3Networks(msg.getSecurityGroupUuid(), uuids);
msg.setVmNicUuids(uuids);
}
@Transactional(readOnly = true)
private void checkIfVmNicFromAttachedL3Networks(String securityGroupUuid, List<String> uuids) {
String sql = "select nic.uuid from SecurityGroupL3NetworkRefVO ref, VmNicVO nic where ref.l3NetworkUuid = nic.l3NetworkUuid" +
" and ref.securityGroupUuid = :sgUuid and nic.uuid in (:nicUuids)";
TypedQuery<String> q = dbf.getEntityManager().createQuery(sql, String.class);
q.setParameter("nicUuids", uuids);
q.setParameter("sgUuid", securityGroupUuid);
List<String> nicUuids = q.getResultList();
List<String> wrongUuids = new ArrayList<String>();
for (String uuid : uuids) {
if (!nicUuids.contains(uuid)) {
wrongUuids.add(uuid);
}
}
if (!wrongUuids.isEmpty()) {
throw new ApiMessageInterceptionException(argerr("VM nics[uuids:%s] are not on L3 networks that have been attached to the security group[uuid:%s]",
wrongUuids, securityGroupUuid));
}
}
private boolean checkSecurityGroupRuleEqual(SecurityGroupRuleAO rule1, SecurityGroupRuleAO rule2) {
if (rule1 == rule2) {
return true;
}
if (rule1.getStartPort().equals(rule2.getStartPort()) &&
rule1.getEndPort().equals(rule2.getEndPort()) &&
rule1.getProtocol().equals(rule2.getProtocol()) &&
rule1.getType().equals(rule2.getType())) {
//
if (rule1.getAllowedCidr() == null) {
if (rule2.getAllowedCidr() == null ||
rule2.getAllowedCidr().equals("") ||
rule2.getAllowedCidr().equals(SecurityGroupConstant.WORLD_OPEN_CIDR)) {
return true;
}
}
if (rule2.getAllowedCidr() == null) {
if (rule1.getAllowedCidr() == null ||
rule1.getAllowedCidr().equals("") ||
rule1.getAllowedCidr().equals(SecurityGroupConstant.WORLD_OPEN_CIDR)) {
return true;
}
}
if (rule1.getAllowedCidr().equals(rule2.getAllowedCidr())) {
return true;
}
}
return false;
}
private void validate(APIAddSecurityGroupRuleMsg msg) {
// Basic check
for (SecurityGroupRuleAO ao : msg.getRules()) {
if (ao.getType() == null) {
throw new ApiMessageInterceptionException(argerr("rule type can not be null. rule dump: %s", JSONObjectUtil.toJsonString(ao)));
}
if (!ao.getType().equals(SecurityGroupRuleType.Egress.toString()) &&
!ao.getType().equals(SecurityGroupRuleType.Ingress.toString())) {
throw new ApiMessageInterceptionException(argerr("unknown rule type[%s], rule can only be Ingress/Egress. rule dump: %s",
ao.getType(), JSONObjectUtil.toJsonString(ao)));
}
if (ao.getProtocol() == null) {
throw new ApiMessageInterceptionException(argerr("protocol can not be null. rule dump: %s", JSONObjectUtil.toJsonString(ao)));
}
try {
SecurityGroupRuleProtocolType.valueOf(ao.getProtocol());
} catch (Exception e) {
throw new ApiMessageInterceptionException(argerr("invalid protocol[%s]. Valid protocols are [TCP, UDP, ICMP]. rule dump: %s",
ao.getProtocol(), JSONObjectUtil.toJsonString(ao)));
}
if (ao.getStartPort() == null) {
throw new ApiMessageInterceptionException(argerr("startPort can not be null. rule dump: %s", JSONObjectUtil.toJsonString(ao)));
}
if (SecurityGroupRuleProtocolType.ICMP.toString().equals(ao.getProtocol())) {
if (ao.getStartPort() < -1 || ao.getStartPort() > 255) {
throw new ApiMessageInterceptionException(argerr("invalid ICMP type[%s]. Valid type is [-1, 255]. rule dump: %s",
ao.getStartPort(), JSONObjectUtil.toJsonString(ao)));
}
} else {
if (ao.getStartPort() < 0 || ao.getStartPort() > 65535) {
throw new ApiMessageInterceptionException(argerr("invalid startPort[%s]. Valid range is [0, 65535]. rule dump: %s",
ao.getStartPort(), JSONObjectUtil.toJsonString(ao)));
}
}
if (ao.getEndPort() == null) {
ao.setEndPort(ao.getStartPort());
}
if (SecurityGroupRuleProtocolType.ICMP.toString().equals(ao.getProtocol())) {
if (ao.getEndPort() < -1 || ao.getEndPort() > 3) {
throw new ApiMessageInterceptionException(argerr("invalid ICMP code[%s]. Valid range is [-1, 3]. rule dump: %s",
ao.getEndPort(), JSONObjectUtil.toJsonString(ao)));
}
} else {
if (ao.getEndPort() < 0 || ao.getEndPort() > 65535) {
throw new ApiMessageInterceptionException(argerr("invalid endPort[%s]. Valid range is [0, 65535]. rule dump: %s",
ao.getEndPort(), JSONObjectUtil.toJsonString(ao)));
}
}
if (ao.getAllowedCidr() != null && !NetworkUtils.isCidr(ao.getAllowedCidr())) {
throw new ApiMessageInterceptionException(argerr("invalid CIDR[%s]. rule dump: %s", ao.getAllowedCidr(), JSONObjectUtil.toJsonString(ao)));
}
}
// Deduplicate in msg
for (int i = 0; i < msg.getRules().size() - 1; i++) {
for (int j = msg.getRules().size() - 1; j > i; j--) {
if (checkSecurityGroupRuleEqual(msg.getRules().get(j), msg.getRules().get(i))) {
throw new ApiMessageInterceptionException(argerr("rule should not be duplicated. rule dump: %s",
JSONObjectUtil.toJsonString(msg.getRules().get(j))));
}
}
}
// Deduplicate in database
SimpleQuery<SecurityGroupRuleVO> lsquery = dbf.createQuery(SecurityGroupRuleVO.class);
lsquery.add(SecurityGroupRuleVO_.securityGroupUuid, Op.EQ, msg.getSecurityGroupUuid());
List<SecurityGroupRuleVO> vos = lsquery.list();
for (SecurityGroupRuleVO svo : vos) {
SecurityGroupRuleAO ao = new SecurityGroupRuleAO();
ao.setType(svo.getType().toString());
ao.setAllowedCidr(svo.getAllowedCidr());
ao.setProtocol(svo.getProtocol().toString());
ao.setStartPort(svo.getStartPort());
ao.setEndPort(svo.getEndPort());
for (SecurityGroupRuleAO sao : msg.getRules()) {
if (checkSecurityGroupRuleEqual(ao, sao)) {
throw new ApiMessageInterceptionException(argerr("rule exist. rule dump: %s",
JSONObjectUtil.toJsonString(sao)));
}
}
}
// fin
for (SecurityGroupRuleAO ao : msg.getRules()) {
int start = Math.min(ao.getStartPort(), ao.getEndPort());
int end = Math.max(ao.getStartPort(), ao.getEndPort());
ao.setStartPort(start);
ao.setEndPort(end);
if (ao.getAllowedCidr() == null) {
ao.setAllowedCidr(SecurityGroupConstant.WORLD_OPEN_CIDR);
}
}
}
}