/* Copyright 2016 predic8 GmbH, www.predic8.com
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
package com.predic8.membrane.core.rules;
import com.google.common.base.Objects;
import com.predic8.membrane.annot.MCAttribute;
import com.predic8.membrane.annot.MCChildElement;
import com.predic8.membrane.annot.MCElement;
import com.predic8.membrane.core.Router;
import com.predic8.membrane.core.config.security.SSLParser;
import com.predic8.membrane.core.exchange.Exchange;
import com.predic8.membrane.core.interceptor.Interceptor;
import com.predic8.membrane.core.transport.http.Connection;
import com.predic8.membrane.core.transport.http.ConnectionManager;
import com.predic8.membrane.core.transport.http.StreamPump;
import com.predic8.membrane.core.transport.http.client.ConnectionConfiguration;
import com.predic8.membrane.core.transport.ssl.StaticSSLContext;
import com.predic8.membrane.core.transport.ssl.SSLContext;
import com.predic8.membrane.core.transport.ssl.SSLProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Required;
import java.io.IOException;
import java.net.Socket;
import java.net.SocketException;
import java.util.List;
import java.util.Map;
@MCElement(name="sslProxy")
public class SSLProxy implements Rule {
private static Logger log = LoggerFactory.getLogger(SSLProxy.class.getName());
private Target target;
private ConnectionConfiguration connectionConfiguration = new ConnectionConfiguration();
@MCElement(id = "sslProxy-target", name="target", topLevel = false)
public static class Target {
private int port = -1;
private String host;
public int getPort() {
return port;
}
@MCAttribute
public void setPort(int port) {
this.port = port;
}
public String getHost() {
return host;
}
@MCAttribute
public void setHost(String host) {
this.host = host;
}
}
public ConnectionConfiguration getConnectionConfiguration() {
return connectionConfiguration;
}
@MCChildElement(order = 0)
public void setConnectionConfiguration(ConnectionConfiguration connectionConfiguration) {
this.connectionConfiguration = connectionConfiguration;
}
public Target getTarget() {
return target;
}
@Required
@MCChildElement(order = 100)
public void setTarget(Target target) {
this.target = target;
}
@Override
public List<Interceptor> getInterceptors() {
return null;
}
@Override
public void setInterceptors(List<Interceptor> interceptors) {
}
@Override
public boolean isBlockRequest() {
return false;
}
@Override
public boolean isBlockResponse() {
return false;
}
int port;
public int getPort() {
return port;
}
@MCAttribute
public void setPort(int port) {
this.port = port;
}
String host;
public String getHost() {
return host;
}
@MCAttribute
public void setHost(String host) {
this.host = host;
}
@Override
public RuleKey getKey() {
return new MyRuleKey();
}
@Override
public void setKey(RuleKey ruleKey) {
}
@Override
public void setName(String name) {
}
@Override
public String getName() {
return "SSL " + getHost() + ":" + getPort();
}
@Override
public void setBlockRequest(boolean blockStatus) {
}
@Override
public void setBlockResponse(boolean blockStatus) {
}
@Override
public void collectStatisticsFrom(Exchange exc) {
}
@Override
public Map<Integer, StatisticCollector> getStatisticsByStatusCodes() {
return null;
}
@Override
public int getCount() {
return 0;
}
ConnectionManager cm;
@Override
public SSLContext getSslInboundContext() {
return new StaticSSLContext(new SSLParser(), router.getResolverMap(), router.getBaseLocation()) {
@Override
public Socket wrap(Socket socket, byte[] buffer, int position) throws IOException {
int port = target.getPort();
if (port == -1)
port = getPort();
StreamPump.StreamPumpStats streamPumpStats = router.getStatistics().getStreamPumpStats();
String protocol = "SSL";
Connection con = cm.getConnection(target.getHost(), port, connectionConfiguration.getLocalAddr(), null, connectionConfiguration.getTimeout());
con.out.write(buffer, 0, position);
con.out.flush();
String source = socket.getRemoteSocketAddress().toString();
String dest = con.toString();
final StreamPump a = new StreamPump(con.in, socket.getOutputStream(), streamPumpStats, protocol + " " + source + " <- " + dest, SSLProxy.this);
final StreamPump b = new StreamPump(socket.getInputStream(), con.out, streamPumpStats, protocol + " " + source + " -> " + dest, SSLProxy.this);
socket.setSoTimeout(0);
String threadName = Thread.currentThread().getName();
new Thread(a, threadName + " " + protocol + " Backward Thread").start();
try {
Thread.currentThread().setName(threadName + " " + protocol + " Onward Thread");
b.run();
} finally {
try {
con.close();
} catch (IOException e) {
log.debug("", e);
}
}
throw new SocketException("SSL Forwarding Connection closed.");
}
@Override
public String constructHostNamePattern() {
return getKey().getHost();
}
};
}
@Override
public SSLProvider getSslOutboundContext() {
return null;
}
Router router;
@Override
public void init(Router router) throws Exception {
this.router = router;
cm = new ConnectionManager(connectionConfiguration.getKeepAliveTimeout());
}
@Override
public boolean isTargetAdjustHostHeader() {
return false;
}
@Override
public boolean isActive() {
return true;
}
@Override
public String getErrorState() {
return null;
}
@Override
public SSLProxy clone() throws CloneNotSupportedException {
SSLProxy clone = (SSLProxy) super.clone();
try {
clone.init(router);
} catch (Exception e) {
throw new RuntimeException(e);
}
return clone;
}
private class MyRuleKey implements RuleKey {
@Override
public int getPort() {
return port;
}
@Override
public String getMethod() {
return null;
}
@Override
public String getPath() {
return null;
}
@Override
public String getHost() {
return host;
}
@Override
public boolean isMethodWildcard() {
return false;
}
@Override
public boolean isPathRegExp() {
return false;
}
@Override
public boolean isUsePathPattern() {
return false;
}
@Override
public void setUsePathPattern(boolean usePathPattern) {
}
@Override
public void setPathRegExp(boolean pathRegExp) {
}
@Override
public void setPath(String path) {
}
@Override
public boolean matchesPath(String path) {
return false;
}
@Override
public String getIp() {
return null;
}
@Override
public void setIp(String ip) {
}
@Override
public boolean matchesHostHeader(String hostHeader) {
return false;
}
@Override
public boolean matchesVersion(String version) {
return false;
}
@Override
public boolean complexMatch(String hostHeader, String method, String uri, String version, int port, String localIP) {
return false;
}
@Override
public boolean equals(Object obj) {
if (!(obj instanceof MyRuleKey))
return false;
MyRuleKey other = (MyRuleKey)obj;
return Objects.equal(getHost(), other.getHost()) && getPort() == other.getPort();
}
}
}