/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.activemq.artemis.protocol.amqp.proton.handler;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.activemq.artemis.protocol.amqp.proton.ProtonInitializable;
import org.apache.activemq.artemis.protocol.amqp.sasl.SASLResult;
import org.apache.activemq.artemis.protocol.amqp.sasl.ServerSASL;
import org.apache.activemq.artemis.spi.core.remoting.ReadyListener;
import org.apache.activemq.artemis.utils.ByteUtil;
import org.apache.qpid.proton.Proton;
import org.apache.qpid.proton.amqp.Symbol;
import org.apache.qpid.proton.amqp.transport.AmqpError;
import org.apache.qpid.proton.amqp.transport.ErrorCondition;
import org.apache.qpid.proton.engine.Collector;
import org.apache.qpid.proton.engine.Connection;
import org.apache.qpid.proton.engine.EndpointState;
import org.apache.qpid.proton.engine.Event;
import org.apache.qpid.proton.engine.Sasl;
import org.apache.qpid.proton.engine.Transport;
import org.jboss.logging.Logger;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
public class ProtonHandler extends ProtonInitializable {
private static final Logger log = Logger.getLogger(ProtonHandler.class);
private static final byte SASL = 0x03;
private static final byte BARE = 0x00;
private final Transport transport = Proton.transport();
private final Connection connection = Proton.connection();
private final Collector collector = Proton.collector();
private List<EventHandler> handlers = new ArrayList<>();
private Sasl serverSasl;
private final ReentrantLock lock = new ReentrantLock();
private final long creationTime;
private Map<String, ServerSASL> saslHandlers;
private SASLResult saslResult;
protected volatile boolean dataReceived;
protected boolean receivedFirstPacket = false;
private final Executor flushExecutor;
protected final ReadyListener readyListener;
boolean inDispatch = false;
public ProtonHandler(Executor flushExecutor) {
this.flushExecutor = flushExecutor;
this.readyListener = () -> flushExecutor.execute(() -> {
flush();
});
this.creationTime = System.currentTimeMillis();
transport.bind(connection);
connection.collect(collector);
}
public long tick(boolean firstTick) {
lock.lock();
try {
if (!firstTick) {
try {
if (connection.getLocalState() != EndpointState.CLOSED) {
long rescheduleAt = transport.tick(TimeUnit.NANOSECONDS.toMillis(System.nanoTime()));
if (transport.isClosed()) {
throw new IllegalStateException("Channel was inactive for to long");
}
return rescheduleAt;
}
} catch (Exception e) {
log.warn(e.getMessage(), e);
transport.close();
connection.setCondition(new ErrorCondition());
}
return 0;
}
return transport.tick(TimeUnit.NANOSECONDS.toMillis(System.nanoTime()));
} finally {
lock.unlock();
flushBytes();
}
}
public int capacity() {
lock.lock();
try {
return transport.capacity();
} finally {
lock.unlock();
}
}
public void lock() {
lock.lock();
}
public void unlock() {
lock.unlock();
}
public boolean tryLock(long time, TimeUnit timeUnit) {
try {
return lock.tryLock(time, timeUnit);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return false;
}
}
public Transport getTransport() {
return transport;
}
public Connection getConnection() {
return connection;
}
public ProtonHandler addEventHandler(EventHandler handler) {
handlers.add(handler);
return this;
}
public void createServerSASL(ServerSASL[] handlers) {
this.serverSasl = transport.sasl();
saslHandlers = new HashMap<>();
String[] names = new String[handlers.length];
int count = 0;
for (ServerSASL handler : handlers) {
saslHandlers.put(handler.getName(), handler);
names[count++] = handler.getName();
}
this.serverSasl.server();
serverSasl.setMechanisms(names);
}
public void flushBytes() {
for (EventHandler handler : handlers) {
if (!handler.flowControl(readyListener)) {
return;
}
}
lock.lock();
try {
while (true) {
int pending = transport.pending();
if (pending <= 0) {
break;
}
// We allocated a Pooled Direct Buffer, that will be sent down the stream
ByteBuf buffer = PooledByteBufAllocator.DEFAULT.directBuffer(pending);
ByteBuffer head = transport.head();
buffer.writeBytes(head);
for (EventHandler handler : handlers) {
handler.pushBytes(buffer);
}
transport.pop(pending);
}
} finally {
lock.unlock();
}
}
public SASLResult getSASLResult() {
return saslResult;
}
public void inputBuffer(ByteBuf buffer) {
dataReceived = true;
lock.lock();
try {
while (buffer.readableBytes() > 0) {
int capacity = transport.capacity();
if (!receivedFirstPacket) {
try {
byte auth = buffer.getByte(4);
if (auth == SASL || auth == BARE) {
dispatchAuth(auth == SASL);
/*
* there is a chance that if SASL Handshake has been carried out that the capacity may change.
* */
capacity = transport.capacity();
}
} catch (Throwable e) {
log.warn(e.getMessage(), e);
}
receivedFirstPacket = true;
}
if (capacity > 0) {
ByteBuffer tail = transport.tail();
int min = Math.min(capacity, buffer.readableBytes());
tail.limit(min);
buffer.readBytes(tail);
flush();
} else {
if (capacity == 0) {
log.debugf("abandoning: readableBytes=%d", buffer.readableBytes());
} else {
log.debugf("transport closed, discarding: readableBytes=%d, capacity=%d", buffer.readableBytes(), transport.capacity());
}
break;
}
}
} finally {
lock.unlock();
}
}
public boolean checkDataReceived() {
boolean res = dataReceived;
dataReceived = false;
return res;
}
public long getCreationTime() {
return creationTime;
}
public void flush() {
lock.lock();
try {
transport.process();
checkServerSASL();
} finally {
lock.unlock();
}
dispatch();
}
public void close(ErrorCondition errorCondition) {
lock.lock();
try {
if (errorCondition != null) {
connection.setCondition(errorCondition);
}
connection.close();
} finally {
lock.unlock();
}
flush();
}
protected void checkServerSASL() {
if (serverSasl != null && serverSasl.getRemoteMechanisms().length > 0) {
// TODO: should we look at the first only?
ServerSASL mechanism = saslHandlers.get(serverSasl.getRemoteMechanisms()[0]);
if (mechanism != null) {
byte[] dataSASL = new byte[serverSasl.pending()];
serverSasl.recv(dataSASL, 0, dataSASL.length);
if (log.isTraceEnabled()) {
log.trace("Working on sasl::" + (dataSASL != null && dataSASL.length > 0 ? ByteUtil.bytesToHex(dataSASL, 2) : "Anonymous"));
}
saslResult = mechanism.processSASL(dataSASL);
if (saslResult != null && saslResult.isSuccess()) {
serverSasl.done(Sasl.SaslOutcome.PN_SASL_OK);
serverSasl = null;
saslHandlers.clear();
saslHandlers = null;
} else {
serverSasl.done(Sasl.SaslOutcome.PN_SASL_AUTH);
}
serverSasl = null;
} else {
// no auth available, system error
serverSasl.done(Sasl.SaslOutcome.PN_SASL_SYS);
}
}
}
private void dispatchAuth(boolean sasl) {
for (EventHandler h : handlers) {
h.onAuthInit(this, getConnection(), sasl);
}
}
private void dispatch() {
Event ev;
lock.lock();
try {
if (inDispatch) {
// Avoid recursion from events
return;
}
try {
inDispatch = true;
while ((ev = collector.peek()) != null) {
for (EventHandler h : handlers) {
if (log.isTraceEnabled()) {
log.trace("Handling " + ev + " towards " + h);
}
try {
Events.dispatch(ev, h);
} catch (Exception e) {
log.warn(e.getMessage(), e);
ErrorCondition error = new ErrorCondition();
error.setCondition(AmqpError.INTERNAL_ERROR);
error.setDescription("Unrecoverable error: " +
(e.getMessage() == null ? e.getClass().getSimpleName() : e.getMessage()));
connection.setCondition(error);
connection.close();
}
}
collector.pop();
}
} finally {
inDispatch = false;
}
} finally {
lock.unlock();
}
flushBytes();
}
public void open(String containerId, Map<Symbol, Object> connectionProperties) {
this.transport.open();
this.connection.setContainer(containerId);
this.connection.setProperties(connectionProperties);
this.connection.open();
flush();
}
}