package com.github.yingzhuo.spring.auto.datasource.composite;
import com.github.yingzhuo.spring.auto.datasource.composite.aop.DataSourceRemoter;
import com.github.yingzhuo.spring.auto.datasource.NamedDataSource;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import javax.sql.DataSource;
import java.io.PrintWriter;
import java.lang.reflect.Method;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;
public final class CompositeDataSource implements NamedDataSource, InitializingBean {
private static final org.slf4j.Logger LOGGER = LoggerFactory.getLogger(CompositeDataSource.class);
private final Map<String, DataSource> namedDataSources = new HashMap<>();
private String defaultDataSourceName;
private String initMethod;
private String destoryMethod;
@Override
public void afterPropertiesSet() throws Exception {
Assert.hasText(defaultDataSourceName);
Assert.isTrue(!namedDataSources.isEmpty());
Assert.isTrue(null != namedDataSources.get(defaultDataSourceName));
}
public void init() throws Exception {
LOGGER.debug("init {}", CompositeDataSource.class.getSimpleName());
if (StringUtils.hasLength(initMethod)) {
for (DataSource ds : this.namedDataSources.values()) {
reflectionInvocateMethod(ds, initMethod);
}
}
}
public void close() throws Exception {
LOGGER.debug("close {}", CompositeDataSource.class.getSimpleName());
if (StringUtils.hasLength(destoryMethod)) {
for (DataSource ds : this.namedDataSources.values()) {
reflectionInvocateMethod(ds, destoryMethod);
}
}
}
public void reflectionInvocateMethod(Object target, String methodName) throws Exception {
Method method = target.getClass().getMethod(methodName);
if (!method.isAccessible()) {
method.setAccessible(true);
}
method.invoke(target);
}
public CompositeDataSource() {
super();
}
public CompositeDataSource add(NamedDataSource... namedDataSourceArray) {
for (NamedDataSource namedDataSource : namedDataSourceArray) {
this.namedDataSources.put(namedDataSource.getName(), namedDataSource.getDataSource());
}
return this;
}
@Override
public String getName() {
return CompositeDataSource.class.getName();
}
@Override
public DataSource getDataSource() {
throw new UnsupportedOperationException();
}
@Override
public Connection getConnection() throws SQLException {
return getEffectDataSource().getConnection();
}
@Override
public Connection getConnection(String username, String password) throws SQLException {
return getEffectDataSource().getConnection(username, password);
}
@Override
public PrintWriter getLogWriter() throws SQLException {
return getEffectDataSource().getLogWriter();
}
@Override
public void setLogWriter(PrintWriter out) throws SQLException {
getEffectDataSource().setLogWriter(out);
}
@Override
public void setLoginTimeout(int seconds) throws SQLException {
getEffectDataSource().setLoginTimeout(seconds);
}
@Override
public int getLoginTimeout() throws SQLException {
return getEffectDataSource().getLoginTimeout();
}
@Override
public Logger getParentLogger() throws SQLFeatureNotSupportedException {
return getEffectDataSource().getParentLogger();
}
@Override
public <T> T unwrap(Class<T> iface) throws SQLException {
return getEffectDataSource().unwrap(iface);
}
@Override
public boolean isWrapperFor(Class<?> iface) throws SQLException {
return getEffectDataSource().isWrapperFor(iface);
}
private DataSource getEffectDataSource() {
if (namedDataSources.size() == 1) {
return namedDataSources.values().iterator().next();
}
DataSourceRemoter remoter = DataSourceRemoter.getInstance();
DataSource dataSource = namedDataSources.get(remoter.get());
if (dataSource == null) {
dataSource = namedDataSources.get(defaultDataSourceName);
}
return dataSource;
}
public String getDefaultDataSourceName() {
return defaultDataSourceName;
}
public void setDefaultDataSourceName(String defaultDataSourceName) {
this.defaultDataSourceName = defaultDataSourceName;
}
public String getInitMethod() {
return initMethod;
}
public void setInitMethod(String initMethod) {
this.initMethod = initMethod;
}
public String getDestoryMethod() {
return destoryMethod;
}
public void setDestoryMethod(String destoryMethod) {
this.destoryMethod = destoryMethod;
}
}