/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.runtime.operators;
import org.apache.flink.api.common.io.FileOutputFormat;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.io.network.partition.consumer.IteratorWrappingTestSingleInputGate;
import org.apache.flink.runtime.operators.testutils.InfiniteInputIterator;
import org.apache.flink.runtime.operators.testutils.TaskCancelThread;
import org.apache.flink.runtime.operators.testutils.TaskTestBase;
import org.apache.flink.runtime.operators.testutils.UniformRecordGenerator;
import org.apache.flink.runtime.operators.util.LocalStrategy;
import org.apache.flink.runtime.testutils.recordutils.RecordComparatorFactory;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
public class DataSinkTaskTest extends TaskTestBase {
private static final Logger LOG = LoggerFactory.getLogger(DataSinkTaskTest.class);
private static final int MEMORY_MANAGER_SIZE = 3 * 1024 * 1024;
private static final int NETWORK_BUFFER_SIZE = 1024;
private final String tempTestPath = constructTestPath(DataSinkTaskTest.class, "dst_test");
@After
public void cleanUp() {
File tempTestFile = new File(this.tempTestPath);
if(tempTestFile.exists()) {
tempTestFile.delete();
}
}
@Test
public void testDataSinkTask() {
FileReader fr = null;
BufferedReader br = null;
try {
int keyCnt = 100;
int valCnt = 20;
super.initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
super.addInput(new UniformRecordGenerator(keyCnt, valCnt, false), 0);
DataSinkTask<Record> testTask = new DataSinkTask<>();
super.registerFileOutputTask(testTask, MockOutputFormat.class, new File(tempTestPath).toURI().toString());
testTask.invoke();
File tempTestFile = new File(this.tempTestPath);
Assert.assertTrue("Temp output file does not exist", tempTestFile.exists());
fr = new FileReader(tempTestFile);
br = new BufferedReader(fr);
HashMap<Integer, HashSet<Integer>> keyValueCountMap = new HashMap<>(keyCnt);
while (br.ready()) {
String line = br.readLine();
Integer key = Integer.parseInt(line.substring(0, line.indexOf("_")));
Integer val = Integer.parseInt(line.substring(line.indexOf("_") + 1, line.length()));
if (!keyValueCountMap.containsKey(key)) {
keyValueCountMap.put(key, new HashSet<Integer>());
}
keyValueCountMap.get(key).add(val);
}
Assert.assertTrue("Invalid key count in out file. Expected: " + keyCnt + " Actual: " + keyValueCountMap.keySet().size(),
keyValueCountMap.keySet().size() == keyCnt);
for (Integer key : keyValueCountMap.keySet()) {
Assert.assertTrue("Invalid value count for key: " + key + ". Expected: " + valCnt + " Actual: " + keyValueCountMap.get(key).size(),
keyValueCountMap.get(key).size() == valCnt);
}
}
catch (Exception e) {
e.printStackTrace();
Assert.fail(e.getMessage());
}
finally {
if (br != null) {
try { br.close(); } catch (Throwable t) {}
}
if (fr != null) {
try { fr.close(); } catch (Throwable t) {}
}
}
}
@Test
public void testUnionDataSinkTask() {
int keyCnt = 10;
int valCnt = 20;
super.initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
final IteratorWrappingTestSingleInputGate<?>[] readers = new IteratorWrappingTestSingleInputGate[4];
readers[0] = super.addInput(new UniformRecordGenerator(keyCnt, valCnt, 0, 0, false), 0, false);
readers[1] = super.addInput(new UniformRecordGenerator(keyCnt, valCnt, keyCnt, 0, false), 0, false);
readers[2] = super.addInput(new UniformRecordGenerator(keyCnt, valCnt, keyCnt * 2, 0, false), 0, false);
readers[3] = super.addInput(new UniformRecordGenerator(keyCnt, valCnt, keyCnt * 3, 0, false), 0, false);
DataSinkTask<Record> testTask = new DataSinkTask<>();
super.registerFileOutputTask(testTask, MockOutputFormat.class, new File(tempTestPath).toURI().toString());
try {
// For the union reader to work, we need to start notifications *after* the union reader
// has been initialized. This is accomplished via a mockito hack in TestSingleInputGate,
// which checks forwards existing notifications on registerListener calls.
for (IteratorWrappingTestSingleInputGate<?> reader : readers) {
reader.notifyNonEmpty();
}
testTask.invoke();
} catch (Exception e) {
LOG.debug("Exception while invoking the test task.", e);
Assert.fail("Invoke method caused exception.");
}
File tempTestFile = new File(this.tempTestPath);
Assert.assertTrue("Temp output file does not exist",tempTestFile.exists());
FileReader fr = null;
BufferedReader br = null;
try {
fr = new FileReader(tempTestFile);
br = new BufferedReader(fr);
HashMap<Integer,HashSet<Integer>> keyValueCountMap = new HashMap<>(keyCnt);
while(br.ready()) {
String line = br.readLine();
Integer key = Integer.parseInt(line.substring(0,line.indexOf("_")));
Integer val = Integer.parseInt(line.substring(line.indexOf("_")+1,line.length()));
if(!keyValueCountMap.containsKey(key)) {
keyValueCountMap.put(key,new HashSet<Integer>());
}
keyValueCountMap.get(key).add(val);
}
Assert.assertTrue("Invalid key count in out file. Expected: "+keyCnt+" Actual: "+keyValueCountMap.keySet().size(),
keyValueCountMap.keySet().size() == keyCnt * 4);
for(Integer key : keyValueCountMap.keySet()) {
Assert.assertTrue("Invalid value count for key: "+key+". Expected: "+valCnt+" Actual: "+keyValueCountMap.get(key).size(),
keyValueCountMap.get(key).size() == valCnt);
}
} catch (FileNotFoundException e) {
Assert.fail("Out file got lost...");
} catch (IOException ioe) {
Assert.fail("Caught IOE while reading out file");
} finally {
if (br != null) {
try { br.close(); } catch (Throwable t) {}
}
if (fr != null) {
try { fr.close(); } catch (Throwable t) {}
}
}
}
@Test
@SuppressWarnings("unchecked")
public void testSortingDataSinkTask() {
int keyCnt = 100;
int valCnt = 20;
double memoryFraction = 1.0;
super.initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
super.addInput(new UniformRecordGenerator(keyCnt, valCnt, true), 0);
DataSinkTask<Record> testTask = new DataSinkTask<>();
// set sorting
super.getTaskConfig().setInputLocalStrategy(0, LocalStrategy.SORT);
super.getTaskConfig().setInputComparator(
new RecordComparatorFactory(new int[]{1},(new Class[]{IntValue.class})), 0);
super.getTaskConfig().setRelativeMemoryInput(0, memoryFraction);
super.getTaskConfig().setFilehandlesInput(0, 8);
super.getTaskConfig().setSpillingThresholdInput(0, 0.8f);
super.registerFileOutputTask(testTask, MockOutputFormat.class, new File(tempTestPath).toURI().toString());
try {
testTask.invoke();
} catch (Exception e) {
LOG.debug("Exception while invoking the test task.", e);
Assert.fail("Invoke method caused exception.");
}
File tempTestFile = new File(this.tempTestPath);
Assert.assertTrue("Temp output file does not exist",tempTestFile.exists());
FileReader fr = null;
BufferedReader br = null;
try {
fr = new FileReader(tempTestFile);
br = new BufferedReader(fr);
Set<Integer> keys = new HashSet<>();
int curVal = -1;
while(br.ready()) {
String line = br.readLine();
Integer key = Integer.parseInt(line.substring(0,line.indexOf("_")));
Integer val = Integer.parseInt(line.substring(line.indexOf("_")+1,line.length()));
// check that values are in correct order
Assert.assertTrue("Values not in ascending order", val >= curVal);
// next value hit
if(val > curVal) {
if(curVal != -1) {
// check that we saw 100 distinct keys for this values
Assert.assertTrue("Keys missing for value", keys.size() == 100);
}
// empty keys set
keys.clear();
// update current value
curVal = val;
}
Assert.assertTrue("Duplicate key for value", keys.add(key));
}
} catch (FileNotFoundException e) {
Assert.fail("Out file got lost...");
} catch (IOException ioe) {
Assert.fail("Caught IOE while reading out file");
} finally {
if (br != null) {
try { br.close(); } catch (Throwable t) {}
}
if (fr != null) {
try { fr.close(); } catch (Throwable t) {}
}
}
}
@Test
public void testFailingDataSinkTask() {
int keyCnt = 100;
int valCnt = 20;
super.initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
super.addInput(new UniformRecordGenerator(keyCnt, valCnt, false), 0);
DataSinkTask<Record> testTask = new DataSinkTask<>();
Configuration stubParams = new Configuration();
super.getTaskConfig().setStubParameters(stubParams);
super.registerFileOutputTask(testTask, MockFailingOutputFormat.class, new File(tempTestPath).toURI().toString());
boolean stubFailed = false;
try {
testTask.invoke();
} catch (Exception e) {
stubFailed = true;
}
Assert.assertTrue("Function exception was not forwarded.", stubFailed);
// assert that temp file was removed
File tempTestFile = new File(this.tempTestPath);
Assert.assertFalse("Temp output file has not been removed", tempTestFile.exists());
}
@Test
@SuppressWarnings("unchecked")
public void testFailingSortingDataSinkTask() {
int keyCnt = 100;
int valCnt = 20;
double memoryFraction = 1.0;
super.initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
super.addInput(new UniformRecordGenerator(keyCnt, valCnt, true), 0);
DataSinkTask<Record> testTask = new DataSinkTask<>();
Configuration stubParams = new Configuration();
super.getTaskConfig().setStubParameters(stubParams);
// set sorting
super.getTaskConfig().setInputLocalStrategy(0, LocalStrategy.SORT);
super.getTaskConfig().setInputComparator(
new RecordComparatorFactory(new int[]{1}, ( new Class[]{IntValue.class})), 0);
super.getTaskConfig().setRelativeMemoryInput(0, memoryFraction);
super.getTaskConfig().setFilehandlesInput(0, 8);
super.getTaskConfig().setSpillingThresholdInput(0, 0.8f);
super.registerFileOutputTask(testTask, MockFailingOutputFormat.class, new File(tempTestPath).toURI().toString());
boolean stubFailed = false;
try {
testTask.invoke();
} catch (Exception e) {
stubFailed = true;
}
Assert.assertTrue("Function exception was not forwarded.", stubFailed);
// assert that temp file was removed
File tempTestFile = new File(this.tempTestPath);
Assert.assertFalse("Temp output file has not been removed", tempTestFile.exists());
}
@Test
public void testCancelDataSinkTask() throws Exception {
super.initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
super.addInput(new InfiniteInputIterator(), 0);
final DataSinkTask<Record> testTask = new DataSinkTask<>();
Configuration stubParams = new Configuration();
super.getTaskConfig().setStubParameters(stubParams);
super.registerFileOutputTask(testTask, MockOutputFormat.class, new File(tempTestPath).toURI().toString());
Thread taskRunner = new Thread() {
@Override
public void run() {
try {
testTask.invoke();
} catch (Exception ie) {
ie.printStackTrace();
Assert.fail("Task threw exception although it was properly canceled");
}
}
};
taskRunner.start();
File tempTestFile = new File(this.tempTestPath);
// wait until the task created the file
long deadline = System.currentTimeMillis() + 60000;
while (!tempTestFile.exists() && System.currentTimeMillis() < deadline) {
Thread.sleep(10);
}
assertTrue("Task did not create file within 60 seconds", tempTestFile.exists());
// cancel the task
Thread.sleep(500);
testTask.cancel();
taskRunner.interrupt();
// wait for the canceling to complete
taskRunner.join();
// assert that temp file was created
assertFalse("Temp output file has not been removed", tempTestFile.exists());
}
@Test
@SuppressWarnings("unchecked")
public void testCancelSortingDataSinkTask() {
double memoryFraction = 1.0;
super.initEnvironment(MEMORY_MANAGER_SIZE, NETWORK_BUFFER_SIZE);
super.addInput(new InfiniteInputIterator(), 0);
final DataSinkTask<Record> testTask = new DataSinkTask<>();
Configuration stubParams = new Configuration();
super.getTaskConfig().setStubParameters(stubParams);
// set sorting
super.getTaskConfig().setInputLocalStrategy(0, LocalStrategy.SORT);
super.getTaskConfig().setInputComparator(
new RecordComparatorFactory(new int[]{1},(new Class[]{IntValue.class})), 0);
super.getTaskConfig().setRelativeMemoryInput(0, memoryFraction);
super.getTaskConfig().setFilehandlesInput(0, 8);
super.getTaskConfig().setSpillingThresholdInput(0, 0.8f);
super.registerFileOutputTask(testTask, MockOutputFormat.class, new File(tempTestPath).toURI().toString());
Thread taskRunner = new Thread() {
@Override
public void run() {
try {
testTask.invoke();
} catch (Exception ie) {
ie.printStackTrace();
Assert.fail("Task threw exception although it was properly canceled");
}
}
};
taskRunner.start();
TaskCancelThread tct = new TaskCancelThread(2, taskRunner, testTask);
tct.start();
try {
tct.join();
taskRunner.join();
} catch(InterruptedException ie) {
Assert.fail("Joining threads failed");
}
}
public static class MockOutputFormat extends FileOutputFormat<Record> {
private static final long serialVersionUID = 1L;
final StringBuilder bld = new StringBuilder();
@Override
public void configure(Configuration parameters) {
super.configure(parameters);
}
@Override
public void writeRecord(Record rec) throws IOException {
IntValue key = rec.getField(0, IntValue.class);
IntValue value = rec.getField(1, IntValue.class);
this.bld.setLength(0);
this.bld.append(key.getValue());
this.bld.append('_');
this.bld.append(value.getValue());
this.bld.append('\n');
byte[] bytes = this.bld.toString().getBytes(ConfigConstants.DEFAULT_CHARSET);
this.stream.write(bytes);
}
}
public static class MockFailingOutputFormat extends MockOutputFormat {
private static final long serialVersionUID = 1L;
int cnt = 0;
@Override
public void configure(Configuration parameters) {
super.configure(parameters);
}
@Override
public void writeRecord(Record rec) throws IOException {
if (++this.cnt >= 10) {
throw new RuntimeException("Expected Test Exception");
}
super.writeRecord(rec);
}
}
public static String constructTestPath(Class<?> forClass, String folder) {
// we create test path that depends on class to prevent name clashes when two tests
// create temp files with the same name
String path = System.getProperty("java.io.tmpdir");
if (!(path.endsWith("/") || path.endsWith("\\")) ) {
path += System.getProperty("file.separator");
}
path += (forClass.getName() + "-" + folder);
return path;
}
public static String constructTestURI(Class<?> forClass, String folder) {
return new File(constructTestPath(forClass, folder)).toURI().toString();
}
}