/* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.elasticsearch.common.util.concurrent; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.test.ESTestCase; import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Supplier; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.sameInstance; public class ThreadContextTests extends ESTestCase { public void testStashContext() { Settings build = Settings.builder().put("request.headers.default", "1").build(); ThreadContext threadContext = new ThreadContext(build); threadContext.putHeader("foo", "bar"); threadContext.putTransient("ctx.foo", 1); assertEquals("bar", threadContext.getHeader("foo")); assertEquals(new Integer(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); } assertEquals("bar", threadContext.getHeader("foo")); assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); } public void testStashAndMerge() { Settings build = Settings.builder().put("request.headers.default", "1").build(); ThreadContext threadContext = new ThreadContext(build); threadContext.putHeader("foo", "bar"); threadContext.putTransient("ctx.foo", 1); assertEquals("bar", threadContext.getHeader("foo")); assertEquals(new Integer(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); HashMap<String, String> toMerge = new HashMap<>(); toMerge.put("foo", "baz"); toMerge.put("simon", "says"); try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(toMerge)) { assertEquals("bar", threadContext.getHeader("foo")); assertEquals("says", threadContext.getHeader("simon")); assertNull(threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); } assertNull(threadContext.getHeader("simon")); assertEquals("bar", threadContext.getHeader("foo")); assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); } public void testStoreContext() { Settings build = Settings.builder().put("request.headers.default", "1").build(); ThreadContext threadContext = new ThreadContext(build); threadContext.putHeader("foo", "bar"); threadContext.putTransient("ctx.foo", 1); assertEquals("bar", threadContext.getHeader("foo")); assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); ThreadContext.StoredContext storedContext = threadContext.newStoredContext(false); threadContext.putHeader("foo.bar", "baz"); try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); } assertEquals("bar", threadContext.getHeader("foo")); assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); assertEquals("baz", threadContext.getHeader("foo.bar")); if (randomBoolean()) { storedContext.restore(); } else { storedContext.close(); } assertEquals("bar", threadContext.getHeader("foo")); assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); assertNull(threadContext.getHeader("foo.bar")); } public void testRestorableContext() { Settings build = Settings.builder().put("request.headers.default", "1").build(); ThreadContext threadContext = new ThreadContext(build); threadContext.putHeader("foo", "bar"); threadContext.putTransient("ctx.foo", 1); threadContext.addResponseHeader("resp.header", "baaaam"); Supplier<ThreadContext.StoredContext> contextSupplier = threadContext.newRestorableContext(true); try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { assertNull(threadContext.getHeader("foo")); assertEquals("1", threadContext.getHeader("default")); threadContext.addResponseHeader("resp.header", "boom"); try (ThreadContext.StoredContext tmp = contextSupplier.get()) { assertEquals("bar", threadContext.getHeader("foo")); assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); assertEquals(2, threadContext.getResponseHeaders().get("resp.header").size()); assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0)); assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(1)); } assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("ctx.foo")); assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0)); } assertEquals("bar", threadContext.getHeader("foo")); assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0)); contextSupplier = threadContext.newRestorableContext(false); try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { assertNull(threadContext.getHeader("foo")); assertEquals("1", threadContext.getHeader("default")); threadContext.addResponseHeader("resp.header", "boom"); try (ThreadContext.StoredContext tmp = contextSupplier.get()) { assertEquals("bar", threadContext.getHeader("foo")); assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0)); } assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("ctx.foo")); assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); assertEquals("boom", threadContext.getResponseHeaders().get("resp.header").get(0)); } assertEquals("bar", threadContext.getHeader("foo")); assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); assertEquals(1, threadContext.getResponseHeaders().get("resp.header").size()); assertEquals("baaaam", threadContext.getResponseHeaders().get("resp.header").get(0)); } public void testResponseHeaders() { final boolean expectThird = randomBoolean(); final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); threadContext.addResponseHeader("foo", "bar"); // pretend that another thread created the same response if (randomBoolean()) { threadContext.addResponseHeader("foo", "bar"); } final String value = DeprecationLogger.formatWarning("qux"); threadContext.addResponseHeader("baz", value, DeprecationLogger::extractWarningValueFromWarningHeader); // pretend that another thread created the same response at a different time if (randomBoolean()) { final String duplicateValue = DeprecationLogger.formatWarning("qux"); threadContext.addResponseHeader("baz", duplicateValue, DeprecationLogger::extractWarningValueFromWarningHeader); } threadContext.addResponseHeader("Warning", "One is the loneliest number"); threadContext.addResponseHeader("Warning", "Two can be as bad as one"); if (expectThird) { threadContext.addResponseHeader("Warning", "No is the saddest experience"); } final Map<String, List<String>> responseHeaders = threadContext.getResponseHeaders(); final List<String> foo = responseHeaders.get("foo"); final List<String> baz = responseHeaders.get("baz"); final List<String> warnings = responseHeaders.get("Warning"); final int expectedWarnings = expectThird ? 3 : 2; assertThat(foo, hasSize(1)); assertThat(baz, hasSize(1)); assertEquals("bar", foo.get(0)); assertEquals(value, baz.get(0)); assertThat(warnings, hasSize(expectedWarnings)); assertThat(warnings, hasItem(equalTo("One is the loneliest number"))); assertThat(warnings, hasItem(equalTo("Two can be as bad as one"))); if (expectThird) { assertThat(warnings, hasItem(equalTo("No is the saddest experience"))); } } public void testCopyHeaders() { Settings build = Settings.builder().put("request.headers.default", "1").build(); ThreadContext threadContext = new ThreadContext(build); threadContext.copyHeaders(Collections.<String,String>emptyMap().entrySet()); threadContext.copyHeaders(Collections.<String,String>singletonMap("foo", "bar").entrySet()); assertEquals("bar", threadContext.getHeader("foo")); } public void testAccessClosed() throws IOException { Settings build = Settings.builder().put("request.headers.default", "1").build(); ThreadContext threadContext = new ThreadContext(build); threadContext.putHeader("foo", "bar"); threadContext.putTransient("ctx.foo", 1); threadContext.close(); try { threadContext.getHeader("foo"); fail(); } catch (IllegalStateException ise) { assertEquals("threadcontext is already closed", ise.getMessage()); } try { threadContext.putTransient("foo", new Object()); fail(); } catch (IllegalStateException ise) { assertEquals("threadcontext is already closed", ise.getMessage()); } try { threadContext.putHeader("boom", "boom"); fail(); } catch (IllegalStateException ise) { assertEquals("threadcontext is already closed", ise.getMessage()); } } public void testSerialize() throws IOException { Settings build = Settings.builder().put("request.headers.default", "1").build(); ThreadContext threadContext = new ThreadContext(build); threadContext.putHeader("foo", "bar"); threadContext.putTransient("ctx.foo", 1); threadContext.addResponseHeader("Warning", "123456"); if (rarely()) { threadContext.addResponseHeader("Warning", "123456"); } threadContext.addResponseHeader("Warning", "234567"); BytesStreamOutput out = new BytesStreamOutput(); threadContext.writeTo(out); try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("ctx.foo")); assertTrue(threadContext.getResponseHeaders().isEmpty()); assertEquals("1", threadContext.getHeader("default")); threadContext.readHeaders(out.bytes().streamInput()); assertEquals("bar", threadContext.getHeader("foo")); assertNull(threadContext.getTransient("ctx.foo")); final Map<String, List<String>> responseHeaders = threadContext.getResponseHeaders(); final List<String> warnings = responseHeaders.get("Warning"); assertThat(responseHeaders.keySet(), hasSize(1)); assertThat(warnings, hasSize(2)); assertThat(warnings, hasItem(equalTo("123456"))); assertThat(warnings, hasItem(equalTo("234567"))); } assertEquals("bar", threadContext.getHeader("foo")); assertEquals(Integer.valueOf(1), threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); } public void testSerializeInDifferentContext() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); { Settings build = Settings.builder().put("request.headers.default", "1").build(); ThreadContext threadContext = new ThreadContext(build); threadContext.putHeader("foo", "bar"); threadContext.putTransient("ctx.foo", 1); threadContext.addResponseHeader("Warning", "123456"); if (rarely()) { threadContext.addResponseHeader("Warning", "123456"); } threadContext.addResponseHeader("Warning", "234567"); assertEquals("bar", threadContext.getHeader("foo")); assertNotNull(threadContext.getTransient("ctx.foo")); assertEquals("1", threadContext.getHeader("default")); assertThat(threadContext.getResponseHeaders().keySet(), hasSize(1)); threadContext.writeTo(out); } { Settings otherSettings = Settings.builder().put("request.headers.default", "5").build(); ThreadContext otherThreadContext = new ThreadContext(otherSettings); otherThreadContext.readHeaders(out.bytes().streamInput()); assertEquals("bar", otherThreadContext.getHeader("foo")); assertNull(otherThreadContext.getTransient("ctx.foo")); assertEquals("1", otherThreadContext.getHeader("default")); final Map<String, List<String>> responseHeaders = otherThreadContext.getResponseHeaders(); final List<String> warnings = responseHeaders.get("Warning"); assertThat(responseHeaders.keySet(), hasSize(1)); assertThat(warnings, hasSize(2)); assertThat(warnings, hasItem(equalTo("123456"))); assertThat(warnings, hasItem(equalTo("234567"))); } } public void testSerializeInDifferentContextNoDefaults() throws IOException { BytesStreamOutput out = new BytesStreamOutput(); { ThreadContext threadContext = new ThreadContext(Settings.EMPTY); threadContext.putHeader("foo", "bar"); threadContext.putTransient("ctx.foo", 1); assertEquals("bar", threadContext.getHeader("foo")); assertNotNull(threadContext.getTransient("ctx.foo")); assertNull(threadContext.getHeader("default")); threadContext.writeTo(out); } { Settings otherSettings = Settings.builder().put("request.headers.default", "5").build(); ThreadContext otherhreadContext = new ThreadContext(otherSettings); otherhreadContext.readHeaders(out.bytes().streamInput()); assertEquals("bar", otherhreadContext.getHeader("foo")); assertNull(otherhreadContext.getTransient("ctx.foo")); assertEquals("5", otherhreadContext.getHeader("default")); } } public void testCanResetDefault() { Settings build = Settings.builder().put("request.headers.default", "1").build(); ThreadContext threadContext = new ThreadContext(build); threadContext.putHeader("default", "2"); assertEquals("2", threadContext.getHeader("default")); } public void testStashAndMergeWithModifiedDefaults() { Settings build = Settings.builder().put("request.headers.default", "1").build(); ThreadContext threadContext = new ThreadContext(build); HashMap<String, String> toMerge = new HashMap<>(); toMerge.put("default", "2"); try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(toMerge)) { assertEquals("2", threadContext.getHeader("default")); } build = Settings.builder().put("request.headers.default", "1").build(); threadContext = new ThreadContext(build); threadContext.putHeader("default", "4"); toMerge = new HashMap<>(); toMerge.put("default", "2"); try (ThreadContext.StoredContext ctx = threadContext.stashAndMergeHeaders(toMerge)) { assertEquals("4", threadContext.getHeader("default")); } } public void testPreserveContext() throws IOException { try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) { Runnable withContext; // Create a runnable that should run with some header try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { threadContext.putHeader("foo", "bar"); withContext = threadContext.preserveContext(sometimesAbstractRunnable(() -> { assertEquals("bar", threadContext.getHeader("foo")); })); } // We don't see the header outside of the runnable assertNull(threadContext.getHeader("foo")); // But we do inside of it withContext.run(); // but not after assertNull(threadContext.getHeader("foo")); } } public void testPreserveContextKeepsOriginalContextWhenCalledTwice() throws IOException { try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) { Runnable originalWithContext; Runnable withContext; // Create a runnable that should run with some header try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { threadContext.putHeader("foo", "bar"); withContext = threadContext.preserveContext(sometimesAbstractRunnable(() -> { assertEquals("bar", threadContext.getHeader("foo")); })); } // Now attempt to rewrap it originalWithContext = withContext; try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { threadContext.putHeader("foo", "zot"); withContext = threadContext.preserveContext(withContext); } // We get the original context inside the runnable withContext.run(); // In fact the second wrapping didn't even change it assertThat(withContext, sameInstance(originalWithContext)); } } public void testPreservesThreadsOriginalContextOnRunException() throws IOException { try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) { Runnable withContext; // create a abstract runnable, add headers and transient objects and verify in the methods try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { threadContext.putHeader("foo", "bar"); boolean systemContext = randomBoolean(); if (systemContext) { threadContext.markAsSystemContext(); } threadContext.putTransient("foo", "bar_transient"); withContext = threadContext.preserveContext(new AbstractRunnable() { @Override public void onAfter() { assertEquals(systemContext, threadContext.isSystemContext()); assertEquals("bar", threadContext.getHeader("foo")); assertEquals("bar_transient", threadContext.getTransient("foo")); assertNotNull(threadContext.getTransient("failure")); assertEquals("exception from doRun", ((RuntimeException)threadContext.getTransient("failure")).getMessage()); assertFalse(threadContext.isDefaultContext()); threadContext.putTransient("after", "after"); } @Override public void onFailure(Exception e) { assertEquals(systemContext, threadContext.isSystemContext()); assertEquals("exception from doRun", e.getMessage()); assertEquals("bar", threadContext.getHeader("foo")); assertEquals("bar_transient", threadContext.getTransient("foo")); assertFalse(threadContext.isDefaultContext()); threadContext.putTransient("failure", e); } @Override protected void doRun() throws Exception { assertEquals(systemContext, threadContext.isSystemContext()); assertEquals("bar", threadContext.getHeader("foo")); assertEquals("bar_transient", threadContext.getTransient("foo")); assertFalse(threadContext.isDefaultContext()); throw new RuntimeException("exception from doRun"); } }); } // We don't see the header outside of the runnable assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("foo")); assertNull(threadContext.getTransient("failure")); assertNull(threadContext.getTransient("after")); assertTrue(threadContext.isDefaultContext()); // But we do inside of it withContext.run(); // verify not seen after assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("foo")); assertNull(threadContext.getTransient("failure")); assertNull(threadContext.getTransient("after")); assertTrue(threadContext.isDefaultContext()); // repeat with regular runnable try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { threadContext.putHeader("foo", "bar"); threadContext.putTransient("foo", "bar_transient"); withContext = threadContext.preserveContext(() -> { assertEquals("bar", threadContext.getHeader("foo")); assertEquals("bar_transient", threadContext.getTransient("foo")); assertFalse(threadContext.isDefaultContext()); threadContext.putTransient("run", true); throw new RuntimeException("exception from run"); }); } assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("foo")); assertNull(threadContext.getTransient("run")); assertTrue(threadContext.isDefaultContext()); final Runnable runnable = withContext; RuntimeException e = expectThrows(RuntimeException.class, runnable::run); assertEquals("exception from run", e.getMessage()); assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("foo")); assertNull(threadContext.getTransient("run")); assertTrue(threadContext.isDefaultContext()); } } public void testPreservesThreadsOriginalContextOnFailureException() throws IOException { try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) { Runnable withContext; // a runnable that throws from onFailure try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { threadContext.putHeader("foo", "bar"); threadContext.putTransient("foo", "bar_transient"); withContext = threadContext.preserveContext(new AbstractRunnable() { @Override public void onFailure(Exception e) { throw new RuntimeException("from onFailure", e); } @Override protected void doRun() throws Exception { assertEquals("bar", threadContext.getHeader("foo")); assertEquals("bar_transient", threadContext.getTransient("foo")); assertFalse(threadContext.isDefaultContext()); throw new RuntimeException("from doRun"); } }); } // We don't see the header outside of the runnable assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("foo")); assertTrue(threadContext.isDefaultContext()); // But we do inside of it RuntimeException e = expectThrows(RuntimeException.class, withContext::run); assertEquals("from onFailure", e.getMessage()); assertEquals("from doRun", e.getCause().getMessage()); // but not after assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("foo")); assertTrue(threadContext.isDefaultContext()); } } public void testPreservesThreadsOriginalContextOnAfterException() throws IOException { try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) { Runnable withContext; // a runnable that throws from onAfter try (ThreadContext.StoredContext ignored = threadContext.stashContext()) { threadContext.putHeader("foo", "bar"); threadContext.putTransient("foo", "bar_transient"); withContext = threadContext.preserveContext(new AbstractRunnable() { @Override public void onAfter() { throw new RuntimeException("from onAfter"); } @Override public void onFailure(Exception e) { throw new RuntimeException("from onFailure", e); } @Override protected void doRun() throws Exception { assertEquals("bar", threadContext.getHeader("foo")); assertEquals("bar_transient", threadContext.getTransient("foo")); assertFalse(threadContext.isDefaultContext()); } }); } // We don't see the header outside of the runnable assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("foo")); assertTrue(threadContext.isDefaultContext()); // But we do inside of it RuntimeException e = expectThrows(RuntimeException.class, withContext::run); assertEquals("from onAfter", e.getMessage()); assertNull(e.getCause()); // but not after assertNull(threadContext.getHeader("foo")); assertNull(threadContext.getTransient("foo")); assertTrue(threadContext.isDefaultContext()); } } public void testMarkAsSystemContext() throws IOException { try (ThreadContext threadContext = new ThreadContext(Settings.EMPTY)) { assertFalse(threadContext.isSystemContext()); try(ThreadContext.StoredContext context = threadContext.stashContext()){ assertFalse(threadContext.isSystemContext()); threadContext.markAsSystemContext(); assertTrue(threadContext.isSystemContext()); } assertFalse(threadContext.isSystemContext()); } } /** * Sometimes wraps a Runnable in an AbstractRunnable. */ private Runnable sometimesAbstractRunnable(Runnable r) { if (random().nextBoolean()) { return r; } return new AbstractRunnable() { @Override public void onFailure(Exception e) { throw new RuntimeException(e); } @Override protected void doRun() throws Exception { r.run(); } }; } }