/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hive.ql.security.authorization.plugin;
import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.hadoop.hive.UtilsForTest;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.conf.HiveConf.ConfVars;
import org.apache.hadoop.hive.ql.CommandNeedRetryException;
import org.apache.hadoop.hive.ql.Driver;
import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse;
import org.apache.hadoop.hive.ql.security.HiveAuthenticationProvider;
import org.apache.hadoop.hive.ql.security.SessionStateUserAuthenticator;
import org.apache.hadoop.hive.ql.security.authorization.plugin.HivePrivilegeObject.HivePrivilegeObjectType;
import org.apache.hadoop.hive.ql.session.SessionState;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.Mockito;
/**
* Test HiveAuthorizer api invocation for filtering objects
*/
public class TestHiveAuthorizerShowFilters {
protected static HiveConf conf;
protected static Driver driver;
private static final String tableName1 = (TestHiveAuthorizerShowFilters.class.getSimpleName() + "table1")
.toLowerCase();
private static final String tableName2 = (TestHiveAuthorizerShowFilters.class.getSimpleName() + "table2")
.toLowerCase();
private static final String dbName1 = (TestHiveAuthorizerShowFilters.class.getSimpleName() + "db1")
.toLowerCase();
private static final String dbName2 = (TestHiveAuthorizerShowFilters.class.getSimpleName() + "db2")
.toLowerCase();
protected static HiveAuthorizer mockedAuthorizer;
static final List<String> AllTables = getSortedList(tableName1, tableName2);
static final List<String> AllDbs = getSortedList("default", dbName1, dbName2);
private static List<HivePrivilegeObject> filterArguments = null;
private static List<HivePrivilegeObject> filteredResults = new ArrayList<HivePrivilegeObject>();
/**
* This factory creates a mocked HiveAuthorizer class. The mocked class is
* used to capture the argument passed to HiveAuthorizer.filterListCmdObjects.
* It returns fileredResults object for call to
* HiveAuthorizer.filterListCmdObjects, and stores the list argument in
* filterArguments
*/
protected static class MockedHiveAuthorizerFactory implements HiveAuthorizerFactory {
protected abstract class AuthorizerWithFilterCmdImpl implements HiveAuthorizer {
@Override
public List<HivePrivilegeObject> filterListCmdObjects(List<HivePrivilegeObject> listObjs,
HiveAuthzContext context) throws HiveAuthzPluginException, HiveAccessControlException {
// capture arguments in static
filterArguments = listObjs;
// return static variable with results, if it is set to some set of
// values
// otherwise return the arguments
if (filteredResults.size() == 0) {
return filterArguments;
}
return filteredResults;
}
}
@Override
public HiveAuthorizer createHiveAuthorizer(HiveMetastoreClientFactory metastoreClientFactory,
HiveConf conf, HiveAuthenticationProvider authenticator, HiveAuthzSessionContext ctx) {
Mockito.validateMockitoUsage();
mockedAuthorizer = Mockito.mock(AuthorizerWithFilterCmdImpl.class, Mockito.withSettings()
.verboseLogging());
try {
Mockito.when(
mockedAuthorizer.filterListCmdObjects((List<HivePrivilegeObject>) any(),
(HiveAuthzContext) any())).thenCallRealMethod();
} catch (Exception e) {
org.junit.Assert.fail("Caught exception " + e);
}
return mockedAuthorizer;
}
}
@BeforeClass
public static void beforeTest() throws Exception {
conf = new HiveConf();
// Turn on mocked authorization
conf.setVar(ConfVars.HIVE_AUTHORIZATION_MANAGER, MockedHiveAuthorizerFactory.class.getName());
conf.setVar(ConfVars.HIVE_AUTHENTICATOR_MANAGER, SessionStateUserAuthenticator.class.getName());
conf.setBoolVar(ConfVars.HIVE_AUTHORIZATION_ENABLED, true);
conf.setBoolVar(ConfVars.HIVE_SERVER2_ENABLE_DOAS, false);
conf.setBoolVar(ConfVars.HIVE_SUPPORT_CONCURRENCY, false);
UtilsForTest.setNewDerbyDbLocation(conf, TestHiveAuthorizerShowFilters.class.getSimpleName());
SessionState.start(conf);
driver = new Driver(conf);
runCmd("create table " + tableName1
+ " (i int, j int, k string) partitioned by (city string, `date` string) ");
runCmd("create table " + tableName2 + "(i int)");
runCmd("create database " + dbName1);
runCmd("create database " + dbName2);
}
@Before
public void setup() {
filterArguments = null;
filteredResults.clear();
}
@AfterClass
public static void afterTests() throws Exception {
// Drop the tables when we're done. This makes the test work inside an IDE
runCmd("drop table if exists " + tableName1);
runCmd("drop table if exists " + tableName2);
runCmd("drop database if exists " + dbName1);
runCmd("drop database if exists " + dbName2);
driver.close();
}
@Test
public void testShowDatabasesAll() throws HiveAuthzPluginException, HiveAccessControlException,
CommandNeedRetryException, IOException {
runShowDbTest(AllDbs);
}
@Test
public void testShowDatabasesSelected() throws HiveAuthzPluginException,
HiveAccessControlException, CommandNeedRetryException, IOException {
setFilteredResults(HivePrivilegeObjectType.DATABASE, dbName2);
runShowDbTest(Arrays.asList(dbName2));
}
private void runShowDbTest(List<String> expectedDbList) throws HiveAuthzPluginException,
HiveAccessControlException, CommandNeedRetryException, IOException {
runCmd("show databases");
verifyAllDb();
assertEquals("filtered result check ", expectedDbList, getSortedResults());
}
@Test
public void testShowTablesAll() throws HiveAuthzPluginException, HiveAccessControlException,
CommandNeedRetryException, IOException {
runShowTablesTest(AllTables);
}
@Test
public void testShowTablesSelected() throws HiveAuthzPluginException, HiveAccessControlException,
CommandNeedRetryException, IOException {
setFilteredResults(HivePrivilegeObjectType.TABLE_OR_VIEW, tableName2);
runShowTablesTest(Arrays.asList(tableName2));
}
private void runShowTablesTest(List<String> expectedTabs) throws IOException,
CommandNeedRetryException, HiveAuthzPluginException, HiveAccessControlException {
runCmd("show tables");
verifyAllTables();
assertEquals("filtered result check ", expectedTabs, getSortedResults());
}
private List<String> getSortedResults() throws IOException, CommandNeedRetryException {
List<String> res = new ArrayList<String>();
// set results to be returned
driver.getResults(res);
Collections.sort(res);
return res;
}
/**
* Verify that arguments to call to HiveAuthorizer.filterListCmdObjects are of
* type DATABASE and contain all databases.
*
* @throws HiveAccessControlException
* @throws HiveAuthzPluginException
*/
private void verifyAllDb() throws HiveAuthzPluginException, HiveAccessControlException {
List<HivePrivilegeObject> privObjs = filterArguments;
// get the db names out
List<String> dbArgs = new ArrayList<String>();
for (HivePrivilegeObject privObj : privObjs) {
assertEquals("Priv object type should be db", HivePrivilegeObjectType.DATABASE,
privObj.getType());
dbArgs.add(privObj.getDbname());
}
// sort before comparing with expected results
Collections.sort(dbArgs);
assertEquals("All db should be passed as arguments", AllDbs, dbArgs);
}
/**
* Verify that arguments to call to HiveAuthorizer.filterListCmdObjects are of
* type TABLE and contain all tables.
*
* @throws HiveAccessControlException
* @throws HiveAuthzPluginException
*/
private void verifyAllTables() throws HiveAuthzPluginException, HiveAccessControlException {
List<HivePrivilegeObject> privObjs = filterArguments;
// get the table names out
List<String> tables = new ArrayList<String>();
for (HivePrivilegeObject privObj : privObjs) {
assertEquals("Priv object type should be db", HivePrivilegeObjectType.TABLE_OR_VIEW,
privObj.getType());
assertEquals("Database name", "default", privObj.getDbname());
tables.add(privObj.getObjectName());
}
// sort before comparing with expected results
Collections.sort(tables);
assertEquals("All tables should be passed as arguments", AllTables, tables);
}
private static void setFilteredResults(HivePrivilegeObjectType type, String... objs) {
filteredResults.clear();
for (String obj : objs) {
String dbname;
String tabname = null;
if (type == HivePrivilegeObjectType.DATABASE) {
dbname = obj;
} else {
dbname = "default";
tabname = obj;
}
filteredResults.add(new HivePrivilegeObject(type, dbname, tabname));
}
}
private static void runCmd(String cmd) throws CommandNeedRetryException {
CommandProcessorResponse resp = driver.run(cmd);
assertEquals(0, resp.getResponseCode());
}
private static List<String> getSortedList(String... strings) {
return getSortedList(Arrays.asList(strings));
}
private static List<String> getSortedList(List<String> columns) {
List<String> sortedCols = new ArrayList<String>(columns);
Collections.sort(sortedCols);
return sortedCols;
}
}