package org.zstack.test.securitygroup; import junit.framework.Assert; import org.junit.BeforeClass; import org.junit.Test; import org.zstack.core.componentloader.ComponentLoader; import org.zstack.core.db.DatabaseFacade; import org.zstack.core.db.SimpleQuery; import org.zstack.core.db.SimpleQuery.Op; import org.zstack.core.thread.AsyncThread; import org.zstack.header.host.HostVO; import org.zstack.header.vm.VmInstanceInventory; import org.zstack.header.vm.VmInstanceState; import org.zstack.header.vm.VmInstanceVO; import org.zstack.header.vm.VmInstanceVO_; import org.zstack.network.securitygroup.RuleTO; import org.zstack.network.securitygroup.SecurityGroupInventory; import org.zstack.network.securitygroup.SecurityGroupRuleTO; import org.zstack.simulator.SimulatorSecurityGroupBackend; import org.zstack.test.Api; import org.zstack.test.ApiSenderException; import org.zstack.test.DBUtil; import org.zstack.test.WebBeanConstructor; import org.zstack.test.deployer.Deployer; import org.zstack.utils.Utils; import org.zstack.utils.logging.CLogger; import java.util.ArrayList; import java.util.List; import java.util.Random; import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; public class TestSecurityGroupRuleRandom { static CLogger logger = Utils.getLogger(TestSecurityGroupRuleRandom.class); static Deployer deployer; static Api api; static ComponentLoader loader; static DatabaseFacade dbf; static SimulatorSecurityGroupBackend sbkd; static String[] operations = {"stop", "reboot", "start"}; static int num = 10; CountDownLatch latch = new CountDownLatch(3); @BeforeClass public static void setUp() throws Exception { DBUtil.reDeployDB(); WebBeanConstructor con = new WebBeanConstructor(); deployer = new Deployer("deployerXml/securityGroup/TestApplySeurityGroupRulesToVmOnSimulator2.xml", con); deployer.build(); api = deployer.getApi(); loader = deployer.getComponentLoader(); dbf = loader.getComponent(DatabaseFacade.class); sbkd = loader.getComponent(SimulatorSecurityGroupBackend.class); } private String nextOp(VmInstanceState state) { Random r = new Random(); while (true) { int i = r.nextInt(3); String nextOp = operations[i]; if (state == VmInstanceState.Running && "start".equals(nextOp)) { continue; } if (state == VmInstanceState.Stopped && "stop".equals(nextOp)) { continue; } if (state == VmInstanceState.Stopped && "reboot".equals(nextOp)) { continue; } return nextOp; } } @AsyncThread private void randomOpOnVm(VmInstanceInventory vm) throws ApiSenderException { for (int i = 0; i < num; i++) { String nextOp = nextOp(VmInstanceState.valueOf(vm.getState())); if ("start".equals(nextOp)) { vm = api.startVmInstance(vm.getUuid()); } else if ("stop".equals(nextOp)) { vm = api.stopVmInstance(vm.getUuid()); } else if ("reboot".equals(nextOp)) { vm = api.rebootVmInstance(vm.getUuid()); } } latch.countDown(); } private void validate(List<String> internalAllowedIps) { List<HostVO> hosts = dbf.listAll(HostVO.class); for (HostVO h : hosts) { logger.debug(String.format("checking security group rules on host[uuid:%s]", h.getUuid())); Set<SecurityGroupRuleTO> tos = sbkd.getRulesOnHost(h.getUuid()); SimpleQuery<VmInstanceVO> q = dbf.createQuery(VmInstanceVO.class); q.add(VmInstanceVO_.state, Op.EQ, VmInstanceState.Running); q.add(VmInstanceVO_.hostUuid, Op.EQ, h.getUuid()); long count = q.count(); if (count == 0) { List<RuleTO> rules = new ArrayList<RuleTO>(); for (SecurityGroupRuleTO to : tos) { rules.addAll(to.getRules()); } Assert.assertEquals(0, rules.size()); } else { for (SecurityGroupRuleTO to : tos) { for (RuleTO r : to.getRules()) { logger.debug(String.format("expected: %s, real: %s", internalAllowedIps, r.getAllowedInternalIpRange())); Assert.assertTrue(r.getAllowedInternalIpRange().containsAll(internalAllowedIps)); } } } } } @Test public void test() throws ApiSenderException, InterruptedException { SecurityGroupInventory scinv = deployer.securityGroups.get("test"); VmInstanceInventory vm1 = deployer.vms.get("TestVm1"); VmInstanceInventory vm2 = deployer.vms.get("TestVm2"); VmInstanceInventory vm3 = deployer.vms.get("TestVm3"); List<String> nicUuids = new ArrayList<String>(); nicUuids.add(vm1.getVmNics().get(0).getUuid()); nicUuids.add(vm2.getVmNics().get(0).getUuid()); nicUuids.add(vm3.getVmNics().get(0).getUuid()); api.addVmNicToSecurityGroup(scinv.getUuid(), nicUuids); TimeUnit.MILLISECONDS.sleep(500); randomOpOnVm(vm1); randomOpOnVm(vm2); randomOpOnVm(vm3); latch.await(); TimeUnit.SECONDS.sleep(1); List<String> internalAllowedIps = new ArrayList<String>(); List<VmInstanceVO> vmvos = dbf.listAll(VmInstanceVO.class); for (VmInstanceVO vmvo : vmvos) { internalAllowedIps.add(vmvo.getVmNics().iterator().next().getIp()); } validate(internalAllowedIps); } }