/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.hadoop.hive.ql.security.authorization.plugin; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.verify; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.hive.conf.HiveConf; import org.apache.hadoop.hive.conf.HiveConf.ConfVars; import org.apache.hadoop.hive.ql.CommandNeedRetryException; import org.apache.hadoop.hive.ql.Driver; import org.apache.hadoop.hive.ql.lockmgr.DbTxnManager; import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse; import org.apache.hadoop.hive.ql.security.HiveAuthenticationProvider; import org.apache.hadoop.hive.ql.security.SessionStateUserAuthenticator; import org.apache.hadoop.hive.ql.security.authorization.plugin.HivePrivilegeObject.HivePrivilegeObjectType; import org.apache.hadoop.hive.ql.session.SessionState; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Mockito; /** * Test HiveAuthorizer api invocation */ public class TestHiveAuthorizerCheckInvocation { private final Logger LOG = LoggerFactory.getLogger(this.getClass().getName());; protected static HiveConf conf; protected static Driver driver; private static final String tableName = TestHiveAuthorizerCheckInvocation.class.getSimpleName() + "Table"; private static final String viewName = TestHiveAuthorizerCheckInvocation.class.getSimpleName() + "View"; private static final String inDbTableName = tableName + "_in_db"; private static final String acidTableName = tableName + "_acid"; private static final String dbName = TestHiveAuthorizerCheckInvocation.class.getSimpleName() + "Db"; static HiveAuthorizer mockedAuthorizer; /** * This factory creates a mocked HiveAuthorizer class. Use the mocked class to * capture the argument passed to it in the test case. */ static class MockedHiveAuthorizerFactory implements HiveAuthorizerFactory { @Override public HiveAuthorizer createHiveAuthorizer(HiveMetastoreClientFactory metastoreClientFactory, HiveConf conf, HiveAuthenticationProvider authenticator, HiveAuthzSessionContext ctx) { TestHiveAuthorizerCheckInvocation.mockedAuthorizer = Mockito.mock(HiveAuthorizer.class); return TestHiveAuthorizerCheckInvocation.mockedAuthorizer; } } @BeforeClass public static void beforeTest() throws Exception { conf = new HiveConf(); // Turn on mocked authorization conf.setVar(ConfVars.HIVE_AUTHORIZATION_MANAGER, MockedHiveAuthorizerFactory.class.getName()); conf.setVar(ConfVars.HIVE_AUTHENTICATOR_MANAGER, SessionStateUserAuthenticator.class.getName()); conf.setBoolVar(ConfVars.HIVE_AUTHORIZATION_ENABLED, true); conf.setBoolVar(ConfVars.HIVE_SERVER2_ENABLE_DOAS, false); conf.setBoolVar(ConfVars.HIVE_SUPPORT_CONCURRENCY, true); conf.setVar(ConfVars.HIVE_TXN_MANAGER, DbTxnManager.class.getName()); conf.setVar(HiveConf.ConfVars.HIVEMAPREDMODE, "nonstrict"); SessionState.start(conf); driver = new Driver(conf); runCmd("create table " + tableName + " (i int, j int, k string) partitioned by (city string, `date` string) "); runCmd("create view " + viewName + " as select * from " + tableName); runCmd("create database " + dbName); runCmd("create table " + dbName + "." + inDbTableName + "(i int)"); // Need a separate table for ACID testing since it has to be bucketed and it has to be Acid runCmd("create table " + acidTableName + " (i int, j int, k int) clustered by (k) into 2 buckets " + "stored as orc TBLPROPERTIES ('transactional'='true')"); } private static void runCmd(String cmd) throws CommandNeedRetryException { CommandProcessorResponse resp = driver.run(cmd); assertEquals(0, resp.getResponseCode()); } @AfterClass public static void afterTests() throws Exception { // Drop the tables when we're done. This makes the test work inside an IDE runCmd("drop table if exists " + acidTableName); runCmd("drop table if exists " + tableName); runCmd("drop table if exists " + viewName); runCmd("drop table if exists " + dbName + "." + inDbTableName); runCmd("drop database if exists " + dbName ); driver.close(); } @Test public void testInputSomeColumnsUsed() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { reset(mockedAuthorizer); int status = driver.compile("select i from " + tableName + " where k = 'X' and city = 'Scottsdale-AZ' "); assertEquals(0, status); List<HivePrivilegeObject> inputs = getHivePrivilegeObjectInputs().getLeft(); checkSingleTableInput(inputs); HivePrivilegeObject tableObj = inputs.get(0); assertEquals("no of columns used", 3, tableObj.getColumns().size()); assertEquals("Columns used", Arrays.asList("city", "i", "k"), getSortedList(tableObj.getColumns())); } @Test public void testInputSomeColumnsUsedView() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { reset(mockedAuthorizer); int status = driver.compile("select i from " + viewName + " where k = 'X' and city = 'Scottsdale-AZ' "); assertEquals(0, status); List<HivePrivilegeObject> inputs = getHivePrivilegeObjectInputs().getLeft(); checkSingleViewInput(inputs); HivePrivilegeObject tableObj = inputs.get(0); assertEquals("no of columns used", 3, tableObj.getColumns().size()); assertEquals("Columns used", Arrays.asList("city", "i", "k"), getSortedList(tableObj.getColumns())); } @Test public void testInputSomeColumnsUsedJoin() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { reset(mockedAuthorizer); int status = driver.compile("select " + viewName + ".i, " + tableName + ".city from " + viewName + " join " + tableName + " on " + viewName + ".city = " + tableName + ".city where " + tableName + ".k = 'X'"); assertEquals(0, status); List<HivePrivilegeObject> inputs = getHivePrivilegeObjectInputs().getLeft(); Collections.sort(inputs); assertEquals(inputs.size(), 2); HivePrivilegeObject tableObj = inputs.get(0); assertEquals(tableObj.getObjectName().toLowerCase(), tableName.toLowerCase()); assertEquals("no of columns used", 2, tableObj.getColumns().size()); assertEquals("Columns used", Arrays.asList("city", "k"), getSortedList(tableObj.getColumns())); tableObj = inputs.get(1); assertEquals(tableObj.getObjectName().toLowerCase(), viewName.toLowerCase()); assertEquals("no of columns used", 2, tableObj.getColumns().size()); assertEquals("Columns used", Arrays.asList("city", "i"), getSortedList(tableObj.getColumns())); } private List<String> getSortedList(List<String> columns) { List<String> sortedCols = new ArrayList<String>(columns); Collections.sort(sortedCols); return sortedCols; } @Test public void testInputAllColumnsUsed() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { reset(mockedAuthorizer); int status = driver.compile("select * from " + tableName + " order by i"); assertEquals(0, status); List<HivePrivilegeObject> inputs = getHivePrivilegeObjectInputs().getLeft(); checkSingleTableInput(inputs); HivePrivilegeObject tableObj = inputs.get(0); assertEquals("no of columns used", 5, tableObj.getColumns().size()); assertEquals("Columns used", Arrays.asList("city", "date", "i", "j", "k"), getSortedList(tableObj.getColumns())); } @Test public void testCreateTableWithDb() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { final String newTable = "ctTableWithDb"; checkCreateViewOrTableWithDb(newTable, "create table " + dbName + "." + newTable + "(i int)"); } @Test public void testCreateViewWithDb() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { final String newTable = "ctViewWithDb"; checkCreateViewOrTableWithDb(newTable, "create table " + dbName + "." + newTable + "(i int)"); } private void checkCreateViewOrTableWithDb(String newTable, String cmd) throws HiveAuthzPluginException, HiveAccessControlException { reset(mockedAuthorizer); int status = driver.compile(cmd); assertEquals(0, status); List<HivePrivilegeObject> outputs = getHivePrivilegeObjectInputs().getRight(); assertEquals("num outputs", 2, outputs.size()); for (HivePrivilegeObject output : outputs) { switch (output.getType()) { case DATABASE: assertTrue("database name", output.getDbname().equalsIgnoreCase(dbName)); break; case TABLE_OR_VIEW: assertTrue("database name", output.getDbname().equalsIgnoreCase(dbName)); assertEqualsIgnoreCase("table name", output.getObjectName(), newTable); break; default: fail("Unexpected type : " + output.getType()); } } } private void assertEqualsIgnoreCase(String msg, String expected, String actual) { assertEquals(msg, expected.toLowerCase(), actual.toLowerCase()); } @Test public void testInputNoColumnsUsed() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { reset(mockedAuthorizer); int status = driver.compile("describe " + tableName); assertEquals(0, status); List<HivePrivilegeObject> inputs = getHivePrivilegeObjectInputs().getLeft(); checkSingleTableInput(inputs); HivePrivilegeObject tableObj = inputs.get(0); assertNull("columns used", tableObj.getColumns()); } @Test public void testPermFunction() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { reset(mockedAuthorizer); final String funcName = "testauthfunc1"; int status = driver.compile("create function " + dbName + "." + funcName + " as 'org.apache.hadoop.hive.ql.udf.UDFPI'"); assertEquals(0, status); List<HivePrivilegeObject> outputs = getHivePrivilegeObjectInputs().getRight(); HivePrivilegeObject funcObj; HivePrivilegeObject dbObj; assertEquals("number of output object", 2, outputs.size()); if(outputs.get(0).getType() == HivePrivilegeObjectType.FUNCTION) { funcObj = outputs.get(0); dbObj = outputs.get(1); } else { funcObj = outputs.get(1); dbObj = outputs.get(0); } assertEquals("input type", HivePrivilegeObjectType.FUNCTION, funcObj.getType()); assertTrue("function name", funcName.equalsIgnoreCase(funcObj.getObjectName())); assertTrue("db name", dbName.equalsIgnoreCase(funcObj.getDbname())); assertEquals("input type", HivePrivilegeObjectType.DATABASE, dbObj.getType()); assertTrue("db name", dbName.equalsIgnoreCase(dbObj.getDbname())); } @Test public void testTempFunction() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { reset(mockedAuthorizer); final String funcName = "testAuthFunc2"; int status = driver.compile("create temporary function " + funcName + " as 'org.apache.hadoop.hive.ql.udf.UDFPI'"); assertEquals(0, status); List<HivePrivilegeObject> outputs = getHivePrivilegeObjectInputs().getRight(); HivePrivilegeObject funcObj = outputs.get(0); assertEquals("input type", HivePrivilegeObjectType.FUNCTION, funcObj.getType()); assertTrue("function name", funcName.equalsIgnoreCase(funcObj.getObjectName())); assertEquals("db name", null, funcObj.getDbname()); } @Test public void testUpdateSomeColumnsUsed() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { reset(mockedAuthorizer); int status = driver.compile("update " + acidTableName + " set i = 5 where j = 3"); assertEquals(0, status); Pair<List<HivePrivilegeObject>, List<HivePrivilegeObject>> io = getHivePrivilegeObjectInputs(); List<HivePrivilegeObject> outputs = io.getRight(); HivePrivilegeObject tableObj = outputs.get(0); LOG.debug("Got privilege object " + tableObj); assertEquals("no of columns used", 1, tableObj.getColumns().size()); assertEquals("Column used", "i", tableObj.getColumns().get(0)); List<HivePrivilegeObject> inputs = io.getLeft(); assertEquals(1, inputs.size()); tableObj = inputs.get(0); assertEquals(2, tableObj.getColumns().size()); assertEquals("j", tableObj.getColumns().get(0)); } @Test public void testUpdateSomeColumnsUsedExprInSet() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { reset(mockedAuthorizer); int status = driver.compile("update " + acidTableName + " set i = 5, j = k where j = 3"); assertEquals(0, status); Pair<List<HivePrivilegeObject>, List<HivePrivilegeObject>> io = getHivePrivilegeObjectInputs(); List<HivePrivilegeObject> outputs = io.getRight(); HivePrivilegeObject tableObj = outputs.get(0); LOG.debug("Got privilege object " + tableObj); assertEquals("no of columns used", 2, tableObj.getColumns().size()); assertEquals("Columns used", Arrays.asList("i", "j"), getSortedList(tableObj.getColumns())); List<HivePrivilegeObject> inputs = io.getLeft(); assertEquals(1, inputs.size()); tableObj = inputs.get(0); assertEquals(2, tableObj.getColumns().size()); assertEquals("Columns used", Arrays.asList("j", "k"), getSortedList(tableObj.getColumns())); } @Test public void testDelete() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { reset(mockedAuthorizer); int status = driver.compile("delete from " + acidTableName + " where j = 3"); assertEquals(0, status); Pair<List<HivePrivilegeObject>, List<HivePrivilegeObject>> io = getHivePrivilegeObjectInputs(); List<HivePrivilegeObject> inputs = io.getLeft(); assertEquals(1, inputs.size()); HivePrivilegeObject tableObj = inputs.get(0); assertEquals(1, tableObj.getColumns().size()); assertEquals("j", tableObj.getColumns().get(0)); } @Test public void testShowTables() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { reset(mockedAuthorizer); int status = driver.compile("show tables"); assertEquals(0, status); Pair<List<HivePrivilegeObject>, List<HivePrivilegeObject>> io = getHivePrivilegeObjectInputs(); List<HivePrivilegeObject> inputs = io.getLeft(); assertEquals(1, inputs.size()); HivePrivilegeObject dbObj = inputs.get(0); assertEquals("default", dbObj.getDbname().toLowerCase()); } @Test public void testDescDatabase() throws HiveAuthzPluginException, HiveAccessControlException, CommandNeedRetryException { reset(mockedAuthorizer); int status = driver.compile("describe database " + dbName); assertEquals(0, status); Pair<List<HivePrivilegeObject>, List<HivePrivilegeObject>> io = getHivePrivilegeObjectInputs(); List<HivePrivilegeObject> inputs = io.getLeft(); assertEquals(1, inputs.size()); HivePrivilegeObject dbObj = inputs.get(0); assertEquals(dbName.toLowerCase(), dbObj.getDbname().toLowerCase()); } private void checkSingleTableInput(List<HivePrivilegeObject> inputs) { assertEquals("number of inputs", 1, inputs.size()); HivePrivilegeObject tableObj = inputs.get(0); assertEquals("input type", HivePrivilegeObjectType.TABLE_OR_VIEW, tableObj.getType()); assertTrue("table name", tableName.equalsIgnoreCase(tableObj.getObjectName())); } private void checkSingleViewInput(List<HivePrivilegeObject> inputs) { assertEquals("number of inputs", 1, inputs.size()); HivePrivilegeObject tableObj = inputs.get(0); assertEquals("input type", HivePrivilegeObjectType.TABLE_OR_VIEW, tableObj.getType()); assertTrue("table name", viewName.equalsIgnoreCase(tableObj.getObjectName())); } /** * @return pair with left value as inputs and right value as outputs, * passed in current call to authorizer.checkPrivileges * @throws HiveAuthzPluginException * @throws HiveAccessControlException */ private Pair<List<HivePrivilegeObject>, List<HivePrivilegeObject>> getHivePrivilegeObjectInputs() throws HiveAuthzPluginException, HiveAccessControlException { // Create argument capturer // a class variable cast to this generic of generic class Class<List<HivePrivilegeObject>> class_listPrivObjects = (Class) List.class; ArgumentCaptor<List<HivePrivilegeObject>> inputsCapturer = ArgumentCaptor .forClass(class_listPrivObjects); ArgumentCaptor<List<HivePrivilegeObject>> outputsCapturer = ArgumentCaptor .forClass(class_listPrivObjects); verify(mockedAuthorizer).checkPrivileges(any(HiveOperationType.class), inputsCapturer.capture(), outputsCapturer.capture(), any(HiveAuthzContext.class)); return new ImmutablePair(inputsCapturer.getValue(), outputsCapturer.getValue()); } }