/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.common.util.concurrent;
import org.apache.lucene.util.CloseableThreadLocal;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Setting.Property;
import org.elasticsearch.common.settings.Settings;
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* A ThreadContext is a map of string headers and a transient map of keyed objects that are associated with
* a thread. It allows to store and retrieve header information across method calls, network calls as well as threads spawned from a
* thread that has a {@link ThreadContext} associated with. Threads spawned from a {@link org.elasticsearch.threadpool.ThreadPool} have out of the box
* support for {@link ThreadContext} and all threads spawned will inherit the {@link ThreadContext} from the thread that it is forking from.".
* Network calls will also preserve the senders headers automatically.
* <p>
* Consumers of ThreadContext usually don't need to interact with adding or stashing contexts. Every elasticsearch thread is managed by a thread pool or executor
* being responsible for stashing and restoring the threads context. For instance if a network request is received, all headers are deserialized from the network
* and directly added as the headers of the threads {@link ThreadContext} (see {@link #readHeaders(StreamInput)}. In order to not modify the context that is currently
* active on this thread the network code uses a try/with pattern to stash it's current context, read headers into a fresh one and once the request is handled or a handler thread
* is forked (which in turn inherits the context) it restores the previous context. For instance:
* </p>
* <pre>
* // current context is stashed and replaced with a default context
* try (StoredContext context = threadContext.stashContext()) {
* threadContext.readHeaders(in); // read headers into current context
* if (fork) {
* threadPool.execute(() -> request.handle()); // inherits context
* } else {
* request.handle();
* }
* }
* // previous context is restored on StoredContext#close()
* </pre>
*
*/
public final class ThreadContext implements Closeable, Writeable {
public static final String PREFIX = "request.headers";
public static final Setting<Settings> DEFAULT_HEADERS_SETTING = Setting.groupSetting(PREFIX + ".", Property.NodeScope);
private static final ThreadContextStruct DEFAULT_CONTEXT = new ThreadContextStruct();
private final Map<String, String> defaultHeader;
private final ContextThreadLocal threadLocal;
private boolean isSystemContext;
/**
* Creates a new ThreadContext instance
* @param settings the settings to read the default request headers from
*/
public ThreadContext(Settings settings) {
Settings headers = DEFAULT_HEADERS_SETTING.get(settings);
if (headers == null) {
this.defaultHeader = Collections.emptyMap();
} else {
Map<String, String> defaultHeader = new HashMap<>();
for (String key : headers.names()) {
defaultHeader.put(key, headers.get(key));
}
this.defaultHeader = Collections.unmodifiableMap(defaultHeader);
}
threadLocal = new ContextThreadLocal();
}
@Override
public void close() throws IOException {
threadLocal.close();
}
/**
* Removes the current context and resets a default context. The removed context can be
* restored when closing the returned {@link StoredContext}
*/
public StoredContext stashContext() {
final ThreadContextStruct context = threadLocal.get();
threadLocal.set(null);
return () -> threadLocal.set(context);
}
/**
* Removes the current context and resets a new context that contains a merge of the current headers and the given headers. The removed context can be
* restored when closing the returned {@link StoredContext}. The merge strategy is that headers that are already existing are preserved unless they are defaults.
*/
public StoredContext stashAndMergeHeaders(Map<String, String> headers) {
final ThreadContextStruct context = threadLocal.get();
Map<String, String> newHeader = new HashMap<>(headers);
newHeader.putAll(context.requestHeaders);
threadLocal.set(DEFAULT_CONTEXT.putHeaders(newHeader));
return () -> threadLocal.set(context);
}
/**
* Just like {@link #stashContext()} but no default context is set.
* @param preserveResponseHeaders if set to <code>true</code> the response headers of the restore thread will be preserved.
*/
public StoredContext newStoredContext(boolean preserveResponseHeaders) {
final ThreadContextStruct context = threadLocal.get();
return () -> {
if (preserveResponseHeaders && threadLocal.get() != context) {
threadLocal.set(context.putResponseHeaders(threadLocal.get().responseHeaders));
} else {
threadLocal.set(context);
}
};
}
/**
* Returns a supplier that gathers a {@link #newStoredContext(boolean)} and restores it once the
* returned supplier is invoked. The context returned from the supplier is a stored version of the
* suppliers callers context that should be restored once the originally gathered context is not needed anymore.
* For instance this method should be used like this:
*
* <pre>
* Supplier<ThreadContext.StoredContext> restorable = context.newRestorableContext(true);
* new Thread() {
* public void run() {
* try (ThreadContext.StoredContext ctx = restorable.get()) {
* // execute with the parents context and restore the threads context afterwards
* }
* }
*
* }.start();
* </pre>
*
* @param preserveResponseHeaders if set to <code>true</code> the response headers of the restore thread will be preserved.
* @return a restorable context supplier
*/
public Supplier<StoredContext> newRestorableContext(boolean preserveResponseHeaders) {
return wrapRestorable(newStoredContext(preserveResponseHeaders));
}
/**
* Same as {@link #newRestorableContext(boolean)} but wraps an existing context to restore.
* @param storedContext the context to restore
*/
public Supplier<StoredContext> wrapRestorable(StoredContext storedContext) {
return () -> {
StoredContext context = newStoredContext(false);
storedContext.restore();
return context;
};
}
@Override
public void writeTo(StreamOutput out) throws IOException {
threadLocal.get().writeTo(out, defaultHeader);
}
/**
* Reads the headers from the stream into the current context
*/
public void readHeaders(StreamInput in) throws IOException {
threadLocal.set(new ThreadContext.ThreadContextStruct(in));
}
/**
* Returns the header for the given key or <code>null</code> if not present
*/
public String getHeader(String key) {
String value = threadLocal.get().requestHeaders.get(key);
if (value == null) {
return defaultHeader.get(key);
}
return value;
}
/**
* Returns all of the request contexts headers
*/
public Map<String, String> getHeaders() {
HashMap<String, String> map = new HashMap<>(defaultHeader);
map.putAll(threadLocal.get().requestHeaders);
return Collections.unmodifiableMap(map);
}
/**
* Get a copy of all <em>response</em> headers.
*
* @return Never {@code null}.
*/
public Map<String, List<String>> getResponseHeaders() {
Map<String, List<String>> responseHeaders = threadLocal.get().responseHeaders;
HashMap<String, List<String>> map = new HashMap<>(responseHeaders.size());
for (Map.Entry<String, List<String>> entry : responseHeaders.entrySet()) {
map.put(entry.getKey(), Collections.unmodifiableList(entry.getValue()));
}
return Collections.unmodifiableMap(map);
}
/**
* Copies all header key, value pairs into the current context
*/
public void copyHeaders(Iterable<Map.Entry<String, String>> headers) {
threadLocal.set(threadLocal.get().copyHeaders(headers));
}
/**
* Puts a header into the context
*/
public void putHeader(String key, String value) {
threadLocal.set(threadLocal.get().putRequest(key, value));
}
/**
* Puts all of the given headers into this context
*/
public void putHeader(Map<String, String> header) {
threadLocal.set(threadLocal.get().putHeaders(header));
}
/**
* Puts a transient header object into this context
*/
public void putTransient(String key, Object value) {
threadLocal.set(threadLocal.get().putTransient(key, value));
}
/**
* Returns a transient header object or <code>null</code> if there is no header for the given key
*/
@SuppressWarnings("unchecked") // (T)object
public <T> T getTransient(String key) {
return (T) threadLocal.get().transientHeaders.get(key);
}
/**
* Add the {@code value} for the specified {@code key} Any duplicate {@code value} is ignored.
*
* @param key the header name
* @param value the header value
*/
public void addResponseHeader(final String key, final String value) {
addResponseHeader(key, value, v -> v);
}
/**
* Add the {@code value} for the specified {@code key} with the specified {@code uniqueValue} used for de-duplication. Any duplicate
* {@code value} after applying {@code uniqueValue} is ignored.
*
* @param key the header name
* @param value the header value
* @param uniqueValue the function that produces de-duplication values
*/
public void addResponseHeader(final String key, final String value, final Function<String, String> uniqueValue) {
threadLocal.set(threadLocal.get().putResponse(key, value, uniqueValue));
}
/**
* Saves the current thread context and wraps command in a Runnable that restores that context before running command. If
* <code>command</code> has already been passed through this method then it is returned unaltered rather than wrapped twice.
*/
public Runnable preserveContext(Runnable command) {
if (command instanceof ContextPreservingAbstractRunnable) {
return command;
}
if (command instanceof ContextPreservingRunnable) {
return command;
}
if (command instanceof AbstractRunnable) {
return new ContextPreservingAbstractRunnable((AbstractRunnable) command);
}
return new ContextPreservingRunnable(command);
}
/**
* Unwraps a command that was previously wrapped by {@link #preserveContext(Runnable)}.
*/
public Runnable unwrap(Runnable command) {
if (command instanceof ContextPreservingAbstractRunnable) {
return ((ContextPreservingAbstractRunnable) command).unwrap();
}
if (command instanceof ContextPreservingRunnable) {
return ((ContextPreservingRunnable) command).unwrap();
}
return command;
}
/**
* Returns true if the current context is the default context.
*/
boolean isDefaultContext() {
return threadLocal.get() == DEFAULT_CONTEXT;
}
/**
* Marks this thread context as an internal system context. This signals that actions in this context are issued
* by the system itself rather than by a user action.
*/
public void markAsSystemContext() {
threadLocal.set(threadLocal.get().setSystemContext());
}
/**
* Returns <code>true</code> iff this context is a system context
*/
public boolean isSystemContext() {
return threadLocal.get().isSystemContext;
}
/**
* Returns <code>true</code> if the context is closed, otherwise <code>true</code>
*/
boolean isClosed() {
return threadLocal.closed.get();
}
@FunctionalInterface
public interface StoredContext extends AutoCloseable {
@Override
void close();
default void restore() {
close();
}
}
private static final class ThreadContextStruct {
private final Map<String, String> requestHeaders;
private final Map<String, Object> transientHeaders;
private final Map<String, List<String>> responseHeaders;
private final boolean isSystemContext;
private ThreadContextStruct(StreamInput in) throws IOException {
final int numRequest = in.readVInt();
Map<String, String> requestHeaders = numRequest == 0 ? Collections.emptyMap() : new HashMap<>(numRequest);
for (int i = 0; i < numRequest; i++) {
requestHeaders.put(in.readString(), in.readString());
}
this.requestHeaders = requestHeaders;
this.responseHeaders = in.readMapOfLists(StreamInput::readString, StreamInput::readString);
this.transientHeaders = Collections.emptyMap();
isSystemContext = false; // we never serialize this it's a transient flag
}
private ThreadContextStruct setSystemContext() {
if (isSystemContext) {
return this;
}
return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, true);
}
private ThreadContextStruct(Map<String, String> requestHeaders,
Map<String, List<String>> responseHeaders,
Map<String, Object> transientHeaders, boolean isSystemContext) {
this.requestHeaders = requestHeaders;
this.responseHeaders = responseHeaders;
this.transientHeaders = transientHeaders;
this.isSystemContext = isSystemContext;
}
/**
* This represents the default context and it should only ever be called by {@link #DEFAULT_CONTEXT}.
*/
private ThreadContextStruct() {
this(Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), false);
}
private ThreadContextStruct putRequest(String key, String value) {
Map<String, String> newRequestHeaders = new HashMap<>(this.requestHeaders);
putSingleHeader(key, value, newRequestHeaders);
return new ThreadContextStruct(newRequestHeaders, responseHeaders, transientHeaders, isSystemContext);
}
private void putSingleHeader(String key, String value, Map<String, String> newHeaders) {
if (newHeaders.putIfAbsent(key, value) != null) {
throw new IllegalArgumentException("value for key [" + key + "] already present");
}
}
private ThreadContextStruct putHeaders(Map<String, String> headers) {
if (headers.isEmpty()) {
return this;
} else {
final Map<String, String> newHeaders = new HashMap<>();
for (Map.Entry<String, String> entry : headers.entrySet()) {
putSingleHeader(entry.getKey(), entry.getValue(), newHeaders);
}
newHeaders.putAll(this.requestHeaders);
return new ThreadContextStruct(newHeaders, responseHeaders, transientHeaders, isSystemContext);
}
}
private ThreadContextStruct putResponseHeaders(Map<String, List<String>> headers) {
assert headers != null;
if (headers.isEmpty()) {
return this;
}
final Map<String, List<String>> newResponseHeaders = new HashMap<>(this.responseHeaders);
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
String key = entry.getKey();
final List<String> existingValues = newResponseHeaders.get(key);
if (existingValues != null) {
List<String> newValues = Stream.concat(entry.getValue().stream(),
existingValues.stream()).distinct().collect(Collectors.toList());
newResponseHeaders.put(key, Collections.unmodifiableList(newValues));
} else {
newResponseHeaders.put(key, entry.getValue());
}
}
return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext);
}
private ThreadContextStruct putResponse(final String key, final String value, final Function<String, String> uniqueValue) {
assert value != null;
final Map<String, List<String>> newResponseHeaders = new HashMap<>(this.responseHeaders);
final List<String> existingValues = newResponseHeaders.get(key);
if (existingValues != null) {
final Set<String> existingUniqueValues = existingValues.stream().map(uniqueValue).collect(Collectors.toSet());
assert existingValues.size() == existingUniqueValues.size();
if (existingUniqueValues.contains(uniqueValue.apply(value))) {
return this;
}
final List<String> newValues = new ArrayList<>(existingValues);
newValues.add(value);
newResponseHeaders.put(key, Collections.unmodifiableList(newValues));
} else {
newResponseHeaders.put(key, Collections.singletonList(value));
}
return new ThreadContextStruct(requestHeaders, newResponseHeaders, transientHeaders, isSystemContext);
}
private ThreadContextStruct putTransient(String key, Object value) {
Map<String, Object> newTransient = new HashMap<>(this.transientHeaders);
if (newTransient.putIfAbsent(key, value) != null) {
throw new IllegalArgumentException("value for key [" + key + "] already present");
}
return new ThreadContextStruct(requestHeaders, responseHeaders, newTransient, isSystemContext);
}
boolean isEmpty() {
return requestHeaders.isEmpty() && responseHeaders.isEmpty() && transientHeaders.isEmpty();
}
private ThreadContextStruct copyHeaders(Iterable<Map.Entry<String, String>> headers) {
Map<String, String> newHeaders = new HashMap<>();
for (Map.Entry<String, String> header : headers) {
newHeaders.put(header.getKey(), header.getValue());
}
return putHeaders(newHeaders);
}
private void writeTo(StreamOutput out, Map<String, String> defaultHeaders) throws IOException {
final Map<String, String> requestHeaders;
if (defaultHeaders.isEmpty()) {
requestHeaders = this.requestHeaders;
} else {
requestHeaders = new HashMap<>(defaultHeaders);
requestHeaders.putAll(this.requestHeaders);
}
out.writeVInt(requestHeaders.size());
for (Map.Entry<String, String> entry : requestHeaders.entrySet()) {
out.writeString(entry.getKey());
out.writeString(entry.getValue());
}
out.writeMapOfLists(responseHeaders, StreamOutput::writeString, StreamOutput::writeString);
}
}
private static class ContextThreadLocal extends CloseableThreadLocal<ThreadContextStruct> {
private final AtomicBoolean closed = new AtomicBoolean(false);
@Override
public void set(ThreadContextStruct object) {
try {
if (object == DEFAULT_CONTEXT) {
super.set(null);
} else {
super.set(object);
}
} catch (NullPointerException ex) {
/* This is odd but CloseableThreadLocal throws a NPE if it was closed but still accessed.
to get a real exception we call ensureOpen() to tell the user we are already closed.*/
ensureOpen();
throw ex;
}
}
@Override
public ThreadContextStruct get() {
try {
ThreadContextStruct threadContextStruct = super.get();
if (threadContextStruct != null) {
return threadContextStruct;
}
return DEFAULT_CONTEXT;
} catch (NullPointerException ex) {
/* This is odd but CloseableThreadLocal throws a NPE if it was closed but still accessed.
to get a real exception we call ensureOpen() to tell the user we are already closed.*/
ensureOpen();
throw ex;
}
}
private void ensureOpen() {
if (closed.get()) {
throw new IllegalStateException("threadcontext is already closed");
}
}
@Override
public void close() {
if (closed.compareAndSet(false, true)) {
super.close();
}
}
}
/**
* Wraps a Runnable to preserve the thread context.
*/
private class ContextPreservingRunnable implements Runnable {
private final Runnable in;
private final ThreadContext.StoredContext ctx;
private ContextPreservingRunnable(Runnable in) {
ctx = newStoredContext(false);
this.in = in;
}
@Override
public void run() {
boolean whileRunning = false;
try (ThreadContext.StoredContext ignore = stashContext()){
ctx.restore();
whileRunning = true;
in.run();
whileRunning = false;
} catch (IllegalStateException ex) {
if (whileRunning || threadLocal.closed.get() == false) {
throw ex;
}
// if we hit an ISE here we have been shutting down
// this comes from the threadcontext and barfs if
// our threadpool has been shutting down
}
}
@Override
public String toString() {
return in.toString();
}
public Runnable unwrap() {
return in;
}
}
/**
* Wraps an AbstractRunnable to preserve the thread context.
*/
private class ContextPreservingAbstractRunnable extends AbstractRunnable {
private final AbstractRunnable in;
private final ThreadContext.StoredContext creatorsContext;
private ThreadContext.StoredContext threadsOriginalContext = null;
private ContextPreservingAbstractRunnable(AbstractRunnable in) {
creatorsContext = newStoredContext(false);
this.in = in;
}
@Override
public boolean isForceExecution() {
return in.isForceExecution();
}
@Override
public void onAfter() {
try {
in.onAfter();
} finally {
if (threadsOriginalContext != null) {
threadsOriginalContext.restore();
}
}
}
@Override
public void onFailure(Exception e) {
in.onFailure(e);
}
@Override
public void onRejection(Exception e) {
in.onRejection(e);
}
@Override
protected void doRun() throws Exception {
boolean whileRunning = false;
threadsOriginalContext = stashContext();
try {
creatorsContext.restore();
whileRunning = true;
in.doRun();
whileRunning = false;
} catch (IllegalStateException ex) {
if (whileRunning || threadLocal.closed.get() == false) {
throw ex;
}
// if we hit an ISE here we have been shutting down
// this comes from the threadcontext and barfs if
// our threadpool has been shutting down
}
}
@Override
public String toString() {
return in.toString();
}
public AbstractRunnable unwrap() {
return in;
}
}
}