/* * Copyright © 2015 Cask Data, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not * use this file except in compliance with the License. You may obtain a copy of * the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations under * the License. */ package co.cask.cdap.common.lang; import com.google.common.base.Predicate; import com.google.common.collect.Iterators; import com.google.common.collect.Lists; import java.io.IOException; import java.net.URL; import java.net.URLClassLoader; import java.util.Enumeration; import java.util.List; import javax.annotation.Nullable; /** * A {@link ClassLoader} that filter class based on package name. Classes in the bootstrap ClassLoader is always * loadable from this ClassLoader. */ public class PackageFilterClassLoader extends ClassLoader { private final Predicate<String> predicate; private final ClassLoader bootstrapClassLoader; /** * Constructs a new instance that only allow class's package name passes the given {@link Predicate}. */ public PackageFilterClassLoader(ClassLoader parent, Predicate<String> predicate) { super(parent); this.predicate = predicate; // This is no reliable way to get bootstrap ClassLoader from Java (System.class.getClassLoader() will return null). // A URLClassLoader with no URLs and with a null parent will load class from bootstrap ClassLoader only. this.bootstrapClassLoader = new URLClassLoader(new URL[0], null); } @Override protected synchronized Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException { try { return bootstrapClassLoader.loadClass(name); } catch (ClassNotFoundException e) { if (!predicate.apply(getClassPackage(name))) { throw new ClassNotFoundException("Loading of class " + name + " not allowed"); } return super.loadClass(name, resolve); } } @Override public URL getResource(String name) { URL resource = bootstrapClassLoader.getResource(name); if (resource != null) { return resource; } if (name.endsWith(".class") && !predicate.apply(getResourcePackage(name))) { return null; } return super.getResource(name); } @Override public Enumeration<URL> getResources(String name) throws IOException { Enumeration<URL> resources = bootstrapClassLoader.getResources(name); if (resources.hasMoreElements()) { return resources; } if (name.endsWith(".class") && !predicate.apply(getResourcePackage(name))) { return Iterators.asEnumeration(Iterators.<URL>emptyIterator()); } return super.getResources(name); } @Override protected Package[] getPackages() { List<Package> packages = Lists.newArrayList(); for (Package pkg : super.getPackages()) { if (predicate.apply(pkg.getName())) { packages.add(pkg); } } return packages.toArray(new Package[packages.size()]); } @Override protected Package getPackage(String name) { if (!predicate.apply(name)) { return null; } return super.getPackage(name); } /** * Returns the package of the given class or {@code null} if the class is in default package. * * @param className fully qualified name of the class */ @Nullable private String getClassPackage(String className) { int idx = className.lastIndexOf('.'); return idx < 0 ? null : className.substring(0, idx); } /** * Returns the package name of the given resource name representing a class. * * @param classResource Resource name of the class. */ private String getResourcePackage(String classResource) { String packageName = classResource.substring(0, classResource.length() - ".class".length()).replace('/', '.'); if (packageName.startsWith("/")) { return packageName.substring(1); } return packageName; } }