package org.infinispan.server.core.security;
import java.security.Provider;
import java.security.Security;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.ServiceLoader;
import java.util.Set;
import javax.security.sasl.SaslClientFactory;
import javax.security.sasl.SaslServerFactory;
/**
* Utility methods for handling SASL authentication
*
* @author <a href="mailto:david.lloyd@redhat.com">David M. Lloyd</a>
* @author Tristan Tarrant
*/
public final class SaslUtils {
private SaslUtils() {
}
/**
* Returns an iterator of all of the registered {@code SaslServerFactory}s where the order is based on the
* order of the Provider registration and/or class path order. Class path providers are listed before
* global providers; in the event of a name conflict, the class path provider is preferred.
*
* @param classLoader the class loader to use
* @param includeGlobal {@code true} to include globally registered providers, {@code false} to exclude them
* @return the {@code Iterator} of {@code SaslServerFactory}s
*/
public static Iterator<SaslServerFactory> getSaslServerFactories(ClassLoader classLoader, boolean includeGlobal) {
return getFactories(SaslServerFactory.class, classLoader, includeGlobal);
}
/**
* Returns an iterator of all of the registered {@code SaslServerFactory}s where the order is based on the
* order of the Provider registration and/or class path order.
*
* @return the {@code Iterator} of {@code SaslServerFactory}s
*/
public static Iterator<SaslServerFactory> getSaslServerFactories() {
return getFactories(SaslServerFactory.class, null, true);
}
/**
* Returns an iterator of all of the registered {@code SaslClientFactory}s where the order is based on the
* order of the Provider registration and/or class path order. Class path providers are listed before
* global providers; in the event of a name conflict, the class path provider is preferred.
*
* @param classLoader the class loader to use
* @param includeGlobal {@code true} to include globally registered providers, {@code false} to exclude them
* @return the {@code Iterator} of {@code SaslClientFactory}s
*/
public static Iterator<SaslClientFactory> getSaslClientFactories(ClassLoader classLoader, boolean includeGlobal) {
return getFactories(SaslClientFactory.class, classLoader, includeGlobal);
}
/**
* Returns an iterator of all of the registered {@code SaslClientFactory}s where the order is based on the
* order of the Provider registration and/or class path order.
*
* @return the {@code Iterator} of {@code SaslClientFactory}s
*/
public static Iterator<SaslClientFactory> getSaslClientFactories() {
return getFactories(SaslClientFactory.class, null, true);
}
private static <T> Iterator<T> getFactories(Class<T> type, ClassLoader classLoader, boolean includeGlobal) {
Set<T> factories = new LinkedHashSet<T>();
final ServiceLoader<T> loader = ServiceLoader.load(type, classLoader);
for (T factory : loader) {
factories.add(factory);
}
if (includeGlobal) {
Set<String> loadedClasses = new HashSet<String>();
final String filter = type.getSimpleName() + ".";
Provider[] providers = Security.getProviders();
for (Provider currentProvider : providers) {
final ClassLoader cl = currentProvider.getClass().getClassLoader();
for (Object currentKey : currentProvider.keySet()) {
if (currentKey instanceof String &&
((String) currentKey).startsWith(filter) &&
((String) currentKey).indexOf(' ') < 0) {
String className = currentProvider.getProperty((String) currentKey);
if (className != null && loadedClasses.add(className)) {
try {
factories.add(Class.forName(className, true, cl).asSubclass(type).newInstance());
} catch (ClassNotFoundException e) {
} catch (ClassCastException e) {
} catch (InstantiationException e) {
} catch (IllegalAccessException e) {
}
}
}
}
}
}
return factories.iterator();
}
}