/* * Copyright © 2015 Cask Data, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package co.cask.cdap.common.lang; import com.google.common.collect.Lists; import java.io.IOException; import java.io.InputStream; import java.net.URL; import java.net.URLClassLoader; import java.util.Collections; import java.util.Enumeration; import java.util.HashSet; import java.util.List; import java.util.Set; /** * ClassLoader that filters out certain resources. */ public final class FilterClassLoader extends ClassLoader { private final ClassLoader extensionClassLoader; private final Filter filter; /** * Represents filtering that the {@link FilterClassLoader} needs to apply. */ public interface Filter { /** * Returns the result of whether the given resource is accepted or not. */ boolean acceptResource(String resource); /** * Returns the result of whether the given package is accepted or not. */ boolean acceptPackage(String packageName); } /** * Returns the default filter that should applies to all program type. By default * all hadoop classes and cdap-api classes (and dependencies) are allowed. */ public static Filter defaultFilter() { final Set<String> visibleResources = ProgramResources.getVisibleResources(); final Set<String> visiblePackages = new HashSet<>(); for (String resource : visibleResources) { if (resource.endsWith(".class")) { int idx = resource.lastIndexOf('/'); // Ignore empty package if (idx > 0) { visiblePackages.add(resource.substring(0, idx).replace('/', '.')); } } } return new Filter() { @Override public boolean acceptResource(String resource) { return visibleResources.contains(resource); } @Override public boolean acceptPackage(String packageName) { return visiblePackages.contains(packageName); } }; } /** * Creates a new {@link FilterClassLoader} that filter classes based on the {@link #defaultFilter()} on the * given parent ClassLoader * * @param parentClassLoader the ClassLoader to filter from. * @return a new intance of {@link FilterClassLoader}. */ public static FilterClassLoader create(ClassLoader parentClassLoader) { return new FilterClassLoader(parentClassLoader, defaultFilter()); } /** * Create a {@link FilterClassLoader} that filter classes based on the given {@link Filter} on the given * parent ClassLoader. * * @param parentClassLoader Parent ClassLoader * @param filter Filter to apply for the ClassLoader */ public FilterClassLoader(ClassLoader parentClassLoader, Filter filter) { super(parentClassLoader); this.extensionClassLoader = new URLClassLoader(new URL[0], ClassLoader.getSystemClassLoader().getParent()); this.filter = filter; } @Override protected synchronized Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException { // Try to load it from bootstrap class loader first try { return extensionClassLoader.loadClass(name); } catch (ClassNotFoundException e) { if (filter.acceptResource(classNameToResourceName(name))) { return super.loadClass(name, resolve); } throw e; } } @Override protected Package[] getPackages() { List<Package> packages = Lists.newArrayList(); for (Package pkg : super.getPackages()) { if (filter.acceptPackage(pkg.getName())) { packages.add(pkg); } } return packages.toArray(new Package[packages.size()]); } @Override protected Package getPackage(String name) { // Replace all '/' with '.' since Java allow both names like "java/lang" or "java.lang" as the name to lookup return (filter.acceptPackage(name.replace('/', '.'))) ? super.getPackage(name) : null; } @Override public URL getResource(String name) { URL resource = extensionClassLoader.getResource(name); if (resource != null) { return resource; } return filter.acceptResource(name) ? super.getResource(name) : null; } @Override public Enumeration<URL> getResources(String name) throws IOException { Enumeration<URL> resources = extensionClassLoader.getResources(name); if (resources.hasMoreElements()) { return resources; } return filter.acceptResource(name) ? super.getResources(name) : Collections.<URL>emptyEnumeration(); } @Override public InputStream getResourceAsStream(String name) { InputStream resourceStream = extensionClassLoader.getResourceAsStream(name); if (resourceStream != null) { return resourceStream; } return filter.acceptResource(name) ? super.getResourceAsStream(name) : null; } private String classNameToResourceName(String className) { return className.replace('.', '/') + ".class"; } }