package ameba.db; import ameba.container.event.ShutdownEvent; import ameba.core.Addon; import ameba.core.Application; import ameba.db.model.ModelManager; import ameba.event.SystemEventBus; import com.alibaba.druid.pool.DruidDataSource; import com.alibaba.druid.pool.DruidDataSourceFactory; import com.google.common.collect.Maps; import org.apache.commons.lang3.StringUtils; import org.glassfish.jersey.internal.inject.AbstractBinder; import org.glassfish.jersey.internal.inject.InstanceBinding; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import javax.sql.DataSource; import java.util.Map; import java.util.Set; /** * <p>DataSourceManager class.</p> * * @author 张立鑫 IntelligentCode * @since 2013-08-07 */ public class DataSourceManager extends Addon { private static final Map<String, DruidDataSource> dataSourceMap = Maps.newLinkedHashMap(); private static final Logger logger = LoggerFactory.getLogger(DataSourceManager.class); private static String DEFAULT_DS_NAME = "default"; /** * <p>getDefaultDataSourceName.</p> * * @return a {@link java.lang.String} object. */ public static String getDefaultDataSourceName() { return DEFAULT_DS_NAME; } /** * 根据数据源名称获取数据源 * * @param name data source name * @return DataSource */ public static DataSource getDataSource(String name) { return dataSourceMap.get(name); } /** * 获取所有数据源名称 * * @return data source name set */ public static Set<String> getDataSourceNames() { return dataSourceMap.keySet(); } /** * {@inheritDoc} */ @Override public void setup(final Application app) { Map<String, Object> config = app.getSrcProperties(); String dsName = (String) config.get("db.default"); if (StringUtils.isNotBlank(dsName)) { DEFAULT_DS_NAME = StringUtils.deleteWhitespace(dsName); } Map<String, Map<String, String>> map = Maps.newHashMap(); for (String key : config.keySet()) { key = StringUtils.deleteWhitespace(key); key = key.replaceAll("\\.{2,}", "."); if (key.startsWith(ModelManager.MODULE_MODELS_KEY_PREFIX)) continue; //db.[DataSourceName].[ConfigKey] String[] keys = key.split("\\."); if (keys.length > 2 && "db".equals(keys[0])) { Map<String, String> sourceConfig = map.computeIfAbsent(keys[1], k -> Maps.newHashMap()); if (StringUtils.isNotBlank(keys[2])) { sourceConfig.put(keys[2], String.valueOf(config.get(key))); } } } for (String name : map.keySet()) { try { Map<String, String> conf = map.get(name); String value = conf.get("init"); if (StringUtils.isBlank(value)) { conf.put("init", "true"); } DruidDataSource ds = (DruidDataSource) DruidDataSourceFactory.createDataSource(conf); ds.setName(name); ds.setDefaultAutoCommit(false); dataSourceMap.put(name, ds); } catch (Exception e) { logger.error("配置数据源出错", e); } } SystemEventBus.subscribe(ShutdownEvent.class, (ShutdownEvent event) -> { dataSourceMap.forEach((name, dataSource) -> { if (!dataSource.isClosed()) dataSource.close(); }); dataSourceMap.clear(); }); app.register(new AbstractBinder() { @Override protected void configure() { for (Map.Entry<String, DruidDataSource> entry : dataSourceMap.entrySet()) { DruidDataSource ds = entry.getValue(); String name = entry.getKey(); createBuilder(ds).named(name); if (getDefaultDataSourceName().equals(name)) { createBuilder(ds); } } } private InstanceBinding<DruidDataSource> createBuilder(DruidDataSource dataSource) { return bind(dataSource) .to(DruidDataSource.class) .to(DataSource.class) .proxy(false); } }); } }