/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.rotation;
import java.security.Key;
import java.security.KeyManagementException;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
/**
* {@link KeyLocator} that represents a list of multiple {@link KeyLocator}s. Key is searched
* from the first to the last {@link KeyLocator} in the order given by the list. If there are
* multiple {@link KeyLocator}s providing key with the same key ID, the first matching key is
* returned.
*
* @author hmlnarik
*/
public class CompositeKeyLocator implements KeyLocator, Iterable<Key> {
private final List<KeyLocator> keyLocators = new LinkedList<>();
@Override
public Key getKey(String kid) throws KeyManagementException {
for (KeyLocator keyLocator : keyLocators) {
Key k = keyLocator.getKey(kid);
if (k != null) {
return k;
}
}
return null;
}
@Override
public void refreshKeyCache() {
for (KeyLocator keyLocator : keyLocators) {
keyLocator.refreshKeyCache();
}
}
/**
* Registers a given {@link KeyLocator} as the first {@link KeyLocator}.
*/
public void addFirst(KeyLocator keyLocator) {
this.keyLocators.add(0, keyLocator);
}
/**
* Registers a given {@link KeyLocator} as the last {@link KeyLocator}.
*/
public void add(KeyLocator keyLocator) {
this.keyLocators.add(keyLocator);
}
/**
* Clears the list of registered {@link KeyLocator}s
*/
public void clear() {
this.keyLocators.clear();
}
@Override
public String toString() {
if (this.keyLocators.size() == 1) {
return this.keyLocators.get(0).toString();
}
StringBuilder sb = new StringBuilder("Key locator chain: [");
for (Iterator<KeyLocator> it = keyLocators.iterator(); it.hasNext();) {
KeyLocator keyLocator = it.next();
sb.append(keyLocator.toString());
if (it.hasNext()) {
sb.append(", ");
}
}
return sb.append("]").toString();
}
@Override
public Iterator<Key> iterator() {
final Iterator<Iterable<Key>> iterablesIterator = getKeyLocatorIterators().iterator();
return new JointKeyIterator(iterablesIterator).iterator();
}
@SuppressWarnings("unchecked")
private Iterable<Iterable<Key>> getKeyLocatorIterators() {
List<Iterable<Key>> res = new LinkedList<>();
for (KeyLocator kl : this.keyLocators) {
if (kl instanceof Iterable) {
res.add(((Iterable<Key>) kl));
}
}
return Collections.unmodifiableCollection(res);
}
private class JointKeyIterator implements Iterable<Key> {
// based on http://stackoverflow.com/a/34126154/6930869
private final Iterator<Iterable<Key>> iterablesIterator;
public JointKeyIterator(Iterator<Iterable<Key>> iterablesIterator) {
this.iterablesIterator = iterablesIterator;
}
@Override
public Iterator<Key> iterator() {
if (! iterablesIterator.hasNext()) {
return Collections.<Key>emptyIterator();
}
return new Iterator<Key>() {
private Iterator<Key> currentIterator = nextIterator();
@Override
public boolean hasNext() {
return currentIterator.hasNext();
}
@Override
public Key next() {
final Key next = currentIterator.next();
findNext();
return next;
}
private Iterator<Key> nextIterator() {
return iterablesIterator.next().iterator();
}
private Iterator<Key> findNext() {
while (! currentIterator.hasNext()) {
if (! iterablesIterator.hasNext()) {
break;
}
currentIterator = nextIterator();
}
return this;
}
}.findNext();
}
}
}