package com.ctriposs.baiji.rpc.server;
import com.ctriposs.baiji.exception.BaijiRuntimeException;
import com.ctriposs.baiji.rpc.common.BaijiContract;
import com.ctriposs.baiji.rpc.common.HasResponseStatus;
import com.ctriposs.baiji.rpc.common.formatter.ContentFormatter;
import com.ctriposs.baiji.rpc.common.types.*;
import com.ctriposs.baiji.specific.SpecificRecord;
import org.apache.commons.beanutils.BeanUtils;
import org.apache.commons.lang.exception.ExceptionUtils;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URLDecoder;
import java.sql.Timestamp;
import java.util.*;
public class BaijiHttpRequestRouter implements HttpRequestRouter {
private static final Logger _logger = LoggerFactory.getLogger(BaijiHttpRequestRouter.class);
private final Map<RequestPath, OperationHandler> _handlers = new HashMap<RequestPath, OperationHandler>();
private final ServiceConfig _config;
private final ServiceMetadata _serviceMetadata;
public BaijiHttpRequestRouter(ServiceConfig config, Class<?> serviceType) {
this._config = config;
if (this._config.getDefaultFormatter() == null) { // default formatter must be provided
String errMsg = "Missing mandatory default content formatter in service config";
_logger.error(errMsg);
throw new BaijiRuntimeException(errMsg);
}
Class<?> contract = this.findContract(serviceType);
if (contract == null) {
String errMsg = "Can't find BaijiContract on service type " + serviceType;
throw new BaijiRuntimeException(errMsg);
}
_serviceMetadata = extractServiceMetaData(contract);
// Cache all operation methods
for (Method method : contract.getMethods()) {
// Create handler
OperationHandler handler = new OperationHandler(config, serviceType, method);
RequestPath path = new RequestPath(method.getName().toLowerCase());
// Fail due to duplicate entries
if (_handlers.containsKey(path)) {
String errMsg = String.format("duplicate method %s on service type %s is not allowed", method.getName(), serviceType);
_logger.error(errMsg);
throw new BaijiRuntimeException(errMsg);
}
_handlers.put(path, handler);
}
}
public ServiceMetadata getServiceMetaData() {
return _serviceMetadata;
}
private static ServiceMetadata extractServiceMetaData(Class<?> contract) {
BaijiContract annotation = contract.getAnnotation(BaijiContract.class);
ServiceMetadata metaData = new ServiceMetadata();
metaData.setServiceName(annotation.serviceName());
metaData.setServiceNamespace(annotation.serviceNamespace());
metaData.setCodeGeneratorVersion(annotation.codeGeneratorVersion());
return metaData;
}
private static Class<?> findContract(Class<?> serviceType) {
Class<?>[] interfaces = serviceType.getInterfaces();
if (interfaces == null || interfaces.length == 0) return null;
for (Class<?> intf : interfaces) {
if (intf.getAnnotation(BaijiContract.class) != null) {
return intf;
}
}
return null;
}
private OperationHandler selectHandler(RequestContext environment) {
// Extract path
String path = environment.RequestPath;
if (path == null || path.isEmpty()) {
return null;
}
while (path.startsWith("/")) {
path = path.substring(1); // Remove the beginning "/"
}
while (path.endsWith("/")) {
path = path.substring(0, path.length() - 1);
}
String[] keyBase = path.split("/");
if (keyBase.length != 1) {
return null;
}
// Extract "extension" for media type
int extStart = keyBase[0].lastIndexOf(".");
if (extStart != -1) {
String ext = keyBase[0].substring(extStart + 1);
environment.RequestExtention = ext;
keyBase[0] = keyBase[0].substring(0, keyBase[0].length() - ext.length() - 1);
}
RequestPath requestPath = new RequestPath(keyBase[0].toLowerCase());
return _handlers.get(requestPath);
}
private ContentFormatter negotiateFormat(RequestContext environment) {
ContentFormatter formatter = null;
if (environment.RequestExtention != null && !environment.RequestExtention.isEmpty()) {
// Try specified
Map<String, ContentFormatter> specifiedFormatters = _config.getSpecifiedFormatters();
if (specifiedFormatters != null) {
formatter = specifiedFormatters.get(environment.RequestExtention);
}
}
// Use default when no suitable specified formatter is found.
if (formatter == null) {
formatter = _config.getDefaultFormatter();
}
return formatter;
}
private static Map<String, String> splitQuery(String queryString) throws UnsupportedEncodingException {
if (queryString == null || queryString.isEmpty()) {
return null;
}
Map<String, String> queryMap = new LinkedHashMap<String, String>();
String[] pairs = queryString.split("&");
for (String pair : pairs) {
int idx = pair.indexOf("=");
queryMap.put(URLDecoder.decode(pair.substring(0, idx), "UTF-8"), URLDecoder.decode(pair.substring(idx + 1), "UTF-8"));
}
return queryMap;
}
// Routing -> ParameterBinding/Deserialization -> Invocation -> Serialization -> Write response
public void process(RequestContext request, HttpResponseWrapper responseWriter) {
ContentFormatter formatter = null;
OperationHandler handler = null;
try {
// Routing
handler = this.selectHandler(request);
if (handler == null) {
_logger.error("No handler found: " + request.RequestPath);
this.writeHttpResponse(responseWriter, HttpStatus.SC_NOT_FOUND);
return; // Nothing more to do
}
SpecificRecord requestObject = null;
formatter = this.negotiateFormat(request);
// ParameterBinding or Deserialization
if ("GET".equalsIgnoreCase(request.RequestMethod)) { // REST call, for testing only
// binding parameters
requestObject = handler.getEmptyRequestInstance();
// request parameters binding
Map<String, String> requestQueryMap = splitQuery(request.RequestQueryString);
if (requestQueryMap != null && requestQueryMap.size() > 0) {
BeanUtils.populate(requestObject, requestQueryMap);
}
} else if ("POST".equalsIgnoreCase(request.RequestMethod)) { // RPC call
requestObject = formatter.deserialize(handler.getRequestType(), request.RequestBody);
} else { // for Baiji RPC, only GET & POST are allowed
this.writeHttpResponse(responseWriter, HttpStatus.SC_METHOD_NOT_ALLOWED);
return; // Nothing more to do
}
if (requestObject == null) { // defensive programming
String errMsg = "Unable to bind request with request object of type " + handler.getRequestType();
_logger.error(errMsg);
SpecificRecord errorResponse = (SpecificRecord) this.buildErrorResponse(
handler, "NoRequestObject", errMsg, ErrorClassificationCodeType.FRAMEWORK_ERROR, null);
this.writeBaijiResponse(responseWriter, errorResponse, request, formatter);
return; // Nothing more to do
}
// Invocation
OperationContext operationContext = new OperationContext(request, handler.getMethodName(), requestObject);
SpecificRecord responseObject;
try {
responseObject = handler.invoke(operationContext);
} catch (Exception ex) {
Throwable actualEx;
if (ex instanceof InvocationTargetException) {
actualEx = ((InvocationTargetException) ex).getTargetException();
} else {
actualEx = ex;
}
String errMsg = actualEx.getClass().getName() + " - " + actualEx.getMessage();
_logger.error("Fail to invoke target service method " + handler.getMethodName() + ": " + errMsg,
actualEx);
SpecificRecord errorResponse = (SpecificRecord) this.buildErrorResponse(
handler, "ServiceInvocationError", errMsg, ErrorClassificationCodeType.SERVICE_ERROR, actualEx);
this.writeBaijiResponse(responseWriter, errorResponse, request, formatter);
return; // Nothing more to do
}
if (responseObject == null) { // defensive programming
String errMsg = "Fail to get response object when invoking the service";
_logger.error(errMsg);
SpecificRecord errorResponse = (SpecificRecord) this.buildErrorResponse(
handler, "NoResponseObject", errMsg, ErrorClassificationCodeType.FRAMEWORK_ERROR, null);
this.writeBaijiResponse(responseWriter, errorResponse, request, formatter);
return; // Nothing more to do
}
this.writeBaijiResponse(responseWriter, responseObject, request, formatter);
} catch (Throwable t) {
if (request != null && formatter != null && handler != null) {
String errMsg = t.getMessage();
_logger.error(errMsg, t);
try {
SpecificRecord errorResponse = (SpecificRecord) this.buildErrorResponse(
handler, "RequestException", errMsg, ErrorClassificationCodeType.FRAMEWORK_ERROR, t);
this.writeBaijiResponse(responseWriter, errorResponse, request, formatter);
} catch (Exception e) {
_logger.error("Internal server error", e);
this.writeHttpResponse(responseWriter, HttpStatus.SC_INTERNAL_SERVER_ERROR);
}
return; // Nothing more to do
} else {
_logger.error("Internal server error", t);
this.writeHttpResponse(responseWriter, HttpStatus.SC_INTERNAL_SERVER_ERROR);
return; // Nothing more to do
}
} finally {
if (request != null) {
if (request.RequestBody != null) {
try {
request.RequestBody.close();
} catch (IOException e) {
_logger.error("Fail to close request input stream", e);
}
}
if (request.ResponseBody != null) {
try {
request.ResponseBody.close();
} catch (IOException e) {
_logger.error("Fail to close response output stream", e);
}
}
}
}
}
private HasResponseStatus buildErrorResponse(OperationHandler handler, String errorCode, String errorMessage,
ErrorClassificationCodeType errorClassificationCode, Throwable t) throws Exception {
HasResponseStatus responseObj = (HasResponseStatus) handler.getEmptyResponseInstance();
ResponseStatusType responseStatus = new ResponseStatusType();
responseStatus.ack = AckCodeType.FAILURE;
responseStatus.timestamp = this.getCurrentTimestamp();
ErrorDataType errorData = new ErrorDataType();
errorData.errorCode = errorCode;
errorData.errorClassification = errorClassificationCode;
errorData.message = errorMessage;
errorData.severityCode = SeverityCodeType.ERROR;
if (t != null && this._config.isOutputExceptionStackTrace()) {
String stackTrace = ExceptionUtils.getStackTrace(t);
errorData.stackTrace = stackTrace;
}
responseStatus.errors = new ArrayList<ErrorDataType>();
responseStatus.errors.add(errorData);
responseObj.setResponseStatus(responseStatus);
return responseObj;
}
private String getCurrentTimestamp() {
Date date = new Date();
return new Timestamp(date.getTime()).toString();
}
private void writeBaijiResponse(HttpResponseWrapper responseWriter, SpecificRecord responseObject,
RequestContext environment, ContentFormatter formatter) throws Exception {
responseWriter.setStatus(HttpStatus.SC_OK);
OutputStream outputStream = responseWriter.getResponseStream();
environment.ResponseBody = outputStream;
// Populate response status
ResponseStatusType responseStatus = ((HasResponseStatus) responseObject).getResponseStatus();
if (responseStatus == null) {
responseStatus = new ResponseStatusType();
((HasResponseStatus) responseObject).setResponseStatus(responseStatus);
}
if (responseStatus.ack == null) { // populate mandatory ack
if (this.containSevereError(responseStatus)) {
responseStatus.ack = AckCodeType.FAILURE;
} else {
responseStatus.ack = AckCodeType.SUCCESS;
}
}
if (responseStatus.timestamp == null) {
responseStatus.timestamp = this.getCurrentTimestamp();
}
// Serialization
formatter.serialize(outputStream, responseObject);
String encoding = formatter.getEncoding();
String contentType = formatter.getMediaType() + ((encoding == null) ? "" : "; charset=" + encoding);
responseWriter.setHeader(HttpHeaders.CONTENT_TYPE, contentType);
// Write response
responseWriter.sendResponse();
}
private boolean containSevereError(ResponseStatusType responseStatus) {
List<ErrorDataType> errors = responseStatus.getErrors();
if (errors != null && errors.size() > 0) {
for (ErrorDataType errorData : errors) {
if (errorData.severityCode != null && errorData.severityCode == SeverityCodeType.ERROR) {
return true;
}
}
}
return false;
}
private void writeHttpResponse(HttpResponseWrapper responseWrapper, int responseStatus) {
responseWrapper.setStatus(responseStatus);
responseWrapper.sendResponse();
}
}