/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sshd.common.io.nio2;
import java.io.IOException;
import java.net.SocketAddress;
import java.nio.channels.AsynchronousChannelGroup;
import java.nio.channels.AsynchronousServerSocketChannel;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.sshd.common.FactoryManager;
import org.apache.sshd.common.future.CloseFuture;
import org.apache.sshd.common.io.IoAcceptor;
import org.apache.sshd.common.io.IoHandler;
import org.apache.sshd.common.util.ValidateUtils;
/**
* @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
*/
public class Nio2Acceptor extends Nio2Service implements IoAcceptor {
protected final Map<SocketAddress, AsynchronousServerSocketChannel> channels = new ConcurrentHashMap<>();
private int backlog = DEFAULT_BACKLOG;
public Nio2Acceptor(FactoryManager manager, IoHandler handler, AsynchronousChannelGroup group) {
super(manager, handler, group);
backlog = manager.getIntProperty(FactoryManager.SOCKET_BACKLOG, DEFAULT_BACKLOG);
}
@Override
public void bind(Collection<? extends SocketAddress> addresses) throws IOException {
AsynchronousChannelGroup group = getChannelGroup();
for (SocketAddress address : addresses) {
if (log.isDebugEnabled()) {
log.debug("Binding Nio2Acceptor to address {}", address);
}
AsynchronousServerSocketChannel socket =
setSocketOptions(openAsynchronousServerSocketChannel(address, group));
socket.bind(address, backlog);
SocketAddress local = socket.getLocalAddress();
channels.put(local, socket);
CompletionHandler<AsynchronousSocketChannel, ? super SocketAddress> handler =
ValidateUtils.checkNotNull(createSocketCompletionHandler(channels, socket),
"No completion handler created for address=%s",
address);
socket.accept(local, handler);
}
}
protected AsynchronousServerSocketChannel openAsynchronousServerSocketChannel(
SocketAddress address, AsynchronousChannelGroup group) throws IOException {
return AsynchronousServerSocketChannel.open(group);
}
protected CompletionHandler<AsynchronousSocketChannel, ? super SocketAddress> createSocketCompletionHandler(
Map<SocketAddress, AsynchronousServerSocketChannel> channelsMap, AsynchronousServerSocketChannel socket) throws IOException {
return new AcceptCompletionHandler(socket);
}
@Override
public void bind(SocketAddress address) throws IOException {
bind(Collections.singleton(address));
}
@Override
public void unbind() {
log.debug("Unbinding");
unbind(getBoundAddresses());
}
@Override
public void unbind(Collection<? extends SocketAddress> addresses) {
for (SocketAddress address : addresses) {
AsynchronousServerSocketChannel channel = channels.remove(address);
if (channel != null) {
try {
if (log.isTraceEnabled()) {
log.trace("unbind({})", address);
}
channel.close();
} catch (IOException e) {
log.warn("unbind({}) {} while unbinding channel: {}",
address, e.getClass().getSimpleName(), e.getMessage());
if (log.isDebugEnabled()) {
log.debug("unbind(" + address + ") failure details", e);
}
}
} else {
if (log.isTraceEnabled()) {
log.trace("No active channel to unbind {}", address);
}
}
}
}
@Override
public void unbind(SocketAddress address) {
unbind(Collections.singleton(address));
}
@Override
public Set<SocketAddress> getBoundAddresses() {
return new HashSet<>(channels.keySet());
}
@Override
public CloseFuture close(boolean immediately) {
unbind();
return super.close(immediately);
}
@Override
public void doCloseImmediately() {
for (SocketAddress address : channels.keySet()) {
try {
channels.get(address).close();
} catch (IOException e) {
log.debug("Exception caught while closing channel", e);
}
}
super.doCloseImmediately();
}
protected class AcceptCompletionHandler extends Nio2CompletionHandler<AsynchronousSocketChannel, SocketAddress> {
protected final AsynchronousServerSocketChannel socket;
AcceptCompletionHandler(AsynchronousServerSocketChannel socket) {
this.socket = socket;
}
@Override
@SuppressWarnings("synthetic-access")
protected void onCompleted(AsynchronousSocketChannel result, SocketAddress address) {
// Verify that the address has not been unbound
if (!channels.containsKey(address)) {
return;
}
Nio2Session session = null;
try {
// Create a session
IoHandler handler = getIoHandler();
setSocketOptions(result);
session = Objects.requireNonNull(createSession(Nio2Acceptor.this, address, result, handler), "No NIO2 session created");
handler.sessionCreated(session);
sessions.put(session.getId(), session);
session.startReading();
} catch (Throwable exc) {
failed(exc, address);
// fail fast the accepted connection
if (session != null) {
try {
session.close();
} catch (Throwable t) {
log.warn("Failed (" + t.getClass().getSimpleName() + ")"
+ " to close accepted connection from " + address
+ ": " + t.getMessage(),
t);
}
}
}
try {
// Accept new connections
socket.accept(address, this);
} catch (Throwable exc) {
failed(exc, address);
}
}
@SuppressWarnings("synthetic-access")
protected Nio2Session createSession(Nio2Acceptor acceptor, SocketAddress address, AsynchronousSocketChannel channel, IoHandler handler) throws Throwable {
if (log.isTraceEnabled()) {
log.trace("createNio2Session({}) address={}", acceptor, address);
}
return new Nio2Session(acceptor, getFactoryManager(), handler, channel);
}
@Override
@SuppressWarnings("synthetic-access")
protected void onFailed(final Throwable exc, final SocketAddress address) {
if (channels.containsKey(address) && !disposing.get()) {
log.warn("Caught " + exc.getClass().getSimpleName()
+ " while accepting incoming connection from " + address
+ ": " + exc.getMessage(),
exc);
}
}
}
}