package hk.hku.cecid.ebms.spa.handler.jms;
import static org.hamcrest.MatcherAssert.*;
import static org.hamcrest.Matchers.*;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.isNull;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import hk.hku.cecid.ebms.pkg.EbxmlMessage;
import hk.hku.cecid.ebms.pkg.MessageHeader.PartyId;
import hk.hku.cecid.ebms.spa.handler.MessageServiceHandler;
import hk.hku.cecid.ebms.spa.listener.EbmsRequest;
import hk.hku.cecid.ebms.spa.listener.EbmsResponse;
import hk.hku.cecid.piazza.commons.message.Message;
import hk.hku.cecid.piazza.commons.util.Logger;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
public class EbmsMessageHandlerTest {
EbmsMessageHandler mh = null;
MessageServiceHandler msh = null;
@Before
public void setup() throws Exception {
mh = spy(new EbmsMessageHandler());
doReturn(mock(Logger.class)).when(mh).log();
doReturn(true).when(mh).checkValidChannel(anyString(), anyString(), anyString());
msh = mock(MessageServiceHandler.class);
doReturn(msh).when(mh).getMSH();
}
@Test
public void testEbmsMessageHandlerValidMessage() throws Exception {
Message msg = buildMessage("TestMessage");
doAnswer(new Answer<Object>() {
public Object answer(InvocationOnMock invocation) throws Throwable {
Object arg = invocation.getArguments()[0];
assertThat(arg, instanceOf(EbmsRequest.class));
EbmsRequest request = (EbmsRequest) invocation.getArguments()[0];
assertThat(request, notNullValue());
//Check if the ebxml message is valid
EbxmlMessage ebxml = request.getMessage();
assertThat(ebxml, notNullValue());
assertThat(ebxml.getCpaId(),equalTo("cpaId"));
assertThat(ebxml.getService(),equalTo("http://localhost/inbound"));
assertThat(ebxml.getAction(),equalTo("action"));
assertThat(ebxml.getConversationId(),equalTo("conversationId"));
assertThat(first(ebxml.getFromPartyIds()),equalTo("test_a"));
assertThat(first(ebxml.getToPartyIds()),equalTo("test_b"));
assertThat(ebxml.getTimeToLive(),notNullValue());
assertThat(ebxml.getMessageId(), equalTo("TEST-123456789"));
//Check the payload.
assertThat(ebxml.getPayloadCount(),equalTo(1));
assertThat(ebxml.getPayloadContainer("Payload-0").getContentType(),equalTo("text/xml; charset=UTF-8"));
return null;
}
}).when(msh).processOutboundMessage((EbmsRequest) any(), (EbmsResponse) isNull());
mh.onMessage(msg);
verify(msh).processOutboundMessage((EbmsRequest) any(), (EbmsResponse) isNull());
}
public Message buildMessage(String data) {
Message msg = mock(Message.class);
Map<String,Object> headers = new HashMap<String, Object>();
headers.put("cpaId","cpaId");
headers.put("service","http://localhost/inbound");
headers.put("action","action");
headers.put("conversationId","conversationId");
headers.put("fromPartyId","test_a");
headers.put("fromPartyType","");
headers.put("toPartyId","test_b");
headers.put("toPartyType","");
headers.put("timeToLiveOffset","10800");
headers.put("ebxmlMessageId","TEST-123456789");
List<byte[]> payloads = new ArrayList<byte[]>();
payloads.add("testmessage".getBytes());
doReturn("").when(msg).getSource();
doReturn(headers).when(msg).getHeader();
doReturn(payloads).when(msg).getPayloads();
return msg;
}
private String first(Iterator<?> iter) {
Object o = iter.next();
if(o != null) {
PartyId partyId = (PartyId)o;
return partyId.getId();
}
return null;
}
}