/**
* Copyright 2010 Google Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package org.waveprotocol.box.server.robots.active;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.wave.api.OperationRequest;
import com.google.wave.api.OperationType;
import com.google.wave.api.ProtocolVersion;
import com.google.wave.api.RobotSerializer;
import com.google.wave.api.data.converter.EventDataConverterManager;
import junit.framework.TestCase;
import net.oauth.OAuth;
import net.oauth.OAuthAccessor;
import net.oauth.OAuthException;
import net.oauth.OAuthMessage;
import net.oauth.OAuthServiceProvider;
import net.oauth.OAuthValidator;
import org.waveprotocol.box.server.account.RobotAccountDataImpl;
import org.waveprotocol.box.server.persistence.AccountStore;
import org.waveprotocol.box.server.robots.OperationContext;
import org.waveprotocol.box.server.robots.OperationServiceRegistry;
import org.waveprotocol.box.server.robots.operations.OperationService;
import org.waveprotocol.box.server.robots.util.ConversationUtil;
import org.waveprotocol.box.server.waveserver.WaveletProvider;
import org.waveprotocol.wave.model.wave.ParticipantId;
import java.io.BufferedReader;
import java.io.PrintWriter;
import java.io.StringReader;
import java.io.StringWriter;
import java.lang.reflect.Type;
import java.util.Collections;
import java.util.List;
import java.util.StringTokenizer;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Unit tests for the {@link ActiveApiServlet}.
*
* @author ljvderijk@google.com (Lennard de Rijk)
*/
public class ActiveApiServletTest extends TestCase {
private static final ParticipantId ROBOT = ParticipantId.ofUnsafe("robot@example.com");
private static final ParticipantId UNKNOWN = ParticipantId.ofUnsafe("unknown@example.com");
private RobotSerializer robotSerializer;
private OperationServiceRegistry operationRegistry;
private OAuthValidator validator;
private HttpServletResponse resp;
private HttpServletRequest req;
private StringWriter outputWriter;
private ActiveApiServlet servlet;
@Override
protected void setUp() throws Exception {
robotSerializer = mock(RobotSerializer.class);
operationRegistry = mock(OperationServiceRegistry.class);
validator = mock(OAuthValidator.class);
EventDataConverterManager converterManager = mock(EventDataConverterManager.class);
WaveletProvider waveletProvider = mock(WaveletProvider.class);
ConversationUtil conversationUtil = mock(ConversationUtil.class);
OAuthServiceProvider oAuthServiceProvider = mock(OAuthServiceProvider.class);
AccountStore accountStore = mock(AccountStore.class);
when(accountStore.getAccount(ROBOT)).thenReturn(
new RobotAccountDataImpl(ROBOT, "", "secret", null, true));
req = mock(HttpServletRequest.class);
when(req.getRequestURL()).thenReturn(new StringBuffer("www.example.com/robot"));
when(req.getHeaderNames()).thenReturn(new StringTokenizer("Authorization"));
when(req.getReader()).thenReturn(new BufferedReader(new StringReader("")));
resp = mock(HttpServletResponse.class);
outputWriter = new StringWriter();
when(resp.getWriter()).thenReturn(new PrintWriter(outputWriter));
servlet =
new ActiveApiServlet(robotSerializer, converterManager, waveletProvider, operationRegistry,
conversationUtil, oAuthServiceProvider, validator, accountStore);
}
public void testDoPostExecutesAndWritesResponse() throws Exception {
when(req.getHeaders("Authorization")).thenReturn(generateOAuthHeader(ROBOT.getAddress()));
String operationId = "op1";
OperationRequest operation = new OperationRequest("wavelet.create", operationId);
List<OperationRequest> operations = Collections.singletonList(operation);
when(robotSerializer.deserializeOperations(anyString())).thenReturn(operations);
String responseValue = "response value";
when(robotSerializer.serialize(any(), any(Type.class), any(ProtocolVersion.class))).thenReturn(
responseValue);
OperationService service = mock(OperationService.class);
when(operationRegistry.getServiceFor(any(OperationType.class))).thenReturn(service);
servlet.doPost(req, resp);
verify(operationRegistry).getServiceFor(any(OperationType.class));
verify(service).execute(eq(operation), any(OperationContext.class), eq(ROBOT));
verify(resp).setStatus(HttpServletResponse.SC_OK);
assertEquals("Response should have been written into the servlet", responseValue,
outputWriter.toString());
}
public void testDoPostUnauthorizedWhenParticipantInvalid() throws Exception {
when(req.getHeaders("Authorization")).thenReturn(
generateOAuthHeader("invalid#$example.com"));
servlet.doPost(req, resp);
verify(resp).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
}
public void testDoPostUnauthorizedWhenParticipantUnknown() throws Exception {
when(req.getHeaders("Authorization")).thenReturn(generateOAuthHeader(UNKNOWN.getAddress()));
servlet.doPost(req, resp);
verify(resp).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
}
public void testDoPostUnauthorizedWhenValidationFails() throws Exception {
when(req.getHeaders("Authorization")).thenReturn(generateOAuthHeader(ROBOT.getAddress()));
doThrow(new OAuthException("")).when(validator).validateMessage(
any(OAuthMessage.class), any(OAuthAccessor.class));
servlet.doPost(req, resp);
verify(resp).setStatus(HttpServletResponse.SC_UNAUTHORIZED);
}
private StringTokenizer generateOAuthHeader(String address) {
return new StringTokenizer("OAuth " + OAuth.OAUTH_CONSUMER_KEY + "=\"" + address + "\"", "");
}
}