/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package com.aliyun.odps.udf.local.runner; import java.io.IOException; import java.util.ArrayList; import java.util.LinkedList; import java.util.List; import org.junit.Assert; import com.aliyun.odps.Odps; import com.aliyun.odps.local.common.Constants; import com.aliyun.odps.local.common.WareHouse; import com.aliyun.odps.local.common.security.ApplicatitionType; import com.aliyun.odps.local.common.security.SecurityClient; import com.aliyun.odps.udf.UDF; import com.aliyun.odps.udf.local.LocalExecutionContext; import com.aliyun.odps.udf.local.LocalRunException; import com.aliyun.odps.udf.local.datasource.InputSource; public abstract class BaseRunner { boolean hasClosed = false; public BaseRunner(Odps odps) { WareHouse.getInstance().setOdps(odps); initSecurity(); } private void initSecurity() { List<String> codeBase = new LinkedList<String>(); // add odps-udf-local String path = BaseRunner.class.getProtectionDomain().getCodeSource().getLocation().getPath(); path = path.substring(path.indexOf(":") + 1); codeBase.add(path); // add odps-sdk-udf path = UDF.class.getProtectionDomain().getCodeSource().getLocation().getPath(); path = path.substring(path.indexOf(":") + 1); codeBase.add(path); // add odps-common-local path = WareHouse.class.getProtectionDomain().getCodeSource().getLocation().getPath(); path = path.substring(path.indexOf(":") + 1); codeBase.add(path); boolean isSecurityEnabled = System.getProperty(Constants.LOCAL_SECURITY_ENABLE, "false").trim() .equalsIgnoreCase("true"); boolean isJNIEnabled = System.getProperty(Constants.LOCAL_SECURITY_JNI_ENABLE, "false").trim() .equalsIgnoreCase("true"); String userDefinePolicy = System.getProperty(Constants.LOCAL_USER_DEFINE_POLICY, "").trim(); SecurityClient.init(ApplicatitionType.UDF, codeBase, null, isSecurityEnabled, isJNIEnabled, userDefinePolicy); } protected LocalExecutionContext context = new LocalExecutionContext(); protected List<Object[]> buffer = new ArrayList<Object[]>(); protected List<InputSource> inputSources = new LinkedList<InputSource>(); /** * 添加输入源,可添加多个,框架处理数据输入的顺序为: <br/> * 先处理通过feed和feedAll加入的数据(两种按先后顺序), * 最后处理通过addInputSource添加的数据源中的数据(多个InputSource按添加先后顺序处理) * * 代码示例: * * <pre> * BaseRunner runner = new UDFRunner(odps, new UdfExample()); * * // partition table * String project = "project_name"; * String table = "wc_in2"; * String[] partitions = null; * String[] columns = new String[] {"colc", "cola"}; * partitions = new String[] {"p2=1", "p1=2"}; * * // input1 * InputSource inputSource = new TableInputSource(project, table, partitions, columns); * runner.addInputSource(inputSource); * * // input2 * Object[][] inputs1 = new Object[2][]; * inputs1[0] = new Object[] {"one", "one"}; * inputs1[1] = new Object[] {"two", "two"}; * runner.feedAll(inputs1); * * List<Object[]> out = runner.yield(); * Assert.assertEquals(5, out.size()); * Assert.assertEquals("ss2s:one,one", StringUtils.join(out.get(0), ",")); * Assert.assertEquals("ss2s:two,two", StringUtils.join(out.get(1), ",")); * Assert.assertEquals("ss2s:three3,three1", StringUtils.join(out.get(2), ",")); * Assert.assertEquals("ss2s:three3,three1", StringUtils.join(out.get(3), ",")); * Assert.assertEquals("ss2s:three3,three1", StringUtils.join(out.get(4), ",")); * </pre> * * @param inputSource * @return */ public void addInputSource(InputSource inputSource) { if (inputSource != null) { try { inputSource.setup(); } catch (IOException e) { close(); throw new RuntimeException(e); } inputSources.add(inputSource); } } /** * case的输入数据,每次调用只传递一组输入数据 * * 代码示例: * <pre> * BaseRunner runner = new UDFRunner(odps, new UdfExample()); * runner.feed(new Object[] { "one", "one" }) * .feed(new Object[] { "three", "three" }) * .feed(new Object[] { "four", "four" }); * </pre> * * @param input * @return * @throws LocalRunException */ public BaseRunner feed(Object[] input) throws LocalRunException { try { return internalFeed(input); } catch (LocalRunException e) { close(); throw e; } }; protected abstract BaseRunner internalFeed(Object[] input) throws LocalRunException; /** * case的输入数据,每次调用传递多组输入数据 * * 代码示例: * <pre> * BaseRunner runner = new UDFRunner(odps, new UdfExample()); * Object[][] inputs = new Object[3][]; * inputs[0] = new Object[] { "one", "one" }; * inputs[1] = new Object[] { "three", "three" }; * inputs[2] = new Object[] { "four", "four" }; * runner.feedAll(inputs); * </pre> * * @param inputs * @return * @throws LocalRunException */ public BaseRunner feedAll(Object[][] inputs) throws LocalRunException { if (inputs == null) { return this; } for (Object[] input : inputs) { feed(input); } return this; } /** * case的输入数据,每次调用传递多组输入数据 * * 代码示例: * <pre> * BaseRunner runner = new UDFRunner(odps, new UdfExample()); * List<Object[]> inputs = new ArrayList<Object[]>(); * inputs.add(new Object[] { "one", "one" }); * inputs.add(new Object[] { "three", "three" }); * inputs.add(new Object[] { "four", "four" }); * runner.feedAll(inputs); * </pre> * * @param inputs * @return * @throws LocalRunException */ public BaseRunner feedAll(List<Object[]> inputs) throws LocalRunException { if (inputs == null) { return this; } for (Object[] input : inputs) { feed(input); } return this; } /** * 根据case输入数据产生输出结果, * 调用此方法后将不能再次调用feed及feedAll, * 如果需要添加更多case需要重新构造runner, * 用户可以对输出结果进行校验 * * <pre> * BaseRunner runner = new UDFRunner(odps, new UdfExample()); * List<Object[]> out =runner.feed(new Object[] { "one", "one" }) * .feed(new Object[] {"three", "three" }) * .feed(new Object[] { "four", "four" }) * .yield(); * * Assert.assertEquals(3, out.size()); * Assert.assertEquals("ss2s:one,one", StringUtils.join(out.get(0), ",")); * Assert.assertEquals("ss2s:three,three", StringUtils.join(out.get(1), ",")); * Assert.assertEquals("ss2s:four,four", StringUtils.join(out.get(2), ",")); * </pre> * * @return * @throws LocalRunException */ public List<Object[]> yield() throws LocalRunException { for (InputSource inputSource : inputSources) { Object[] data; try { while ((data = inputSource.getNextRow()) != null) { feed(data); } } catch (IOException e) { close(); throw new LocalRunException(e); } } try { return internalYield(); } finally { close(); } } protected abstract List<Object[]> internalYield() throws LocalRunException; /** * 将输出结果与用户期望值进行比较, * 调用此方法后将不能再次调用feed及feedAll, * 如果需要添加更多case需要重新构造runner * * <pre> * * Object[][] inputs = new Object[3][]; * inputs[0] = new Object[] { "one", "one" }; * inputs[1] = new Object[] { "three", "three" }; * inputs[2] = new Object[] { "four", "four" }; * * Object[][] expectedOutputs = new Object[3][]; * expectedOutputs[0]=new Object[] { "ss2s:one,one" }; * expectedOutputs[1]=new Object[] { "ss2s:three,three" }; * expectedOutputs[2]=new Object[] { "ss2s:four,four" }; * * BaseRunner runner = new UDFRunner(odps, new UdfExample()); * runner.feedAll(inputs).runTest(expectedOutputs); * * </pre> * * @param expectedOutputs * @throws LocalRunException */ public void runTest(Object[][] expectedOutputs) throws LocalRunException { List<Object[]> yields = yield(); if (expectedOutputs == null && yields == null) { return; } else if (expectedOutputs == null) { org.junit.Assert.fail("expected: null,but was: not null"); } else if (yields == null) { org.junit.Assert.fail("expected: not null,but was: null"); } int count = expectedOutputs.length; if (count != yields.size()) { org.junit.Assert.fail("expected size:" + count + ",but wase:" + yields.size()); } for (int i = 0; i < count; ++i) { Assert.assertArrayEquals("Row number(start with 0):" + i + " ", expectedOutputs[i], yields.get(i)); } } /** * 将输出结果与用户期望值进行比较, * 调用此方法后将不能再次调用feed及feedAll, * 如果需要添加更多case需要重新构造runner * * <pre> * * List<Object[]> inputs = new ArrayList<Object[]>(); * inputs.add(new Object[] { "one", "one" }); * inputs.add(new Object[] { "three", "three" }); * inputs.add(new Object[] { "four", "four" }); * * List<Object[]> expectedOutputs = new ArrayList<Object[]>(); * expectedOutputs.add(new Object[] { "ss2s:one,one" }); * expectedOutputs.add(new Object[] { "ss2s:three,three" }); * expectedOutputs.add(new Object[] { "ss2s:four,four" }); * * BaseRunner runner = new UDFRunner(odps, new UdfExample()); * runner.feedAll(inputs).runTest(expectedOutputs); * * </pre> * * @param expectedOutputs * @throws LocalRunException */ public void runTest(List<Object[]> expectedOutputs) throws LocalRunException { if (expectedOutputs == null) { List<Object[]> yields = yield(); if (expectedOutputs == null && yields == null) { return; } else if (expectedOutputs == null) { org.junit.Assert.fail("expected: null,but was: not null"); } else if (yields == null) { org.junit.Assert.fail("expected: not null,but was: null"); } } else { Object[][] expected = new Object[expectedOutputs.size()][]; expectedOutputs.toArray(expected); runTest(expected); } } /** * 程序退出之前(正常退出或抛出异常)必须调用该方法,确保资源得到释放 * */ final synchronized protected void close() { if (hasClosed) { return; } for (InputSource inputSource : inputSources) { try { inputSource.close(); } catch (Exception ex) { } } hasClosed = true; } }