/** * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package org.apache.cxf.systest.ws.util; import java.io.ByteArrayInputStream; import java.util.ArrayList; import java.util.List; import javax.xml.namespace.QName; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.w3c.dom.Node; import org.apache.cxf.staxutils.StaxUtils; import org.apache.cxf.ws.addressing.Names; import org.apache.cxf.ws.rm.RMConstants; import org.junit.Assert; public class MessageFlow extends Assert { private final String addressingNamespace; private final String rmNamespace; private List<byte[]> inStreams; private List<byte[]> outStreams; private List<Document> outboundMessages; private List<Document> inboundMessages; public MessageFlow(List<byte[]> out, List<byte[]> in, String addrns, String rmns) throws Exception { addressingNamespace = addrns; rmNamespace = rmns; inboundMessages = new ArrayList<>(); outboundMessages = new ArrayList<>(); reset(out, in); } public MessageFlow(List<byte[]> out, List<byte[]> in) throws Exception { this(out, in, Names.WSA_NAMESPACE_NAME, null); } public void clear() throws Exception { inStreams.clear(); outStreams.clear(); } public final void reset(List<byte[]> out, List<byte[]> in) throws Exception { for (int i = 0; i < inboundMessages.size(); i++) { in.remove(0); } inStreams = in; for (int i = 0; i < outboundMessages.size(); i++) { out.remove(0); } outStreams = out; inboundMessages.clear(); for (int i = 0; i < inStreams.size(); i++) { byte[] bytes = inStreams.get(i); ByteArrayInputStream is = new ByteArrayInputStream(bytes); Document document = StaxUtils.read(is); inboundMessages.add(document); } outboundMessages.clear(); for (int i = 0; i < outStreams.size(); i++) { byte[] bytes = outStreams.get(i); ByteArrayInputStream is = new ByteArrayInputStream(bytes); Document document = StaxUtils.read(is); outboundMessages.add(document); } } public Document getMessage(int i, boolean outbound) { return outbound ? outboundMessages.get(i) : inboundMessages.get(i); } public void verifyActions(String[] expectedActions, boolean outbound) throws Exception { assertEquals(expectedActions.length, outbound ? outboundMessages.size() : inboundMessages.size()); for (int i = 0; i < expectedActions.length; i++) { Document doc = outbound ? outboundMessages.get(i) : inboundMessages.get(i); String action = getAction(doc); if (null == expectedActions[i]) { assertNull((outbound ? "Outbound " : "Inbound") + " message " + i + " has unexpected action: " + action, action); } else { assertEquals((outbound ? "Outbound " : "Inbound") + " message " + i + " does not contain expected action header" + System.getProperty("line.separator"), expectedActions[i], action); } } } public void verifyActionsIgnoringPartialResponses(String[] expectedActions) throws Exception { int j = 0; for (int i = 0; i < inboundMessages.size() && j < expectedActions.length; i++) { String action = getAction(inboundMessages.get(i)); if (null == action && emptyBody(inboundMessages.get(i))) { continue; } if (null == expectedActions[j]) { assertNull("Inbound message " + i + " has unexpected action: " + action, action); } else { assertEquals("Inbound message " + i + " has unexpected action: ", expectedActions[j], action); } j++; } if (j < expectedActions.length) { fail("Inbound messages do not contain all expected actions."); } } public boolean checkActions(String[] expectedActions, boolean outbound) throws Exception { if (expectedActions.length != (outbound ? outboundMessages.size() : inboundMessages.size())) { return false; } for (int i = 0; i < expectedActions.length; i++) { String action = outbound ? getAction(outboundMessages.get(i)) : getAction(inboundMessages.get(i)); if (null == expectedActions[i]) { if (action != null) { return false; } } else { if (!expectedActions[i].equals(action)) { return false; } } } return true; } public void verifyAction(String expectedAction, int expectedCount, boolean outbound, boolean exact) throws Exception { int messageCount = outbound ? outboundMessages.size() : inboundMessages.size(); int count = 0; for (int i = 0; i < messageCount; i++) { String action = outbound ? getAction(outboundMessages.get(i)) : getAction(inboundMessages.get(i)); if (null == expectedAction) { if (action == null) { count++; } } else { if (expectedAction.equals(action)) { count++; } } } if (exact) { assertEquals("unexpected count for action: " + expectedAction, expectedCount, count); } else { assertTrue("unexpected count for action: " + expectedAction + ": " + count, expectedCount <= count); } } public void verifyMessageNumbers(String[] expectedMessageNumbers, boolean outbound) throws Exception { verifyMessageNumbers(expectedMessageNumbers, outbound, true); } public void verifyMessageNumbers(String[] expectedMessageNumbers, boolean outbound, boolean exact) throws Exception { int actualMessageCount = outbound ? outboundMessages.size() : inboundMessages.size(); if (exact) { assertEquals(expectedMessageNumbers.length, actualMessageCount); } else { assertTrue(expectedMessageNumbers.length <= actualMessageCount); } if (exact) { for (int i = 0; i < expectedMessageNumbers.length; i++) { Document doc = outbound ? outboundMessages.get(i) : inboundMessages.get(i); Element e = getSequence(doc); if (null == expectedMessageNumbers[i]) { assertNull((outbound ? "Outbound" : "Inbound") + " message " + i + " contains unexpected message number ", e); } else { assertEquals((outbound ? "Outbound" : "Inbound") + " message " + i + " does not contain expected message number " + expectedMessageNumbers[i], expectedMessageNumbers[i], getMessageNumber(e)); } } } else { boolean[] matches = new boolean[expectedMessageNumbers.length]; for (int i = 0; i < actualMessageCount; i++) { String messageNumber = null; Element e = outbound ? getSequence(outboundMessages.get(i)) : getSequence(inboundMessages.get(i)); messageNumber = null == e ? null : getMessageNumber(e); for (int j = 0; j < expectedMessageNumbers.length; j++) { if (messageNumber == null) { if (expectedMessageNumbers[j] == null && !matches[j]) { matches[j] = true; break; } } else { if (messageNumber.equals(expectedMessageNumbers[j]) && !matches[j]) { matches[j] = true; break; } } } } for (int k = 0; k < expectedMessageNumbers.length; k++) { assertTrue("no match for message number: " + expectedMessageNumbers[k], matches[k]); } } } public void verifyLastMessage(boolean[] expectedLastMessages, boolean outbound) throws Exception { verifyLastMessage(expectedLastMessages, outbound, true); } public void verifyLastMessage(boolean[] expectedLastMessages, boolean outbound, boolean exact) throws Exception { int actualMessageCount = outbound ? outboundMessages.size() : inboundMessages.size(); if (exact) { assertEquals(expectedLastMessages.length, actualMessageCount); } else { assertTrue(expectedLastMessages.length <= actualMessageCount); } for (int i = 0; i < expectedLastMessages.length; i++) { boolean lastMessage; Element e = outbound ? getSequence(outboundMessages.get(i)) : getSequence(inboundMessages.get(i)); lastMessage = null == e ? false : getLastMessage(e); assertEquals("Outbound message " + i + (expectedLastMessages[i] ? " does not contain expected last message element." : " contains last message element."), expectedLastMessages[i], lastMessage); } } public void verifyAcknowledgements(boolean[] expectedAcks, boolean outbound) throws Exception { assertEquals(expectedAcks.length, outbound ? outboundMessages.size() : inboundMessages.size()); for (int i = 0; i < expectedAcks.length; i++) { boolean ack = outbound ? (null != getAcknowledgment(outboundMessages.get(i))) : (null != getAcknowledgment(inboundMessages.get(i))); if (expectedAcks[i]) { assertTrue((outbound ? "Outbound" : "Inbound") + " message " + i + " does not contain expected acknowledgement", ack); } else { assertFalse((outbound ? "Outbound" : "Inbound") + " message " + i + " contains unexpected acknowledgement", ack); } } } public void verifyAcknowledgements(int expectedAcks, boolean outbound, boolean exact) throws Exception { int actualMessageCount = outbound ? outboundMessages.size() : inboundMessages.size(); int ackCount = 0; for (int i = 0; i < actualMessageCount; i++) { boolean ack = outbound ? (null != getAcknowledgment(outboundMessages.get(i))) : (null != getAcknowledgment(inboundMessages.get(i))); if (ack) { ackCount++; } } if (exact) { assertEquals("unexpected number of acks", expectedAcks, ackCount); } else { assertTrue("unexpected number of acks: " + ackCount, expectedAcks <= ackCount); } } public void verifyAckRequestedOutbound(boolean outbound) throws Exception { boolean found = false; List<Document> messages = outbound ? outboundMessages : inboundMessages; for (Document d : messages) { Element se = getAckRequested(d); if (se != null) { found = true; break; } } assertTrue("expected AckRequested", found); } public void verifySequenceFault(QName code, boolean outbound, int index) throws Exception { Document d = outbound ? outboundMessages.get(index) : inboundMessages.get(index); assert null != getRMHeaderElement(d, RMConstants.SEQUENCE_FAULT_NAME); } public void verifyHeader(QName name, boolean outbound, int index) throws Exception { Document d = outbound ? outboundMessages.get(index) : inboundMessages.get(index); assertNotNull((outbound ? "Outbound" : "Inbound") + " message " + index + " does not have " + name + "header.", getHeaderElement(d, name.getNamespaceURI(), name.getLocalPart())); } public void verifyNoHeader(QName name, boolean outbound, int index) throws Exception { Document d = outbound ? outboundMessages.get(index) : inboundMessages.get(index); assertNull((outbound ? "Outbound" : "Inbound") + " message " + index + " has " + name + "header.", getHeaderElement(d, name.getNamespaceURI(), name.getLocalPart())); } protected String getAction(Document document) throws Exception { Element e = getHeaderElement(document, addressingNamespace, "Action"); if (null != e) { return getText(e); } return null; } protected Element getSequence(Document document) throws Exception { return getRMHeaderElement(document, RMConstants.SEQUENCE_NAME); } public String getMessageNumber(Element elem) throws Exception { for (Node nd = elem.getFirstChild(); nd != null; nd = nd.getNextSibling()) { if (Node.ELEMENT_NODE == nd.getNodeType() && "MessageNumber".equals(nd.getLocalName())) { return getText(nd); } } return null; } private boolean getLastMessage(Element element) throws Exception { for (Node nd = element.getFirstChild(); nd != null; nd = nd.getNextSibling()) { if (Node.ELEMENT_NODE == nd.getNodeType() && "LastMessage".equals(nd.getLocalName())) { return true; } } return false; } protected Element getAcknowledgment(Document document) throws Exception { return getRMHeaderElement(document, RMConstants.SEQUENCE_ACK_NAME); } private Element getAckRequested(Document document) throws Exception { return getRMHeaderElement(document, RMConstants.ACK_REQUESTED_NAME); } private Element getRMHeaderElement(Document document, String name) throws Exception { return getHeaderElement(document, rmNamespace, name); } private Element getHeaderElement(Document document, String namespace, String localName) throws Exception { Element envelopeElement = document.getDocumentElement(); Element headerElement = null; for (Node nd = envelopeElement.getFirstChild(); nd != null; nd = nd.getNextSibling()) { if (Node.ELEMENT_NODE == nd.getNodeType() && "Header".equals(nd.getLocalName())) { headerElement = (Element)nd; break; } } if (null == headerElement) { return null; } for (Node nd = headerElement.getFirstChild(); nd != null; nd = nd.getNextSibling()) { if (Node.ELEMENT_NODE != nd.getNodeType()) { continue; } Element element = (Element)nd; String ns = element.getNamespaceURI(); String ln = element.getLocalName(); if (namespace.equals(ns) && localName.equals(ln)) { return element; } } return null; } public void verifyMessages(int nExpected, boolean outbound) { verifyMessages(nExpected, outbound, true); } public void verifyMessages(int nExpected, boolean outbound, boolean exact) { if (outbound) { if (exact) { assertEquals("Unexpected number of outbound messages" + dump(outStreams), nExpected, outboundMessages.size()); } else { assertTrue("Unexpected number of outbound messages: " + dump(outStreams), nExpected <= outboundMessages.size()); } } else { if (exact) { assertEquals("Unexpected number of inbound messages" + dump(inStreams), nExpected, inboundMessages.size()); } else { assertTrue("Unexpected number of inbound messages: " + dump(inStreams), nExpected <= inboundMessages.size()); } } } public void verifyAcknowledgementRange(long lower, long upper) throws Exception { long currentLower = 0; long currentUpper = 0; // get the final ack range for (Document doc : inboundMessages) { Element e = getRMHeaderElement(doc, RMConstants.SEQUENCE_ACK_NAME); // let the newer messages take precedence over the older messages in getting the final range if (null != e) { e = getNamedElement(e, "AcknowledgementRange"); if (null != e) { currentLower = Long.parseLong(e.getAttribute("Lower")); currentUpper = Long.parseLong(e.getAttribute("Upper")); } } } assertEquals("Unexpected acknowledgement lower range", lower, currentLower); assertEquals("Unexpected acknowledgement upper range", upper, currentUpper); } // note that this method picks the first match and returns public static Element getNamedElement(Element element, String lcname) throws Exception { for (Node nd = element.getFirstChild(); nd != null; nd = nd.getNextSibling()) { if (Node.ELEMENT_NODE == nd.getNodeType() && lcname.equals(nd.getLocalName())) { return (Element)nd; } } return null; } public void purgePartialResponses() throws Exception { for (int i = inboundMessages.size() - 1; i >= 0; i--) { if (isPartialResponse(inboundMessages.get(i))) { inboundMessages.remove(i); } } } public void purge() { inboundMessages.clear(); outboundMessages.clear(); inStreams.clear(); outStreams.clear(); } public void verifyPartialResponses(int nExpected) throws Exception { verifyPartialResponses(nExpected, null); } public void verifyPartialResponses(int nExpected, boolean[] piggybackedAcks) throws Exception { int npr = 0; for (int i = 0; i < inboundMessages.size(); i++) { if (isPartialResponse(inboundMessages.get(i))) { if (piggybackedAcks != null) { Element ack = getAcknowledgment(inboundMessages.get(i)); if (piggybackedAcks[npr]) { assertNotNull("Partial response " + npr + " does not include acknowledgement.", ack); } else { assertNull("Partial response " + npr + " has unexpected acknowledgement.", ack); } } npr++; } } assertEquals("Inbound messages did not contain expected number of partial responses.", nExpected, npr); } public boolean isPartialResponse(Document d) throws Exception { return null == getAction(d) && emptyBody(d); } public boolean emptyBody(Document d) throws Exception { Element envelopeElement = d.getDocumentElement(); Element bodyElement = null; for (Node nd = envelopeElement.getFirstChild(); nd != null; nd = nd.getNextSibling()) { if (Node.ELEMENT_NODE == nd.getNodeType() && "Body".equals(nd.getLocalName())) { bodyElement = (Element)nd; break; } } return !(null != bodyElement && bodyElement.hasChildNodes()); } String dump(List<byte[]> streams) { StringBuilder buf = new StringBuilder(); try { buf.append(System.getProperty("line.separator")); for (int i = 0; i < streams.size(); i++) { buf.append("["); buf.append(i); buf.append("] : "); buf.append(new String(streams.get(i))); buf.append(System.getProperty("line.separator")); } } catch (Exception ex) { return ""; } return buf.toString(); } public static String getText(Node node) { for (Node nd = node.getFirstChild(); nd != null; nd = nd.getNextSibling()) { if (Node.TEXT_NODE == nd.getNodeType()) { return nd.getNodeValue(); } } return null; } protected QName getNodeName(Node nd) { return new QName(nd.getNamespaceURI(), nd.getLocalName()); } }