/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.util.hive;
import com.google.common.util.concurrent.UncheckedExecutionException;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hive.hcatalog.streaming.HiveEndPoint;
import org.apache.hive.hcatalog.streaming.InvalidTable;
import org.apache.hive.hcatalog.streaming.RecordWriter;
import org.apache.hive.hcatalog.streaming.StreamingConnection;
import org.apache.hive.hcatalog.streaming.StreamingException;
import org.apache.hive.hcatalog.streaming.TransactionBatch;
import org.junit.Before;
import org.junit.Test;
import org.mockito.stubbing.Answer;
import java.io.IOException;
import java.lang.reflect.UndeclaredThrowableException;
import java.security.PrivilegedExceptionAction;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyLong;
import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class HiveWriterTest {
private HiveEndPoint hiveEndPoint;
private int txnsPerBatch;
private boolean autoCreatePartitions;
private int callTimeout;
private ExecutorService executorService;
private UserGroupInformation userGroupInformation;
private HiveConf hiveConf;
private HiveWriter hiveWriter;
private StreamingConnection streamingConnection;
private RecordWriter recordWriter;
private Callable<RecordWriter> recordWriterCallable;
private TransactionBatch transactionBatch;
@Before
public void setup() throws Exception {
hiveEndPoint = mock(HiveEndPoint.class);
txnsPerBatch = 100;
autoCreatePartitions = true;
callTimeout = 0;
executorService = mock(ExecutorService.class);
streamingConnection = mock(StreamingConnection.class);
transactionBatch = mock(TransactionBatch.class);
userGroupInformation = mock(UserGroupInformation.class);
hiveConf = mock(HiveConf.class);
recordWriter = mock(RecordWriter.class);
recordWriterCallable = mock(Callable.class);
when(recordWriterCallable.call()).thenReturn(recordWriter);
when(hiveEndPoint.newConnection(autoCreatePartitions, hiveConf, userGroupInformation)).thenReturn(streamingConnection);
when(streamingConnection.fetchTransactionBatch(txnsPerBatch, recordWriter)).thenReturn(transactionBatch);
when(executorService.submit(isA(Callable.class))).thenAnswer(invocation -> {
Future future = mock(Future.class);
Answer<Object> answer = i -> ((Callable) invocation.getArguments()[0]).call();
when(future.get()).thenAnswer(answer);
when(future.get(anyLong(), any(TimeUnit.class))).thenAnswer(answer);
return future;
});
when(userGroupInformation.doAs(isA(PrivilegedExceptionAction.class))).thenAnswer(invocation -> {
try {
try {
return ((PrivilegedExceptionAction) invocation.getArguments()[0]).run();
} catch (UncheckedExecutionException e) {
// Creation of strict json writer will fail due to external deps, this gives us chance to catch it
for (StackTraceElement stackTraceElement : e.getStackTrace()) {
if (stackTraceElement.toString().startsWith("org.apache.hive.hcatalog.streaming.StrictJsonWriter.<init>(")) {
return recordWriterCallable.call();
}
}
throw e;
}
} catch (IOException | Error | RuntimeException | InterruptedException e) {
throw e;
} catch (Throwable e) {
throw new UndeclaredThrowableException(e);
}
});
initWriter();
}
private void initWriter() throws Exception {
hiveWriter = new HiveWriter(hiveEndPoint, txnsPerBatch, autoCreatePartitions, callTimeout, executorService, userGroupInformation, hiveConf);
}
@Test
public void testNormal() {
assertNotNull(hiveWriter);
}
@Test(expected = HiveWriter.ConnectFailure.class)
public void testNewConnectionInvalidTable() throws Exception {
hiveEndPoint = mock(HiveEndPoint.class);
InvalidTable invalidTable = new InvalidTable("badDb", "badTable");
when(hiveEndPoint.newConnection(autoCreatePartitions, hiveConf, userGroupInformation)).thenThrow(invalidTable);
try {
initWriter();
} catch (HiveWriter.ConnectFailure e) {
assertEquals(invalidTable, e.getCause());
throw e;
}
}
@Test(expected = HiveWriter.ConnectFailure.class)
public void testRecordWriterStreamingException() throws Exception {
recordWriterCallable = mock(Callable.class);
StreamingException streamingException = new StreamingException("Test Exception");
when(recordWriterCallable.call()).thenThrow(streamingException);
try {
initWriter();
} catch (HiveWriter.ConnectFailure e) {
assertEquals(streamingException, e.getCause());
throw e;
}
}
}