/* * Copyright 2013-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.springframework.cloud.netflix.ribbon; import java.net.URI; import java.util.ArrayList; import java.util.List; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.springframework.beans.factory.BeanFactoryUtils; import org.springframework.beans.factory.ListableBeanFactory; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.test.util.EnvironmentTestUtils; import org.springframework.cloud.netflix.ribbon.RibbonClientConfiguration.OverrideRestClient; import org.springframework.cloud.netflix.ribbon.apache.RibbonLoadBalancingHttpClient; import org.springframework.cloud.netflix.ribbon.okhttp.OkHttpLoadBalancingClient; import org.springframework.context.annotation.AnnotationConfigApplicationContext; import org.springframework.context.annotation.Configuration; import com.netflix.client.AbstractLoadBalancerAwareClient; import com.netflix.client.config.CommonClientConfigKey; import com.netflix.client.config.DefaultClientConfigImpl; import com.netflix.client.config.IClientConfig; import com.netflix.loadbalancer.Server; import com.netflix.niws.client.http.RestClient; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.when; /** * @author Spencer Gibb */ public class RibbonClientConfigurationTests { private CountingConfig config; @Mock private ServerIntrospector inspector; @Before public void setup() { MockitoAnnotations.initMocks(this); this.config = new CountingConfig(); this.config.setProperty(CommonClientConfigKey.ConnectTimeout, "1"); this.config.setProperty(CommonClientConfigKey.ReadTimeout, "1"); this.config.setProperty(CommonClientConfigKey.MaxHttpConnectionsPerHost, "1"); this.config.setClientName("testClient"); } @Test public void restClientInitCalledOnce() { new TestRestClient(this.config); assertThat(this.config.count, is(1)); } @Test public void restClientWithSecureServer() throws Exception { CountingConfig config = new CountingConfig(); config.setProperty(CommonClientConfigKey.ConnectTimeout, "1"); config.setProperty(CommonClientConfigKey.ReadTimeout, "1"); config.setProperty(CommonClientConfigKey.MaxHttpConnectionsPerHost, "1"); config.setClientName("bar"); Server server = new Server("example.com", 443); URI uri = new TestRestClient(config).reconstructURIWithServer(server, new URI("/foo")); assertThat(uri.getScheme(), is("https")); assertThat(uri.getHost(), is("example.com")); } static class CountingConfig extends DefaultClientConfigImpl { int count = 0; } @Test public void testSecureUriFromClientConfig() throws Exception { Server server = new Server("foo", 7777); when(this.inspector.isSecure(server)).thenReturn(true); for (AbstractLoadBalancerAwareClient client : clients()) { URI uri = client.reconstructURIWithServer(server, new URI("http://foo/")); assertThat(getReason(client), uri, is(new URI("https://foo:7777/"))); } } @Test public void testInSecureUriFromClientConfig() throws Exception { Server server = new Server("foo", 7777); when(this.inspector.isSecure(server)).thenReturn(false); for (AbstractLoadBalancerAwareClient client : clients()) { URI uri = client.reconstructURIWithServer(server, new URI("http://foo/")); assertThat(getReason(client), uri, is(new URI("http://foo:7777/"))); } } String getReason(AbstractLoadBalancerAwareClient client) { return client.getClass().getSimpleName()+" failed"; } @Test public void testNotDoubleEncodedWhenSecure() throws Exception { Server server = new Server("foo", 7777); when(this.inspector.isSecure(server)).thenReturn(true); for (AbstractLoadBalancerAwareClient client : clients()) { URI uri = client.reconstructURIWithServer(server, new URI("http://foo/%20bar")); assertThat(getReason(client), uri, is(new URI("https://foo:7777/%20bar"))); } } @Test public void testPlusInQueryStringGetsRewrittenWhenServerIsSecure() throws Exception { Server server = new Server("foo", 7777); when(this.inspector.isSecure(server)).thenReturn(true); for (AbstractLoadBalancerAwareClient client : clients()) { URI uri = client.reconstructURIWithServer(server, new URI("http://foo/%20bar?hello=1+2")); assertThat(uri, is(new URI("https://foo:7777/%20bar?hello=1%202"))); } } private List<AbstractLoadBalancerAwareClient> clients() { ArrayList<AbstractLoadBalancerAwareClient> clients = new ArrayList<>(); clients.add(new OverrideRestClient(this.config, this.inspector)); clients.add(new RibbonLoadBalancingHttpClient(this.config, this.inspector)); clients.add(new OkHttpLoadBalancingClient(this.config, this.inspector)); return clients; } @Test public void testDefaultsToApacheHttpClient() { testClient(RibbonLoadBalancingHttpClient.class, null, RestClient.class, OkHttpLoadBalancingClient.class); testClient(RibbonLoadBalancingHttpClient.class, "ribbon.httpclient.enabled", RestClient.class, OkHttpLoadBalancingClient.class); } @Test public void testEnableRestClient() { testClient(RestClient.class, "ribbon.restclient.enabled", RibbonLoadBalancingHttpClient.class, OkHttpLoadBalancingClient.class); } @Test public void testEnableOkHttpClient() { testClient(OkHttpLoadBalancingClient.class, "ribbon.okhttp.enabled", RibbonLoadBalancingHttpClient.class, RestClient.class); } void testClient(Class<?> clientType, String property, Class<?>... excludedTypes) { AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext(); context.register(RibbonAutoConfiguration.class, RibbonClientConfiguration.class); if (property != null) { EnvironmentTestUtils.addEnvironment(context, property); } context.refresh(); context.getBean(clientType); for (Class<?> excludedType : excludedTypes) { assertThat("has "+excludedType.getSimpleName()+ " instance", hasInstance(context, excludedType), is(false)); } context.close(); } private <T> boolean hasInstance(ListableBeanFactory lbf, Class<T> requiredType) { return BeanFactoryUtils.beanNamesForTypeIncludingAncestors(lbf, requiredType).length > 0; } @Configuration @EnableAutoConfiguration protected static class TestLBConfig { } static class TestRestClient extends OverrideRestClient { private TestRestClient(IClientConfig ncc) { super(ncc, new DefaultServerIntrospector()); } @Override public void initWithNiwsConfig(IClientConfig clientConfig) { ((CountingConfig) clientConfig).count++; super.initWithNiwsConfig(clientConfig); } } }