/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.nifi.processors.windows.event.log; import com.sun.jna.Native; import javassist.ClassPool; import javassist.CtClass; import javassist.CtMethod; import org.junit.runner.Description; import org.junit.runner.Runner; import org.junit.runner.notification.RunNotifier; import org.junit.runners.model.InitializationError; import org.mockito.runners.MockitoJUnitRunner; import java.net.URLClassLoader; import java.util.Map; /** * Can't even use the JNA interface classes if the native library won't load. This is a workaround to allow mocking them for unit tests. */ public abstract class JNAOverridingJUnitRunner extends Runner { public static final String NATIVE_CANONICAL_NAME = Native.class.getCanonicalName(); public static final String LOAD_LIBRARY = "loadLibrary"; private final Runner delegate; public JNAOverridingJUnitRunner(Class<?> klass) throws InitializationError { Map<String, Map<String, String>> classOverrideMap = getClassOverrideMap(); ClassLoader jnaMockClassloader = new URLClassLoader(((URLClassLoader) JNAOverridingJUnitRunner.class.getClassLoader()).getURLs(), null) { @Override protected synchronized Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException { Map<String, String> classOverrides = classOverrideMap.get(name); if (classOverrides != null) { ClassPool classPool = ClassPool.getDefault(); try { CtClass ctClass = classPool.get(name); try { for (Map.Entry<String, String> methodAndBody : classOverrides.entrySet()) { for (CtMethod loadLibrary : ctClass.getDeclaredMethods(methodAndBody.getKey())) { loadLibrary.setBody(methodAndBody.getValue()); } } byte[] bytes = ctClass.toBytecode(); Class<?> definedClass = defineClass(name, bytes, 0, bytes.length); if (resolve) { resolveClass(definedClass); } return definedClass; } finally { ctClass.detach(); } } catch (Exception e) { throw new ClassNotFoundException(name, e); } } else if (name.startsWith("org.junit.")) { Class<?> result = JNAOverridingJUnitRunner.class.getClassLoader().loadClass(name); if (resolve) { resolveClass(result); } return result; } return super.loadClass(name, resolve); } }; try { delegate = (Runner) jnaMockClassloader.loadClass(MockitoJUnitRunner.class.getCanonicalName()).getConstructor(Class.class) .newInstance(jnaMockClassloader.loadClass(klass.getCanonicalName())); } catch (Exception e) { throw new InitializationError(e); } } protected abstract Map<String, Map<String, String>> getClassOverrideMap(); @Override public Description getDescription() { return delegate.getDescription(); } @Override public void run(RunNotifier notifier) { delegate.run(notifier); } }