/**
* JBoss, Home of Professional Open Source
* Copyright Red Hat, Inc., and individual contributors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.jboss.aerogear.simplepush.server;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.hasItem;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.UUID;
import org.jboss.aerogear.simplepush.protocol.Ack;
import org.jboss.aerogear.simplepush.protocol.HelloMessage;
import org.jboss.aerogear.simplepush.protocol.HelloResponse;
import org.jboss.aerogear.simplepush.protocol.MessageType;
import org.jboss.aerogear.simplepush.protocol.RegisterResponse;
import org.jboss.aerogear.simplepush.protocol.impl.AckImpl;
import org.jboss.aerogear.simplepush.protocol.impl.AckMessageImpl;
import org.jboss.aerogear.simplepush.protocol.impl.HelloMessageImpl;
import org.jboss.aerogear.simplepush.protocol.impl.RegisterMessageImpl;
import org.jboss.aerogear.simplepush.server.datastore.ChannelNotFoundException;
import org.jboss.aerogear.simplepush.server.datastore.DataStore;
import org.jboss.aerogear.simplepush.server.datastore.VersionException;
import org.jboss.aerogear.simplepush.util.CryptoUtil;
import org.jboss.aerogear.simplepush.util.UUIDUtil;
import org.junit.Before;
import org.junit.Test;
public abstract class DefaultSimplePushServerTest {
private DefaultSimplePushServer server;
protected abstract DataStore createDataStore();
@Before
public void setup() {
final DataStore dataStore = createDataStore();
final SimplePushServerConfig config = DefaultSimplePushConfig.create().password("test").build();
final byte[] privateKey = DefaultSimplePushServer.generateAndStorePrivateKey(dataStore, config);
server = new DefaultSimplePushServer(dataStore, config, privateKey);
}
@Test
public void handleHandshake() {
final HelloResponse response = server.handleHandshake(new HelloMessageImpl());
assertThat(response.getUAID(), is(notNullValue()));
}
@Test
public void handleHandshakeWithNullUaid() {
final HelloResponse response = server.handleHandshake(new HelloMessageImpl(null));
assertThat(response.getUAID(), is(notNullValue()));
}
@Test
public void handleHandshakeWithExistingUaid() {
final String uaid = UUIDUtil.newUAID();
final HelloResponse response = server.handleHandshake(new HelloMessageImpl(uaid));
assertThat(response.getUAID(), equalTo(uaid));
}
@Test
public void handleHandshakeWithInvalidUaid() {
final String uaid = "bajja11122";
final HelloResponse response = server.handleHandshake(new HelloMessageImpl(uaid));
assertThat(response.getUAID(), is(notNullValue()));
}
@Test
public void handleHandshakeWithChannels() throws ChannelNotFoundException {
final String channelId1 = UUID.randomUUID().toString();
final String channelId2 = UUID.randomUUID().toString();
final Set<String> channelIds = new HashSet<String>(Arrays.asList(channelId1, channelId2));
final HelloMessage handshakeImpl = new HelloMessageImpl(UUIDUtil.newUAID(), channelIds);
final HelloResponse response = server.handleHandshake(handshakeImpl);
assertThat(response.getUAID(), is(notNullValue()));
assertThat(server.getChannel(channelId1), is(notNullValue()));
assertThat(server.getChannel(channelId2), is(notNullValue()));
}
@Test
public void handleHandshakeWithEmptyChannels() throws ChannelNotFoundException {
final Set<String> channelIds = Collections.emptySet();
final String uaid = UUIDUtil.newUAID();
final HelloMessage handshakeImpl = new HelloMessageImpl(uaid, channelIds);
final HelloResponse response = server.handleHandshake(handshakeImpl);
assertThat(response.getUAID(), is(notNullValue()));
assertThat(server.hasChannel(uaid, "channel1"), is(false));
}
@Test
public void handleHandshakeWithExistingAndEmptyChannelIDsInHello() throws ChannelNotFoundException {
final String channelId1 = UUID.randomUUID().toString();
final String channelId2 = UUID.randomUUID().toString();
final Set<String> channelIds = new HashSet<String>(Arrays.asList(channelId1, channelId2));
final String uaid = UUIDUtil.newUAID();
final HelloMessage firstHello = new HelloMessageImpl(uaid, channelIds);
final HelloResponse firstResponse = server.handleHandshake(firstHello);
assertThat(firstResponse.getUAID(), equalTo(uaid));
assertThat(server.hasChannel(uaid, channelId1), is(true));
assertThat(server.hasChannel(uaid, channelId2), is(true));
final HelloMessage nextHello = new HelloMessageImpl(uaid, Collections.<String>emptySet());
final HelloResponse secondResponse = server.handleHandshake(nextHello);
assertThat(secondResponse.getUAID(), equalTo(uaid));
assertThat(server.hasChannel(secondResponse.getUAID(), channelId1), is(false));
assertThat(server.hasChannel(secondResponse.getUAID(), channelId2), is(false));
}
@Test
public void handleHandshakeWithExistingAndNewChannels() throws ChannelNotFoundException {
final String uaid = UUIDUtil.newUAID();
final String channelId1 = UUID.randomUUID().toString();
final String channelId2 = UUID.randomUUID().toString();
final Set<String> channelIds = new HashSet<String>(Arrays.asList(channelId1, channelId2));
final HelloMessage firstHello = new HelloMessageImpl(uaid, channelIds);
final HelloResponse firstResponse = server.handleHandshake(firstHello);
assertThat(firstResponse.getUAID(), equalTo(uaid));
assertThat(server.hasChannel(uaid, channelId1), is(true));
assertThat(server.hasChannel(uaid, channelId2), is(true));
final String channelId3 = UUID.randomUUID().toString();
final String channelId4 = UUID.randomUUID().toString();
final Set<String> newChannelIds = new HashSet<String>(Arrays.asList(channelId3, channelId4));
final HelloMessage nextHello = new HelloMessageImpl(uaid, newChannelIds);
final HelloResponse secondResponse = server.handleHandshake(nextHello);
assertThat(secondResponse.getUAID(), equalTo(uaid));
assertThat(server.hasChannel(uaid, channelId1), is(false));
assertThat(server.hasChannel(uaid, channelId2), is(false));
assertThat(server.hasChannel(uaid, channelId3), is(true));
assertThat(server.hasChannel(uaid, channelId4), is(true));
}
@Test
public void handleHandshakeWithChannelsButNoUaid() {
final Set<String> channelIds = new HashSet<String>(Arrays.asList("channel1", "channel2"));
final HelloMessage handshakeImpl = new HelloMessageImpl(null, channelIds);
final HelloResponse response = server.handleHandshake(handshakeImpl);
assertThat(response.getUAID(), is(notNullValue()));
assertThat(server.hasChannel(handshakeImpl.getUAID(), "channel1"), is(false));
assertThat(server.hasChannel(handshakeImpl.getUAID(), "channel2"), is(false));
}
@Test
public void handeRegister() {
final RegisterResponse response = server.handleRegister(new RegisterMessageImpl("someChannelId"), UUIDUtil.newUAID());
assertThat(response.getChannelId(), equalTo("someChannelId"));
assertThat(response.getMessageType(), equalTo(MessageType.Type.REGISTER));
assertThat(response.getStatus().getCode(), equalTo(200));
assertThat(response.getStatus().getMessage(), equalTo("OK"));
assertThat(response.getPushEndpoint().startsWith("http://127.0.0.1:7777/update"), is(true));
}
@Test
public void removeChannel() throws ChannelNotFoundException {
final String channelId = "testChannelId";
final String uaid = UUIDUtil.newUAID();
server.handleRegister(new RegisterMessageImpl(channelId), uaid);
assertThat(server.getChannel(channelId).getChannelId(), is(equalTo(channelId)));
assertThat(server.removeChannel(channelId, UUIDUtil.newUAID()), is(false));
assertThat(server.removeChannel(channelId, uaid), is(true));
assertThat(server.removeChannel(channelId, uaid), is(false));
}
@Test
public void getUAID() throws ChannelNotFoundException {
final String channelId = UUID.randomUUID().toString();
final String uaid = UUIDUtil.newUAID();
server.handleRegister(new RegisterMessageImpl(channelId), uaid);
assertThat(server.getUAID(channelId), is(equalTo(uaid)));
}
@Test
public void handleNotification() throws ChannelNotFoundException {
final String channelId = UUID.randomUUID().toString();
final String uaid = UUIDUtil.newUAID();
final RegisterResponse registerResponse = server.handleRegister(new RegisterMessageImpl(channelId), uaid);
final String endpointToken = extractEndpointToken(registerResponse.getPushEndpoint());
Notification notification = server.handleNotification(endpointToken, "version=1");
assertThat(notification.ack(), equalTo((Ack)new AckImpl(channelId, 1L)));
assertThat(server.getChannel(channelId).getVersion(), is(1L));
server.handleNotification(endpointToken, "version=2");
assertThat(server.getChannel(channelId).getVersion(), is(2L));
}
private String extractEndpointToken(final String pushEndpoint) {
return pushEndpoint.substring(pushEndpoint.lastIndexOf('/') + 1);
}
@Test (expected = VersionException.class)
public void handleNotificationWithVersionLessThanCurrentVersion() throws ChannelNotFoundException {
final String channelId = UUID.randomUUID().toString();
final String uaid = UUIDUtil.newUAID();
final RegisterResponse registerResponse = server.handleRegister(new RegisterMessageImpl(channelId), uaid);
final String endpointToken = extractEndpointToken(registerResponse.getPushEndpoint());
server.handleNotification(endpointToken, "version=10");
server.handleNotification(endpointToken, "version=2");
}
@Test (expected = ChannelNotFoundException.class)
public void handleNotificationNonExistingChannelId() throws ChannelNotFoundException {
final String channelId = UUID.randomUUID().toString();
final String uaid = UUIDUtil.newUAID();
final byte[] key = CryptoUtil.secretKey(server.config().password(), "some salt for testing".getBytes());
final String endpointToken = CryptoUtil.endpointToken(channelId, uaid, key);
server.handleNotification(endpointToken, "version=1");
}
@Test
public void handleAck() throws ChannelNotFoundException {
final String channelId_1 = UUID.randomUUID().toString();
final String channelId_2 = UUID.randomUUID().toString();
final String uaid = UUIDUtil.newUAID();
final RegisterResponse registerResponse1 = server.handleRegister(new RegisterMessageImpl(channelId_1), uaid);
final String endpointToken1 = extractEndpointToken(registerResponse1.getPushEndpoint());
final RegisterResponse registerResponse2 = server.handleRegister(new RegisterMessageImpl(channelId_2), uaid);
final String endpointToken2 = extractEndpointToken(registerResponse2.getPushEndpoint());
server.handleNotification(endpointToken1, "version=10");
server.handleNotification(endpointToken2, "version=23");
final Ack ackChannel_1 = new AckImpl(channelId_1, 10L);
final Set<Ack> unacked = server.handleAcknowledgement(new AckMessageImpl(asSet(ackChannel_1)), uaid);
assertThat(unacked, hasItem(new AckImpl(channelId_2, 23L)));
}
private Set<Ack> asSet(final Ack... update) {
return new HashSet<Ack>(Arrays.asList(update));
}
}