package com.rubiconproject.oss.kv.distributed.impl; import java.io.IOException; import java.io.PrintWriter; import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.LinkedList; import java.util.List; import java.util.Properties; import javax.naming.Context; import javax.naming.InitialContext; import javax.naming.NamingException; import javax.sql.DataSource; import com.rubiconproject.oss.kv.distributed.AbstractRefreshingNodeStore; import com.rubiconproject.oss.kv.distributed.ConfigurationException; import com.rubiconproject.oss.kv.distributed.Node; import com.rubiconproject.oss.kv.distributed.NodeStore; /** * Reads nodes from a jdbc database. Table should look like this: * * create table node ( id int primary key, store_id int not null, physical_id int not null, salt * varchar(10) unique not null, connection_uri varchar(128) not null, status * tinyint not null); * * Nodes with a status other than 1 will be ignored. * * @author sam * */ public class JdbcNodeStore extends AbstractRefreshingNodeStore implements NodeStore { public static final String DATA_SOURCE_PROPERTY = "nodeStore.dataSource"; public static final String JDBC_DRIVER_PROPERTY = "nodeStore.jdbcDriver"; public static final String JDBC_URL_PROPERTY = "nodeStore.jdbcUrl"; public static final String JDBC_USER_PROPERTY = "nodeStore.jdbcUsername"; public static final String JDBC_PASSWORD_PROPERTY = "nodeStore.jdbcPassword"; public static final String JDBC_STORE_ID = "nodeStore.id"; private DataSource ds; private int storeId; public JdbcNodeStore() { super(); } public JdbcNodeStore(Properties props) { super(); setProperties(props); } @Override public void addNode(Node node) { Connection conn = null; PreparedStatement select = null; PreparedStatement upsert = null; ResultSet rs = null; try { conn = getConnection(); // would love to use ON DUPLICATE KEY UPDATE here but for // compatibility with non-mysql databases I'm not going to do so. select = conn .prepareStatement("select count(id) as count from node where id = ?"); select.setInt(1, node.getId()); rs = select.executeQuery(); boolean update = false; if (rs.next()) { int count = rs.getInt(1); update = (count > 0); } if (update) { upsert = conn .prepareStatement("update node set status = ? where id = ?"); upsert.setInt(1, 1); upsert.setInt(2, node.getId()); } else { upsert = conn .prepareStatement("insert into node (id, store_id, physical_id, salt, connection_uri, status) values (?, ?, ?, ?, ?)"); upsert.setInt(1, node.getId()); upsert.setInt(2, storeId); upsert.setInt(3, node.getPhysicalId()); upsert.setString(4, node.getSalt()); upsert.setString(5, node.getConnectionURI()); upsert.setInt(6, 1); } upsert.executeUpdate(); if (!conn.getAutoCommit()) conn.commit(); // only add node to in-memory structure if above code succeeded super.addNode(node); } catch (SQLException e) { log.error("SQLException adding node()", e); } catch (NamingException e) { log.error("NamingException adding node()", e); } catch (ClassNotFoundException e) { log.error("ClassNotFoundException adding node()", e); } finally { if (rs != null) { try { rs.close(); } catch (Exception e) { } } if (select != null) { try { select.close(); } catch (Exception e) { } } if (upsert != null) { try { upsert.close(); } catch (Exception e) { } } if (conn != null) { try { conn.close(); } catch (Exception e) { } } } } @Override public void removeNode(Node node) { Connection conn = null; PreparedStatement ps = null; try { // remove node before any sql operations super.removeNode(node); conn = getConnection(); ps = conn .prepareStatement("update node set status = ? where id = ?"); ps.setInt(1, 2); ps.setInt(2, node.getId()); ps.execute(); if (!conn.getAutoCommit()) conn.commit(); } catch (SQLException e) { log.error("SQLException removing node()", e); } catch (NamingException e) { log.error("NamingException removing node()", e); } catch (ClassNotFoundException e) { log.error("ClassNotFoundException removing node()", e); } finally { if (ps != null) { try { ps.close(); } catch (Exception e) { } } if (conn != null) { try { conn.close(); } catch (Exception e) { } } } } @Override public List<Node> refreshActiveNodes() throws IOException, ConfigurationException { Connection conn = null; PreparedStatement ps = null; ResultSet rs = null; try { List<Node> nodes = new LinkedList<Node>(); conn = getConnection(); ps = conn .prepareStatement("select id, physical_id, salt, connection_uri from node where store_id = ? and status = 1 order by id asc"); ps.setInt(1, storeId); rs = ps.executeQuery(); while (rs.next()) { DefaultNodeImpl node = new DefaultNodeImpl(); node.setConnectionURI(rs.getString("connection_uri")); node.setId(rs.getInt("id")); node.setPhysicalId(rs.getInt("physical_id")); node.setSalt(rs.getString("salt")); nodes.add(node); } return nodes; } catch (SQLException e) { throw new IOException(e); } catch (NamingException e) { throw new ConfigurationException(e); } catch (ClassNotFoundException e) { throw new ConfigurationException(e); } finally { if (rs != null) { try { rs.close(); } catch (Exception e) { } } if (ps != null) { try { ps.close(); } catch (Exception e) { } } if (conn != null) { try { conn.close(); } catch (Exception e) { } } } } private Connection getConnection() throws SQLException, NamingException, ClassNotFoundException { if (ds == null) { storeId = Integer.parseInt(props.getProperty(JDBC_STORE_ID)); String dataSourceName = props.getProperty(DATA_SOURCE_PROPERTY); if (dataSourceName != null) { Context initCtx = new InitialContext(); ds = (DataSource) initCtx.lookup(dataSourceName); } else { String driver = props.getProperty(JDBC_DRIVER_PROPERTY); String url = props.getProperty(JDBC_URL_PROPERTY); String user = props.getProperty(JDBC_USER_PROPERTY); String password = props.getProperty(JDBC_PASSWORD_PROPERTY); ds = new SimpleDataSource(driver, url, user, password); } } return ds.getConnection(); } private class SimpleDataSource implements DataSource { private PrintWriter pw; private int loginTimeout; private String url; private String user; private String password; public SimpleDataSource(String driver, String url, String user, String password) throws ClassNotFoundException { this.url = url; this.user = user; this.password = password; Class.forName(driver); } public Connection getConnection() throws SQLException { return DriverManager.getConnection(url, user, password); } public Connection getConnection(String username, String password) throws SQLException { return getConnection(); } public PrintWriter getLogWriter() throws SQLException { return pw; } public int getLoginTimeout() throws SQLException { return loginTimeout; } public void setLogWriter(PrintWriter pw) throws SQLException { this.pw = pw; } public void setLoginTimeout(int loginTimeout) throws SQLException { this.loginTimeout = loginTimeout; } public boolean isWrapperFor(Class<?> cls) throws SQLException { return false; } public <T> T unwrap(Class<T> cls) throws SQLException { return null; } } }