/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.nifi.cluster.protocol.jaxb; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.ByteBuffer; import javax.xml.bind.JAXBContext; import javax.xml.bind.JAXBException; import javax.xml.bind.Marshaller; import javax.xml.bind.Unmarshaller; import org.apache.nifi.cluster.protocol.ProtocolContext; import org.apache.nifi.cluster.protocol.ProtocolMessageMarshaller; import org.apache.nifi.cluster.protocol.ProtocolMessageUnmarshaller; /** * Implements a context for communicating internally amongst the cluster using * JAXB. * * @param <T> The type of protocol message. * */ public class JaxbProtocolContext<T> implements ProtocolContext { private static final int BUF_SIZE = (int) Math.pow(2, 10); // 1k /* * A sentinel is used to detect corrupted messages. Relying on the integrity * of the message size can cause memory issues if the value is corrupted * and equal to a number larger than the memory size. */ private static final byte MESSAGE_PROTOCOL_START_SENTINEL = 0x5A; private final JAXBContext jaxbCtx; public JaxbProtocolContext(final JAXBContext jaxbCtx) { this.jaxbCtx = jaxbCtx; } @Override public ProtocolMessageMarshaller<T> createMarshaller() { return new ProtocolMessageMarshaller<T>() { @Override public void marshal(final T msg, final OutputStream os) throws IOException { try { // marshal message to output stream final Marshaller marshaller = jaxbCtx.createMarshaller(); final ByteArrayOutputStream msgBytes = new ByteArrayOutputStream(); marshaller.marshal(msg, msgBytes); final DataOutputStream dos = new DataOutputStream(os); // write message protocol sentinel dos.write(MESSAGE_PROTOCOL_START_SENTINEL); // write message size in bytes dos.writeInt(msgBytes.size()); // write message dos.write(msgBytes.toByteArray()); dos.flush(); } catch (final JAXBException je) { throw new IOException("Failed marshalling protocol message due to: " + je, je); } } }; } @Override public ProtocolMessageUnmarshaller<T> createUnmarshaller() { return new ProtocolMessageUnmarshaller<T>() { @Override public T unmarshal(final InputStream is) throws IOException { try { final DataInputStream dis = new DataInputStream(is); // check for the presence of the message protocol sentinel final byte sentinel = (byte) dis.read(); if (sentinel == -1) { throw new EOFException(); } if (MESSAGE_PROTOCOL_START_SENTINEL != sentinel) { throw new IOException("Failed reading protocol message due to malformed header"); } // read the message size final int msgBytesSize = dis.readInt(); // read the message final ByteBuffer buffer = ByteBuffer.allocate(msgBytesSize); int totalBytesRead = 0; do { final int bytesToRead; if ((msgBytesSize - totalBytesRead) >= BUF_SIZE) { bytesToRead = BUF_SIZE; } else { bytesToRead = msgBytesSize - totalBytesRead; } totalBytesRead += dis.read(buffer.array(), totalBytesRead, bytesToRead); } while (totalBytesRead < msgBytesSize); // unmarshall message and return final Unmarshaller unmarshaller = jaxbCtx.createUnmarshaller(); final byte[] msg = new byte[totalBytesRead]; buffer.get(msg); return (T) unmarshaller.unmarshal(new ByteArrayInputStream(msg)); } catch (final JAXBException je) { throw new IOException("Failed unmarshalling protocol message due to: " + je, je); } } }; } }