/* * Copyright © 2016 Cask Data, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package co.cask.cdap.internal.app.runtime; import co.cask.cdap.app.guice.DefaultProgramRunnerFactory; import co.cask.cdap.app.runtime.ProgramRunner; import co.cask.cdap.app.runtime.ProgramRuntimeProvider; import co.cask.cdap.common.conf.CConfiguration; import co.cask.cdap.common.conf.Constants; import co.cask.cdap.common.utils.DirUtils; import co.cask.cdap.proto.ProgramType; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Function; import com.google.common.base.Objects; import com.google.common.base.Splitter; import com.google.common.base.Throwables; import com.google.common.cache.CacheBuilder; import com.google.common.cache.CacheLoader; import com.google.common.cache.LoadingCache; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.inject.Inject; import com.google.inject.Injector; import com.google.inject.Singleton; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.File; import java.net.MalformedURLException; import java.net.URL; import java.net.URLClassLoader; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.ServiceLoader; import javax.annotation.Nullable; /** * A singleton class for discovering {@link ProgramRuntimeProvider} through the runtime extension mechanism that uses * the Java {@link ServiceLoader} architecture. */ @Singleton public class ProgramRuntimeProviderLoader { private static final Logger LOG = LoggerFactory.getLogger(ProgramRuntimeProviderLoader.class); // The ServiceLoader that loads ProgramRunnerProvider implementation from the CDAP system classloader. private static final ServiceLoader<ProgramRuntimeProvider> SYSTEM_PROGRAM_RUNNER_PROVIDER_LOADER = ServiceLoader.load(ProgramRuntimeProvider.class); // The ProgramRunnerProvider serves as a tagging instance to indicate there is not // provider supported for a given program type private static final ProgramRuntimeProvider NOT_SUPPORTED_PROVIDER = new ProgramRuntimeProvider() { @Override public ProgramRunner createProgramRunner(ProgramType type, Mode mode, Injector injector) { throw new UnsupportedOperationException(); } }; private final LoadingCache<ProgramType, ProgramRuntimeProvider> programRunnerProviderCache; @VisibleForTesting @Inject public ProgramRuntimeProviderLoader(CConfiguration cConf) { this.programRunnerProviderCache = createProgramRunnerProviderCache(cConf); } /** * Returns a {@link ProgramRuntimeProvider} if one is found for the given {@link ProgramType}; * otherwise {@code null} will be returned. */ @Nullable public ProgramRuntimeProvider get(ProgramType programType) { try { ProgramRuntimeProvider provider = programRunnerProviderCache.get(programType); if (provider != NOT_SUPPORTED_PROVIDER) { return provider; } } catch (Throwable t) { LOG.warn("Failed to load ProgramRunnerProvider for {} program.", programType, t); } return null; } /** * Creates a cache for caching {@link ProgramRuntimeProvider} for different {@link ProgramType}. */ private LoadingCache<ProgramType, ProgramRuntimeProvider> createProgramRunnerProviderCache(CConfiguration cConf) { // A LoadingCache from extension directory to ServiceLoader final LoadingCache<File, ServiceLoader<ProgramRuntimeProvider>> serviceLoaderCache = createServiceLoaderCache(); // List of extension directories to scan String extDirs = cConf.get(Constants.AppFabric.RUNTIME_EXT_DIR, ""); final List<String> dirs = ImmutableList.copyOf(Splitter.on(';').omitEmptyStrings().trimResults().split(extDirs)); return CacheBuilder.newBuilder().build(new CacheLoader<ProgramType, ProgramRuntimeProvider>() { @Override public ProgramRuntimeProvider load(ProgramType programType) throws Exception { // Goes through all extension directory and see which service loader supports the give program type for (String dir : dirs) { File extDir = new File(dir); if (!extDir.isDirectory()) { continue; } // Each module would be under a directory of the extension directory for (File moduleDir : DirUtils.listFiles(extDir)) { if (!moduleDir.isDirectory()) { continue; } // Try to find a provider that can support the given program type. try { ProgramRuntimeProvider provider = findProvider(serviceLoaderCache.getUnchecked(moduleDir), programType); if (provider != null) { return provider; } } catch (Exception e) { LOG.warn("Exception raised when loading a ProgramRuntimeProvider from {}. Extension ignored.", moduleDir, e); } } } // If there is none found in the ext dir, try to look it up from the CDAP system class ClassLoader. // This is for the unit-test case, which extensions are part of the test dependency, hence in the // unit-test ClassLoader. // If no provider was found, returns the NOT_SUPPORTED_PROVIDER so that we won't search again for // this program type. // Cannot use null because LoadingCache doesn't allow null value return Objects.firstNonNull(findProvider(SYSTEM_PROGRAM_RUNNER_PROVIDER_LOADER, programType), NOT_SUPPORTED_PROVIDER); } }); } /** * Creates a cache for caching extension directory to {@link ServiceLoader} of {@link ProgramRuntimeProvider}. */ private LoadingCache<File, ServiceLoader<ProgramRuntimeProvider>> createServiceLoaderCache() { return CacheBuilder.newBuilder().build(new CacheLoader<File, ServiceLoader<ProgramRuntimeProvider>>() { @Override public ServiceLoader<ProgramRuntimeProvider> load(File dir) throws Exception { return createServiceLoader(dir); } }); } /** * Creates a {@link ServiceLoader} from the {@link ClassLoader} created by all jar files under the given directory. */ private ServiceLoader<ProgramRuntimeProvider> createServiceLoader(File dir) { List<File> files = new ArrayList<>(DirUtils.listFiles(dir, "jar")); Collections.sort(files); URL[] urls = Iterables.toArray(Iterables.transform(files, new Function<File, URL>() { @Override public URL apply(File input) { try { return input.toURI().toURL(); } catch (MalformedURLException e) { // Shouldn't happen throw Throwables.propagate(e); } } }), URL.class); URLClassLoader classLoader = new URLClassLoader(urls, DefaultProgramRunnerFactory.class.getClassLoader()); return ServiceLoader.load(ProgramRuntimeProvider.class, classLoader); } /** * Finds a {@link ProgramRuntimeProvider} from the given {@link ServiceLoader} that can support the given * {@link ProgramType}. */ @Nullable private ProgramRuntimeProvider findProvider(ServiceLoader<ProgramRuntimeProvider> serviceLoader, ProgramType programType) { for (ProgramRuntimeProvider provider : serviceLoader) { Class<? extends ProgramRuntimeProvider> providerClass = provider.getClass(); // See if the provide supports the required program type ProgramRuntimeProvider.SupportedProgramType supportedType = providerClass.getAnnotation(ProgramRuntimeProvider.SupportedProgramType.class); if (supportedType == null || !ImmutableSet.copyOf(supportedType.value()).contains(programType)) { continue; } // Found the provider. LOG.debug("ProgramRunnerProvider {} found for {} program.", provider, programType); return provider; } return null; } }