package core.framework.impl.db; import core.framework.api.db.Database; import core.framework.api.db.Repository; import core.framework.api.db.Transaction; import core.framework.api.db.UncheckedSQLException; import core.framework.api.log.ActionLogContext; import core.framework.api.log.Markers; import core.framework.api.util.Exceptions; import core.framework.api.util.Maps; import core.framework.api.util.StopWatch; import core.framework.impl.resource.Pool; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.math.BigDecimal; import java.sql.Connection; import java.sql.Driver; import java.sql.SQLException; import java.time.Duration; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.ZonedDateTime; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Properties; /** * @author neo */ public final class DatabaseImpl implements Database { public final Pool<Connection> pool; public final DatabaseOperation operation; private final Logger logger = LoggerFactory.getLogger(DatabaseImpl.class); private final Map<Class<?>, RowMapper<?>> rowMappers = Maps.newHashMap(); public int tooManyRowsReturnedThreshold = 1000; public String url; public String user; public String password; long slowOperationThresholdInNanos = Duration.ofSeconds(5).toNanos(); private Properties driverProperties; private Duration timeout; private Driver driver; public DatabaseImpl() { initializeRowMappers(); pool = new Pool<>(this::createConnection, Connection::close); pool.name("db"); pool.size(5, 50); // default optimization for AWS medium/large instances pool.maxIdleTime(Duration.ofHours(2)); // make sure db server does not kill connection shorter than this, e.g. MySQL default wait_timeout is 8 hours operation = new DatabaseOperation(pool); timeout(Duration.ofSeconds(15)); } private void initializeRowMappers() { rowMappers.put(String.class, new RowMapper.StringRowMapper()); rowMappers.put(Integer.class, new RowMapper.IntegerRowMapper()); rowMappers.put(Long.class, new RowMapper.LongRowMapper()); rowMappers.put(Double.class, new RowMapper.DoubleRowMapper()); rowMappers.put(BigDecimal.class, new RowMapper.BigDecimalRowMapper()); rowMappers.put(Boolean.class, new RowMapper.BooleanRowMapper()); rowMappers.put(LocalDateTime.class, new RowMapper.LocalDateTimeRowMapper()); rowMappers.put(LocalDate.class, new RowMapper.LocalDateRowMapper()); rowMappers.put(ZonedDateTime.class, new RowMapper.ZonedDateTimeRowMapper()); } private Connection createConnection() { if (url == null) throw new Error("url must not be null"); Properties driverProperties = this.driverProperties; if (driverProperties == null) { driverProperties = driverProperties(); this.driverProperties = driverProperties; } try { return driver.connect(url, driverProperties); } catch (SQLException e) { throw new UncheckedSQLException(e); } } private Properties driverProperties() { Properties properties = new Properties(); if (user != null) properties.put("user", user); if (password != null) properties.put("password", password); String timeoutValue = String.valueOf(timeout.toMillis()); if (url.startsWith("jdbc:mysql:")) { properties.put("connectTimeout", timeoutValue); properties.put("socketTimeout", timeoutValue); } else if (url.startsWith("jdbc:oracle:")) { properties.put("oracle.net.CONNECT_TIMEOUT", timeoutValue); properties.put("oracle.jdbc.ReadTimeout", timeoutValue); } return properties; } public void close() { logger.info("close database client, url={}", url); pool.close(); } public void timeout(Duration timeout) { this.timeout = timeout; operation.queryTimeoutInSeconds = (int) timeout.getSeconds(); pool.checkoutTimeout(timeout); } public void url(String url) { if (!url.startsWith("jdbc:")) throw Exceptions.error("jdbc url must start with \"jdbc:\", url={}", url); logger.info("set database connection url, url={}", url); this.url = url; driver = driver(url); } private Driver driver(String url) { try { if (url.startsWith("jdbc:mysql:")) { return (Driver) Class.forName("com.mysql.jdbc.Driver").newInstance(); } else if (url.startsWith("jdbc:hsqldb:")) { return (Driver) Class.forName("org.hsqldb.jdbc.JDBCDriver").newInstance(); } else if (url.startsWith("jdbc:oracle:")) { return (Driver) Class.forName("oracle.jdbc.OracleDriver").newInstance(); } else { throw Exceptions.error("not supported database, please contact arch team, url={}", url); } } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { throw new Error(e); } } public void slowOperationThreshold(Duration slowOperationThreshold) { slowOperationThresholdInNanos = slowOperationThreshold.toNanos(); } public <T> void view(Class<T> viewClass) { StopWatch watch = new StopWatch(); try { new DatabaseClassValidator(viewClass).validateViewClass(); registerViewClass(viewClass); } finally { logger.info("register db view, viewClass={}, elapsedTime={}", viewClass.getCanonicalName(), watch.elapsedTime()); } } public <T> Repository<T> repository(Class<T> entityClass) { StopWatch watch = new StopWatch(); try { new DatabaseClassValidator(entityClass).validateEntityClass(); RowMapper<T> mapper = registerViewClass(entityClass); return new RepositoryImpl<>(this, entityClass, mapper); } finally { logger.info("register db entity, entityClass={}, elapsedTime={}", entityClass.getCanonicalName(), watch.elapsedTime()); } } @Override public Transaction beginTransaction() { return operation.transactionManager.beginTransaction(); } @Override public <T> List<T> select(String sql, Class<T> viewClass, Object... params) { StopWatch watch = new StopWatch(); try { List<T> results = operation.select(sql, rowMapper(viewClass), params); checkTooManyRowsReturned(results.size()); return results; } finally { long elapsedTime = watch.elapsedTime(); ActionLogContext.track("db", elapsedTime); logger.debug("select, sql={}, params={}, elapsedTime={}", sql, params, elapsedTime); checkSlowOperation(elapsedTime); } } @Override public <T> Optional<T> selectOne(String sql, Class<T> viewClass, Object... params) { StopWatch watch = new StopWatch(); try { return operation.selectOne(sql, rowMapper(viewClass), params); } finally { long elapsedTime = watch.elapsedTime(); ActionLogContext.track("db", elapsedTime); logger.debug("selectOne, sql={}, params={}, elapsedTime={}", sql, params, elapsedTime); checkSlowOperation(elapsedTime); } } @Override public int execute(String sql, Object... params) { StopWatch watch = new StopWatch(); try { return operation.update(sql, params); } finally { long elapsedTime = watch.elapsedTime(); ActionLogContext.track("db", elapsedTime); logger.debug("execute, sql={}, params={}, elapsedTime={}", sql, params, elapsedTime); checkSlowOperation(elapsedTime); } } private <T> RowMapper<T> rowMapper(Class<T> viewClass) { @SuppressWarnings("unchecked") RowMapper<T> mapper = (RowMapper<T>) rowMappers.get(viewClass); if (mapper == null) throw Exceptions.error("view class is not registered, please register in module by db().view(), viewClass={}", viewClass.getCanonicalName()); return mapper; } private <T> RowMapper<T> registerViewClass(Class<T> viewClass) { if (rowMappers.containsKey(viewClass)) { throw Exceptions.error("found duplicate view class, viewClass={}", viewClass.getCanonicalName()); } RowMapper<T> mapper = new RowMapperBuilder<>(viewClass, operation.enumMapper).build(); rowMappers.put(viewClass, mapper); return mapper; } private void checkTooManyRowsReturned(int size) { if (size > tooManyRowsReturnedThreshold) { logger.warn(Markers.errorCode("TOO_MANY_ROWS_RETURNED"), "too many rows returned, returnedRows={}", size); } } private void checkSlowOperation(long elapsedTime) { if (elapsedTime > slowOperationThresholdInNanos) { logger.warn(Markers.errorCode("SLOW_DB"), "slow db operation, elapsedTime={}", elapsedTime); } } }