package com.thinkbiganalytics.hive.service; /*- * #%L * thinkbig-thrift-proxy-core * %% * Copyright (C) 2017 ThinkBig Analytics * %% * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * #L% */ import com.thinkbiganalytics.kerberos.KerberosTicketConfiguration; import com.thinkbiganalytics.kerberos.KerberosUtil; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.autoconfigure.jdbc.DataSourceBuilder; import org.springframework.core.env.Environment; import org.springframework.jdbc.datasource.DataSourceUtils; import org.springframework.jdbc.support.JdbcUtils; import org.springframework.security.core.context.SecurityContextHolder; import java.io.PrintWriter; import java.sql.Connection; import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; import java.sql.Statement; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.Logger; import javax.inject.Inject; import javax.sql.DataSource; /** */ public class RefreshableDataSource implements DataSource { private static final org.slf4j.Logger log = LoggerFactory.getLogger(RefreshableDataSource.class); private static final String DEFAULT_DATASOURCE_NAME = "DEFAULT"; String propertyPrefix; @Autowired Environment env; private ConcurrentHashMap<String, DataSource> datasources = new ConcurrentHashMap<>(); private AtomicBoolean isRefreshing = new AtomicBoolean(false); @Inject @Qualifier("kerberosHiveConfiguration") private KerberosTicketConfiguration kerberosTicketConfiguration; public RefreshableDataSource(String propertyPrefix) { this.propertyPrefix = propertyPrefix; } public void refresh() { if (isRefreshing.compareAndSet(false, true)) { log.info("REFRESHING DATASOURCE for {} ", propertyPrefix); boolean userImpersonationEnabled = Boolean.valueOf(env.getProperty("hive.userImpersonation.enabled")); if (userImpersonationEnabled && propertyPrefix.equals("hive.datasource")) { String currentUser = (String) SecurityContextHolder.getContext().getAuthentication().getPrincipal(); DataSource dataSource = create(true, currentUser); datasources.put(currentUser, dataSource); } else { DataSource dataSource = create(false, null); datasources.put(DEFAULT_DATASOURCE_NAME, dataSource); } isRefreshing.set(false); } } public boolean testConnection() throws SQLException { return testConnection(null, null); } public boolean testConnection(String username, String password) throws SQLException { boolean valid = false; Connection connection = null; Statement statement = null; try { String prefix = getPrefixWithTrailingDot(); String query = env.getProperty(prefix + "validationQuery"); connection = getConnectionForValidation(); statement = connection.createStatement(); statement.execute(query); valid = true; } catch (SQLException e) { DataSourceUtils.releaseConnection(connection, this.getDataSource()); throw e; } finally { JdbcUtils.closeStatement(statement); DataSourceUtils.releaseConnection(connection, this.getDataSource()); } return valid; } private Connection getConnectionForValidation() throws SQLException { if (getDataSource() == null) { refresh(); } return KerberosUtil.getConnectionWithOrWithoutKerberos(getDataSource(), kerberosTicketConfiguration); } private Connection testAndRefreshIfInvalid() throws SQLException { try { testConnection(); } catch (SQLException e) { refresh(); } return getConnectionForValidation(); } private Connection testAndRefreshIfInvalid(String username, String password) throws SQLException { try { testConnection(username, password); } catch (SQLException e) { refresh(); } return getConnectionForValidation(); } @Override public Connection getConnection() throws SQLException { return testAndRefreshIfInvalid(); } @Override public Connection getConnection(String username, String password) throws SQLException { return testAndRefreshIfInvalid(username, password); } private DataSource getDataSource() { boolean userImpersonationEnabled = Boolean.valueOf(env.getProperty("hive.userImpersonation.enabled")); if (userImpersonationEnabled && propertyPrefix.equals("hive.datasource")) { String currentUser = (String) SecurityContextHolder.getContext().getAuthentication().getPrincipal(); return datasources.get(currentUser); } else { return datasources.get(DEFAULT_DATASOURCE_NAME); } } @Override public PrintWriter getLogWriter() throws SQLException { return getDataSource().getLogWriter(); } @Override public void setLogWriter(PrintWriter out) throws SQLException { getDataSource().setLogWriter(out); } @Override public int getLoginTimeout() throws SQLException { return getDataSource().getLoginTimeout(); } @Override public void setLoginTimeout(int seconds) throws SQLException { getDataSource().setLoginTimeout(seconds); } @Override public Logger getParentLogger() throws SQLFeatureNotSupportedException { return getDataSource().getParentLogger(); } @Override public <T> T unwrap(Class<T> iface) throws SQLException { return getDataSource().unwrap(iface); } @Override public boolean isWrapperFor(Class<?> iface) throws SQLException { return getDataSource().isWrapperFor(iface); } private String getPrefixWithTrailingDot() { String prefix = propertyPrefix.endsWith(".") ? propertyPrefix : propertyPrefix + "."; return prefix; } private DataSource create(boolean proxyUser, String principal) { String prefix = getPrefixWithTrailingDot(); String driverClassName = env.getProperty(prefix + "driverClassName"); String url = env.getProperty(prefix + "url"); String password = env.getProperty(prefix + "password"); String userName = env.getProperty(prefix + "username"); if (proxyUser && propertyPrefix.equals("hive.datasource")) { userName = principal; url = url + ";hive.server2.proxy.user=" + principal; } log.debug("The JDBC URL is " + url + " --- User impersonation enabled: " + proxyUser); String username = userName; DataSource ds = DataSourceBuilder.create().driverClassName(driverClassName).url(url).username(username).password(password).build(); return ds; } }