/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.flink.test.optimizer.jsonplan; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import org.apache.flink.api.common.JobExecutionResult; import org.apache.flink.api.common.Plan; import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.ExecutionEnvironmentFactory; import org.apache.flink.configuration.ConfigConstants; import org.apache.flink.configuration.Configuration; import org.apache.flink.examples.java.clustering.KMeans; import org.apache.flink.examples.java.graph.ConnectedComponents; import org.apache.flink.examples.java.relational.WebLogAnalysis; import org.apache.flink.examples.java.wordcount.WordCount; import org.apache.flink.optimizer.Optimizer; import org.apache.flink.optimizer.plan.OptimizedPlan; import org.apache.flink.optimizer.plantranslate.JobGraphGenerator; import org.apache.flink.runtime.jobgraph.JobGraph; import org.apache.flink.runtime.jobgraph.jsonplan.JsonPlanGenerator; import org.junit.After; import org.junit.Before; import org.junit.Test; import java.io.OutputStream; import java.io.PrintStream; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import static org.junit.Assert.*; public class JsonJobGraphGenerationTest { private PrintStream out; private PrintStream err; @Before public void redirectStreams() { this.out = System.out; this.err = System.err; OutputStream discards = new OutputStream() { @Override public void write(int b) {} }; System.setOut(new PrintStream(discards)); System.setErr(new PrintStream(discards)); } @After public void restoreStreams() { if (out != null) { System.setOut(out); } if (err != null) { System.setOut(err); } } @Test public void testWordCountPlan() { try { // without arguments try { final int parallelism = 1; // some ops have DOP 1 forced JsonValidator validator = new GenericValidator(parallelism, 3); TestingExecutionEnvironment.setAsNext(validator, parallelism); WordCount.main(new String[0]); } catch (AbortError ignored) {} // with arguments try { final int parallelism = 17; JsonValidator validator = new GenericValidator(parallelism, 3); TestingExecutionEnvironment.setAsNext(validator, parallelism); String tmpDir = ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH; WordCount.main(new String[] { "--input", tmpDir, "--output", tmpDir}); } catch (AbortError ignored) {} } catch (Exception e) { restoreStreams(); e.printStackTrace(); fail(e.getMessage()); } } @Test public void testWebLogAnalysis() { try { // without arguments try { final int parallelism = 1; // some ops have DOP 1 forced JsonValidator validator = new GenericValidator(parallelism, 6); TestingExecutionEnvironment.setAsNext(validator, parallelism); WebLogAnalysis.main(new String[0]); } catch (AbortError ignored) {} // with arguments try { final int parallelism = 17; JsonValidator validator = new GenericValidator(parallelism, 6); TestingExecutionEnvironment.setAsNext(validator, parallelism); String tmpDir = ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH; WebLogAnalysis.main(new String[] { "--documents", tmpDir, "--ranks", tmpDir, "--visits", tmpDir, "--output", tmpDir}); } catch (AbortError ignored) {} } catch (Exception e) { restoreStreams(); e.printStackTrace(); fail(e.getMessage()); } } @Test public void testKMeans() { try { // without arguments try { final int parallelism = 1; // some ops have DOP 1 forced JsonValidator validator = new GenericValidator(parallelism, 9); TestingExecutionEnvironment.setAsNext(validator, parallelism); KMeans.main(new String[0]); } catch (AbortError ignored) {} // with arguments try { final int parallelism = 42; JsonValidator validator = new GenericValidator(parallelism, 9); TestingExecutionEnvironment.setAsNext(validator, parallelism); String tmpDir = ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH; KMeans.main(new String[] { "--points", tmpDir, "--centroids", tmpDir, "--output", tmpDir, "--iterations", "100"}); } catch (AbortError ignored) {} } catch (Exception e) { restoreStreams(); e.printStackTrace(); fail(e.getMessage()); } } @Test public void testConnectedComponents() { try { // without arguments try { final int parallelism = 1; // some ops have DOP 1 forced JsonValidator validator = new GenericValidator(parallelism, 9); TestingExecutionEnvironment.setAsNext(validator, parallelism); ConnectedComponents.main(); } catch (AbortError ignored) {} // with arguments try { final int parallelism = 23; JsonValidator validator = new GenericValidator(parallelism, 9); TestingExecutionEnvironment.setAsNext(validator, parallelism); String tmpDir = ConfigConstants.DEFAULT_TASK_MANAGER_TMP_PATH; ConnectedComponents.main( "--vertices", tmpDir, "--edges", tmpDir, "--output", tmpDir, "--iterations", "100"); } catch (AbortError ignored) {} } catch (Exception e) { restoreStreams(); e.printStackTrace(); fail(e.getMessage()); } } // ------------------------------------------------------------------------ private static interface JsonValidator { void validateJson(String json) throws Exception; } private static class GenericValidator implements JsonValidator { private final int expectedParallelism; private final int numNodes; GenericValidator(int expectedParallelism, int numNodes) { this.expectedParallelism = expectedParallelism; this.numNodes = numNodes; } @Override public void validateJson(String json) throws Exception { final Map<String, JsonNode> idToNode = new HashMap<>(); // validate the produced JSON ObjectMapper m = new ObjectMapper(); JsonNode rootNode = m.readTree(json); JsonNode idField = rootNode.get("jid"); JsonNode nameField = rootNode.get("name"); JsonNode arrayField = rootNode.get("nodes"); assertNotNull(idField); assertNotNull(nameField); assertNotNull(arrayField); assertTrue(idField.isTextual()); assertTrue(nameField.isTextual()); assertTrue(arrayField.isArray()); ArrayNode array = (ArrayNode) arrayField; for (Iterator<JsonNode> iter = array.elements(); iter.hasNext(); ) { JsonNode vertex = iter.next(); JsonNode vertexIdField = vertex.get("id"); JsonNode parallelismField = vertex.get("parallelism"); JsonNode contentsFields = vertex.get("description"); JsonNode operatorField = vertex.get("operator"); assertNotNull(vertexIdField); assertTrue(vertexIdField.isTextual()); assertNotNull(parallelismField); assertTrue(parallelismField.isNumber()); assertNotNull(contentsFields); assertTrue(contentsFields.isTextual()); assertNotNull(operatorField); assertTrue(operatorField.isTextual()); if (contentsFields.asText().startsWith("Sync")) { assertEquals(1, parallelismField.asInt()); } else { assertEquals(expectedParallelism, parallelismField.asInt()); } idToNode.put(vertexIdField.asText(), vertex); } assertEquals(numNodes, idToNode.size()); // check that all inputs are contained for (JsonNode node : idToNode.values()) { JsonNode inputsField = node.get("inputs"); if (inputsField != null) { for (Iterator<JsonNode> inputsIter = inputsField.elements(); inputsIter.hasNext(); ) { JsonNode inputNode = inputsIter.next(); JsonNode inputIdField = inputNode.get("id"); assertNotNull(inputIdField); assertTrue(inputIdField.isTextual()); String inputIdString = inputIdField.asText(); assertTrue(idToNode.containsKey(inputIdString)); } } } } } // ------------------------------------------------------------------------ private static class AbortError extends Error { private static final long serialVersionUID = 152179957828703919L; } // ------------------------------------------------------------------------ private static class TestingExecutionEnvironment extends ExecutionEnvironment { private final JsonValidator validator; private TestingExecutionEnvironment(JsonValidator validator) { this.validator = validator; } @Override public void startNewSession() throws Exception { } @Override public JobExecutionResult execute(String jobName) throws Exception { Plan plan = createProgramPlan(jobName); Optimizer pc = new Optimizer(new Configuration()); OptimizedPlan op = pc.compile(plan); JobGraphGenerator jgg = new JobGraphGenerator(); JobGraph jobGraph = jgg.compileJobGraph(op); String jsonPlan = JsonPlanGenerator.generatePlan(jobGraph); // first check that the JSON is valid JsonParser parser = new JsonFactory().createJsonParser(jsonPlan); while (parser.nextToken() != null); validator.validateJson(jsonPlan); throw new AbortError(); } @Override public String getExecutionPlan() throws Exception { throw new UnsupportedOperationException(); } public static void setAsNext(final JsonValidator validator, final int defaultParallelism) { initializeContextEnvironment(new ExecutionEnvironmentFactory() { @Override public ExecutionEnvironment createExecutionEnvironment() { ExecutionEnvironment env = new TestingExecutionEnvironment(validator); env.setParallelism(defaultParallelism); return env; } }); } } }