/* * Copyright 2014-2015 the original author or authors * * Licensed under the Apache License, Version 2.0 (the “License”); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an “AS IS” BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.wplatform.ddal.shards; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.*; import java.util.concurrent.ThreadPoolExecutor.AbortPolicy; import javax.sql.DataSource; import com.wplatform.ddal.config.Configuration; import com.wplatform.ddal.config.DataSourceException; import com.wplatform.ddal.config.DataSourceProvider; import com.wplatform.ddal.config.ShardConfig; import com.wplatform.ddal.config.ShardConfig.ShardItem; import com.wplatform.ddal.engine.Database; import com.wplatform.ddal.message.Trace; import com.wplatform.ddal.util.JdbcUtils; import com.wplatform.ddal.util.New; import com.wplatform.ddal.util.StringUtils; /** * @author <a href="mailto:jorgie.mail@gmail.com">jorgie li</a> */ public class DataSourceRepository { private final Database database; private final List<DataSourceMarker> registered = New.arrayList(); private final List<DataSourceMarker> abnormalList = New.copyOnWriteArrayList(); private final List<DataSourceMarker> monitor = New.copyOnWriteArrayList(); private final HashMap<String, DataSource> shardMaping = New.hashMap(); private final HashMap<String, DataSource> idMapping = New.hashMap(); private final String defaultShardName; private final DataSourceProvider dataSourceProvider; private final Trace trace; protected ScheduledExecutorService abnormalScheduler; protected ScheduledExecutorService monitorScheduler; private String validationQuery; private int validationQueryTimeout; private ThreadPoolExecutor jdbcExecutor; private ScheduledExecutorService scheduledExecutor; public DataSourceRepository(Database database) { this.database = database; Configuration configuration = database.getConfiguration(); this.defaultShardName = configuration.getSchemaConfig().getShard(); this.validationQuery = database.getSettings().defaultValidationQuery; this.validationQueryTimeout = database.getSettings().defaultValidationQueryTimeout; this.dataSourceProvider = configuration.getDataSourceProvider(); if(dataSourceProvider == null) { throw new IllegalArgumentException(); } this.trace = database.getTrace(Trace.DATASOURCE); Map<String, ShardConfig> shardMapping = configuration.getCluster(); for (ShardConfig value : shardMapping.values()) { List<ShardItem> shardItems = value.getShardItems(); List<DataSourceMarker> shardDs = New.arrayList(shardItems.size()); DataSourceMarker dsMarker = new DataSourceMarker(); for (ShardItem i : shardItems) { String ref = i.getRef(); DataSource dataSource = dataSourceProvider.lookup(ref); if (dataSource == null) { throw new DataSourceException("Can' find data source: " + ref); } dsMarker.setDataSource(dataSource); dsMarker.setShardName(value.getName()); dsMarker.setUid(ref); dsMarker.setReadOnly(i.isReadOnly()); dsMarker.setwWeight(i.getwWeight()); dsMarker.setrWeight(i.getrWeight()); shardDs.add(dsMarker); idMapping.put(ref, dsMarker.getDataSource()); } if (shardDs.size() < 1) { throw new DataSourceException("No datasource in " + value.getName()); } registered.addAll(shardDs); DataSource dataSource = shardDs.size() > 1 ? new SmartDataSource(this, value.getName(), shardDs) : shardDs.get(0).getDataSource(); shardMaping.put(value.getName(), dataSource); } scheduledExecutor = Executors.newScheduledThreadPool(1, New.customThreadFactory("datasource-ha-thread")); scheduledExecutor.scheduleAtFixedRate(new Worker(), 10, 10, TimeUnit.SECONDS); } public DataSource getDataSourceByShardName(String shardName) { DataSource dataSource = shardMaping.get(shardName); if(dataSource == null) { throw new IllegalArgumentException(shardName + " DataSource not found."); } return dataSource; } public DataSource getDataSourceById(String id) { DataSource dataSource = idMapping.get(id); if(dataSource == null) { throw new IllegalArgumentException(); } return dataSource; } /** * @return the validationQuery */ public String getValidationQuery() { return validationQuery; } /** * @param validationQuery the validationQuery to set */ public void setValidationQuery(String validationQuery) { this.validationQuery = validationQuery; } /** * @return the validationQueryTimeout */ public int getValidationQueryTimeout() { return validationQueryTimeout; } /** * @param validationQueryTimeout the validationQueryTimeout to set */ public void setValidationQueryTimeout(int validationQueryTimeout) { this.validationQueryTimeout = validationQueryTimeout; } public Trace getTrace() { return trace; } public int shardCount() { return this.shardMaping.size(); } public String getDefaultShardName() { return defaultShardName; } public DataSource getDefaultShardDataSource() { if(StringUtils.isNullOrEmpty(defaultShardName)) { return null; } return shardMaping.get(defaultShardName); } /** * TODO configurable * @return the jdbcExecutor */ public ThreadPoolExecutor getJdbcExecutor() { if (jdbcExecutor == null) { int corePoolSize = Runtime.getRuntime().availableProcessors(); int maximumPoolSize = 200;// TODO configurable int capacity = maximumPoolSize * 1; int keepAliveTime = database.getSettings().maxQueryTimeout; if (keepAliveTime <= 0) { keepAliveTime = 15 * 60000; // 15 MINUTES } BlockingQueue<Runnable> workQueue = new LinkedBlockingQueue<Runnable>(capacity); jdbcExecutor = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, keepAliveTime, TimeUnit.MILLISECONDS, workQueue, New.customThreadFactory("jdbc-worker"), new AbortPolicy()); jdbcExecutor.allowCoreThreadTimeOut(true); } return jdbcExecutor; } public void close() { try { this.scheduledExecutor.awaitTermination(500, TimeUnit.MILLISECONDS); } catch (InterruptedException e) { trace.error(e, "scheduledExecutor awaitTermination"); } try { if (jdbcExecutor != null) { jdbcExecutor.awaitTermination(1, TimeUnit.SECONDS); } } catch (InterruptedException e) { trace.error(e, "jdbcExecutor awaitTermination"); } } Connection getConnection(DataSourceMarker selected) throws SQLException { DataSource dataSource = selected.getDataSource(); try { return dataSource.getConnection(); } catch (SQLException e) { selected.incrementFailedCount(); monitor.add(selected); throw e; } } Connection getConnection(DataSourceMarker selected, String username, String password) throws SQLException { DataSource dataSource = selected.getDataSource(); try { return dataSource.getConnection(username, password); } catch (SQLException e) { selected.incrementFailedCount(); monitor.add(selected); throw e; } } private class Worker implements Runnable { @Override public void run() { try { handleMonitorList(); } catch (Exception e) { trace.error(e, "datasource-ha-thread handle monitor list error"); } try { hanldeAbnormalList(); } catch (Exception e) { trace.error(e, "datasource-ha-thread handle monitor list error"); } } /** * @throws SQLException */ private void hanldeAbnormalList() throws SQLException { for (DataSourceMarker failed : abnormalList) { DataSource ds = failed.getDataSource(); boolean isOk = validateAvailable(ds); if (isOk) { DataSource dataSource = shardMaping.get(failed.getShardName()); Failover selector = (Failover) dataSource; selector.doHandleWakeup(failed); abnormalList.remove(failed); } } } /** * @throws SQLException */ private void handleMonitorList() throws SQLException { for (DataSourceMarker source : monitor) { DataSource ds = source.getDataSource(); boolean isOk = validateAvailable(ds); if (!isOk) { DataSource dataSource = shardMaping.get(source.getShardName()); Failover selector = (Failover) dataSource; selector.doHandleAbnormal(source); abnormalList.add(source); trace.error(null, source.toString() + " was abnormal,it's remove in " + source.getShardName()); } monitor.remove(source); } } private boolean validateAvailable(DataSource dataSource) throws SQLException { Connection conn = null; try { conn = dataSource.getConnection(); } catch (SQLException ex) { // skip return false; } Statement stmt = null; ResultSet rs = null; try { stmt = conn.createStatement(); if (validationQueryTimeout > 0) { stmt.setQueryTimeout(validationQueryTimeout); } else { stmt.setQueryTimeout(5); } rs = stmt.executeQuery(validationQuery); return true; } catch (SQLException e) { return false; } catch (Exception e) { // LOG.warn("Unexpected error in ping", e); return false; } finally { JdbcUtils.closeSilently(rs); JdbcUtils.closeSilently(stmt); } } } }