/*
* Copyright (c) 2007, PostgreSQL Global Development Group
* See the LICENSE file in the project root for more information.
*/
package org.postgresql.test.hostchooser;
import static java.lang.Integer.parseInt;
import static java.util.Arrays.asList;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.postgresql.hostchooser.HostRequirement.any;
import static org.postgresql.hostchooser.HostRequirement.master;
import static org.postgresql.hostchooser.HostRequirement.preferSlave;
import static org.postgresql.hostchooser.HostRequirement.slave;
import static org.postgresql.hostchooser.HostStatus.ConnectFail;
import static org.postgresql.hostchooser.HostStatus.Slave;
import static org.postgresql.test.TestUtil.closeDB;
import org.postgresql.hostchooser.GlobalHostStatusTracker;
import org.postgresql.hostchooser.HostRequirement;
import org.postgresql.test.TestUtil;
import org.postgresql.util.HostSpec;
import org.postgresql.util.PSQLException;
import junit.framework.TestCase;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.HashSet;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
public class MultiHostsConnectionTest extends TestCase {
static final String user = TestUtil.getUser();
static final String password = TestUtil.getPassword();
static final String master1 = TestUtil.getServer() + ":" + TestUtil.getPort();
static final String slave1 =
MultiHostTestSuite.getSlaveServer() + ":" + MultiHostTestSuite.getSlavePort();
static final String fake1 = "127.127.217.217:1";
static String masterIp;
static String slaveIp;
static String fakeIp = fake1;
static Connection con;
private static Map<HostSpec, Object> hostStatusMap;
static {
try {
Field field = GlobalHostStatusTracker.class.getDeclaredField("hostStatusMap");
field.setAccessible(true);
hostStatusMap = (Map<HostSpec, Object>) field.get(null);
con = TestUtil.openDB();
masterIp = getRemoteHostSpec();
closeDB(con);
con = MultiHostTestSuite.openSlaveDB();
slaveIp = getRemoteHostSpec();
closeDB(con);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static Connection getConnection(HostRequirement hostType, String... targets)
throws SQLException {
return getConnection(hostType, true, targets);
}
private static HostSpec hostSpec(String host) {
int split = host.indexOf(':');
return new HostSpec(host.substring(0, split), parseInt(host.substring(split + 1)));
}
private static Connection getConnection(HostRequirement hostType, boolean reset,
String... targets) throws SQLException {
return getConnection(hostType, reset, false, targets);
}
private static Connection getConnection(HostRequirement hostType, boolean reset, boolean lb,
String... targets) throws SQLException {
TestUtil.closeDB(con);
if (reset) {
resetGlobalState();
}
Properties props = new Properties();
props.setProperty("user", user);
props.setProperty("password", password);
props.setProperty("targetServerType", hostType.name());
props.setProperty("hostRecheckSeconds", "2");
if (lb) {
props.setProperty("loadBalanceHosts", "true");
}
StringBuilder sb = new StringBuilder();
sb.append("jdbc:postgresql://");
for (String target : targets) {
sb.append(target).append(',');
}
sb.setLength(sb.length() - 1);
sb.append("/test");
return con = DriverManager.getConnection(sb.toString(), props);
}
private static void assertRemote(String expectedHost) throws SQLException {
assertEquals(expectedHost, getRemoteHostSpec());
}
private static String getRemoteHostSpec() throws SQLException {
ResultSet rs = con.createStatement()
.executeQuery("select inet_server_addr() || ':' || inet_server_port()");
rs.next();
return rs.getString(1);
}
public static boolean isMaster(Connection con) throws SQLException {
ResultSet rs = con.createStatement().executeQuery("show transaction_read_only");
rs.next();
return "off".equals(rs.getString(1));
}
private static void assertGlobalState(String host, String status) {
HostSpec spec = hostSpec(host);
if (status == null) {
assertNull(hostStatusMap.get(spec));
} else {
assertEquals(host + "=" + status, hostStatusMap.get(spec).toString());
}
}
private static void resetGlobalState() {
hostStatusMap.clear();
}
public static void testConnectToAny() throws SQLException {
getConnection(any, fake1, master1);
assertRemote(masterIp);
assertGlobalState(master1, "ConnectOK");
assertGlobalState(fake1, "ConnectFail");
getConnection(any, fake1, slave1);
assertRemote(slaveIp);
assertGlobalState(slave1, "ConnectOK");
getConnection(any, fake1, master1);
assertRemote(masterIp);
assertGlobalState(master1, "ConnectOK");
assertGlobalState(fake1, "ConnectFail");
}
public static void testConnectToMaster() throws SQLException {
getConnection(master, true, fake1, master1, slave1);
assertRemote(masterIp);
assertGlobalState(fake1, "ConnectFail");
assertGlobalState(master1, "Master");
assertGlobalState(slave1, null);
getConnection(master, false, fake1, slave1, master1);
assertRemote(masterIp);
assertGlobalState(fake1, "ConnectFail");
assertGlobalState(master1, "Master");
assertGlobalState(slave1, "Slave");
}
public static void testConnectToSlave() throws SQLException {
getConnection(slave, true, fake1, slave1, master1);
assertRemote(slaveIp);
assertGlobalState(fake1, "ConnectFail");
assertGlobalState(slave1, "Slave");
assertGlobalState(master1, null);
getConnection(slave, false, fake1, master1, slave1);
assertRemote(slaveIp);
assertGlobalState(fake1, "ConnectFail");
assertGlobalState(slave1, "Slave");
assertGlobalState(master1, "Master");
}
public static void testConnectToSlaveFirst() throws SQLException {
getConnection(preferSlave, true, fake1, slave1, master1);
assertRemote(slaveIp);
assertGlobalState(fake1, "ConnectFail");
assertGlobalState(slave1, "Slave");
assertGlobalState(master1, null);
getConnection(preferSlave, false, fake1, master1, slave1);
assertRemote(masterIp);
assertGlobalState(fake1, "ConnectFail");
assertGlobalState(slave1, "Slave");
assertGlobalState(master1, "Master");
getConnection(preferSlave, false, fake1, master1, slave1);
assertRemote(slaveIp);
assertGlobalState(fake1, "ConnectFail");
assertGlobalState(slave1, "Slave");
assertGlobalState(master1, "Master");
}
public static void testFailedConnection() throws SQLException {
try {
getConnection(any, true, fake1);
fail();
} catch (PSQLException ex) {
}
}
public static void testLoadBalancing() throws SQLException {
Set<String> connectedHosts = new HashSet<String>();
boolean fake1FoundTried = false;
for (int i = 0; i < 20; ++i) {
getConnection(any, true, true, fake1, master1, slave1);
connectedHosts.add(getRemoteHostSpec());
fake1FoundTried |= hostStatusMap.containsKey(hostSpec(fake1));
if (connectedHosts.size() == 2 && fake1FoundTried) {
break;
}
}
assertEquals("Never connected to all hosts", new HashSet<String>(asList(masterIp, slaveIp)),
connectedHosts);
assertTrue("Never tried to connect to fake node", fake1FoundTried);
}
public static void testHostRechecks() throws SQLException, InterruptedException {
getConnection(master, true, fake1, master1, slave1);
assertRemote(masterIp);
assertGlobalState(fake1, "ConnectFail");
assertGlobalState(slave1, null);
GlobalHostStatusTracker.reportHostStatus(hostSpec(master1), ConnectFail);
assertGlobalState(master1, "ConnectFail");
try {
getConnection(master, false, fake1, slave1, master1);
fail();
} catch (SQLException ex) {
}
SECONDS.sleep(3);
getConnection(master, false, slave1, fake1, master1);
assertRemote(masterIp);
}
public static void testNoGoodHostsRechecksEverything() throws SQLException, InterruptedException {
GlobalHostStatusTracker.reportHostStatus(hostSpec(master1), Slave);
GlobalHostStatusTracker.reportHostStatus(hostSpec(slave1), Slave);
GlobalHostStatusTracker.reportHostStatus(hostSpec(fake1), Slave);
getConnection(master, false, slave1, fake1, master1);
assertRemote(masterIp);
}
}