/* dCache - http://www.dcache.org/
*
* Copyright (C) 2015 Deutsches Elektronen-Synchrotron
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package javatunnel;
import javatunnel.token.Base64TokenReader;
import javatunnel.token.Base64TokenWriter;
import org.dcache.dss.DssContext;
import org.dcache.dss.DssContextFactory;
import javatunnel.token.UnwrappingInputStream;
import javatunnel.token.WrappingOutputStream;
import javatunnel.token.TokenReader;
import javatunnel.token.TokenWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.security.auth.Subject;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.net.SocketImpl;
import java.net.UnknownHostException;
public class DssSocket extends Socket implements TunnelSocket
{
private static final Logger LOGGER = LoggerFactory.getLogger(DssSocket.class);
private DssContext context;
private final DssContextFactory factory;
private WrappingOutputStream out;
private UnwrappingInputStream in;
DssSocket(DssContextFactory factory)
{
this.factory = factory;
}
DssSocket(SocketImpl impl, DssContextFactory factory) throws SocketException
{
super(impl);
this.factory = factory;
}
DssSocket(InetAddress address, int port, DssContextFactory factory)
throws IOException
{
super(address, port);
this.factory = factory;
}
DssSocket(InetAddress address, int port, InetAddress localAddr, int localPort, DssContextFactory factory)
throws IOException
{
super(address, port, localAddr, localPort);
this.factory = factory;
}
DssSocket(String host, int port, DssContextFactory factory)
throws UnknownHostException, IOException
{
super(host, port);
this.factory = factory;
}
DssSocket(String host, int port, InetAddress localAddr, int localPort, DssContextFactory factory)
throws IOException
{
super(host, port, localAddr, localPort);
this.factory = factory;
}
@Override
public synchronized OutputStream getOutputStream() throws IOException
{
if (isClosed())
throw new SocketException("Socket is closed");
if (!isConnected())
throw new SocketException("Socket is not connected");
if (isOutputShutdown())
throw new SocketException("Socket output is shutdown");
if (context == null || !context.isEstablished()) {
throw new SocketException("Security context is not established");
}
return out;
}
@Override
public synchronized InputStream getInputStream() throws IOException
{
if (isClosed())
throw new SocketException("Socket is closed");
if (!isConnected())
throw new SocketException("Socket is not connected");
if (isInputShutdown())
throw new SocketException("Socket input is shutdown");
if (context == null || !context.isEstablished()) {
throw new SocketException("Security context is not established");
}
return in;
}
private synchronized void acceptSecurityContext() throws IOException
{
try {
context = factory.create((InetSocketAddress) getRemoteSocketAddress(),
(InetSocketAddress) getLocalSocketAddress());
TokenWriter writer = new Base64TokenWriter(super.getOutputStream());
TokenReader reader = new Base64TokenReader(super.getInputStream());
while (!context.isEstablished()) {
byte[] inToken = reader.readToken();
if (inToken == null) {
throw new EOFException();
}
byte[] outToken = context.accept(inToken);
if (outToken != null) {
writer.write(outToken);
}
}
out = new WrappingOutputStream(writer, context);
in = new UnwrappingInputStream(reader, context);
} catch (IOException e) {
try {
close();
} catch (IOException e1) {
e.addSuppressed(e1);
}
throw e;
}
}
private synchronized void initSecurityContext() throws IOException
{
try {
context = factory.create((InetSocketAddress) getRemoteSocketAddress(),
(InetSocketAddress) getLocalSocketAddress());
TokenWriter writer = new Base64TokenWriter(super.getOutputStream());
TokenReader reader = new Base64TokenReader(super.getInputStream());
byte[] outToken = context.init(new byte[0]);
if (outToken != null) {
writer.write(outToken);
}
while (!context.isEstablished()) {
byte[] inToken = reader.readToken();
if (inToken == null) {
throw new EOFException();
}
outToken = context.init(inToken);
if (outToken != null) {
writer.write(outToken);
}
}
out = new WrappingOutputStream(writer, context);
in = new UnwrappingInputStream(reader, context);
} catch (IOException e) {
try {
close();
} catch (IOException e1) {
e.addSuppressed(e1);
}
throw e;
}
}
@Override
public void connect(SocketAddress endpoint) throws IOException
{
super.connect(endpoint);
initSecurityContext();
}
@Override
public void connect(SocketAddress endpoint, int timeout) throws IOException
{
super.connect(endpoint, timeout);
initSecurityContext();
}
@Override
public boolean verify()
{
try {
acceptSecurityContext();
return true;
} catch (IOException e) {
LOGGER.error("Failed to verify: {}", e.toString());
return false;
}
}
@Override
public Subject getSubject()
{
return (context == null || !context.isEstablished()) ? null : context.getSubject();
}
}