/**
* This file is part of git-as-svn. It is subject to the license terms
* in the LICENSE file found in the top-level directory of this distribution
* and at http://www.gnu.org/licenses/gpl-2.0.html. No part of git-as-svn,
* including this file, may be copied, modified, propagated, or distributed
* except according to the terms contained in the LICENSE file.
*/
package ru.bozaro.protobuf;
import com.google.protobuf.Message;
import org.apache.http.*;
import org.apache.http.client.utils.URLEncodedUtils;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.DefaultHttpResponseFactory;
import org.apache.http.impl.entity.EntityDeserializer;
import org.apache.http.impl.entity.EntitySerializer;
import org.apache.http.impl.entity.LaxContentLengthStrategy;
import org.apache.http.io.HttpMessageParser;
import org.apache.http.io.HttpMessageWriter;
import org.apache.http.io.SessionInputBuffer;
import org.apache.http.io.SessionOutputBuffer;
import org.apache.http.message.BasicHttpResponse;
import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ru.bozaro.protobuf.internal.MethodInfo;
import ru.bozaro.protobuf.internal.ServiceInfo;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
/**
* Servlet wrapper for Protobuf RPC
*
* @author Artem V. Navrotskiy <bozaro@users.noreply.github.com>
*/
public class ProtobufRpcSimpleHttp {
@NotNull
private static final Logger log = LoggerFactory.getLogger(ProtobufRpcSimpleHttp.class);
@NotNull
private final ServiceHolder holder;
public ProtobufRpcSimpleHttp(@NotNull ServiceHolder holder) {
this.holder = holder;
}
@SuppressWarnings("deprecation")
protected void service(@NotNull SessionInputBuffer inputBuffer, @NotNull SessionOutputBuffer outputBuffer, @NotNull HttpMessageParser<HttpRequest> parser, @NotNull HttpMessageWriter<HttpResponse> writer) throws IOException, HttpException {
try {
final HttpRequest request = parser.parse();
final HttpEntity entity;
if (request instanceof HttpEntityEnclosingRequest) {
final EntityDeserializer deserializer = new EntityDeserializer(new LaxContentLengthStrategy());
entity = deserializer.deserialize(inputBuffer, request);
((HttpEntityEnclosingRequest) request).setEntity(entity);
} else {
entity = null;
}
final HttpResponse response = service(request);
if (entity != null) {
entity.getContent().close();
}
if (response.getEntity() != null) {
response.addHeader(HttpHeaders.CONTENT_LENGTH, Long.toString(response.getEntity().getContentLength()));
response.addHeader(response.getEntity().getContentType());
response.addHeader(response.getEntity().getContentEncoding());
}
response.setHeader(HttpHeaders.SERVER, "Protobuf RPC");
writer.write(response);
if (response.getEntity() != null) {
final EntitySerializer serializer = new EntitySerializer(new LaxContentLengthStrategy());
serializer.serialize(outputBuffer, response, response.getEntity());
}
} finally {
outputBuffer.flush();
}
}
@NotNull
protected HttpResponse service(@NotNull HttpRequest req) throws IOException {
final String pathInfo;
try {
pathInfo = getPathInfo(req);
} catch (URISyntaxException e) {
return sendError(req, HttpStatus.SC_BAD_REQUEST, e.getMessage());
}
final int begin = pathInfo.charAt(0) == '/' ? 1 : 0;
final int separator = pathInfo.lastIndexOf('/');
if (separator > 0) {
ServiceInfo serviceInfo = holder.getService(pathInfo.substring(begin, separator));
if (serviceInfo != null) {
return service(req, pathInfo.substring(separator + 1), serviceInfo);
}
}
return sendError(req, HttpStatus.SC_NOT_FOUND, "Service not found: " + pathInfo);
}
@NotNull
public HttpResponse sendError(@NotNull HttpRequest req, int code, @NotNull String reason) {
final BasicHttpResponse response = new BasicHttpResponse(req.getProtocolVersion(), code, reason);
final ContentType contentType = ContentType.create("text/plain", StandardCharsets.UTF_8);
response.setEntity(new StringEntity("ERROR " + code + ": " + reason, contentType));
return response;
}
private @NotNull HttpResponse service(@NotNull HttpRequest req, @NotNull String methodPath, @NotNull ServiceInfo serviceInfo) throws IOException {
final MethodInfo method = serviceInfo.getMethod(methodPath);
if (method == null) {
return sendError(req, HttpStatus.SC_NOT_FOUND, "Method not found: " + methodPath);
}
final Message.Builder msgRequest = method.requestBuilder();
final String httpHethod = req.getRequestLine().getMethod();
if (httpHethod.equals("POST") && (req instanceof HttpEntityEnclosingRequest)) {
final HttpEntity entity = ((HttpEntityEnclosingRequest) req).getEntity();
if (entity != null) {
method.requestByStream(msgRequest, entity.getContent(), getCharset(entity));
} else {
return sendError(req, HttpStatus.SC_NO_CONTENT, "Request payload not found");
}
} else if (!httpHethod.equals("GET")) {
return sendError(req, HttpStatus.SC_METHOD_NOT_ALLOWED, "Unsupported method");
}
try {
method.requestByParams(msgRequest, getParameterMap(req));
} catch (URISyntaxException | ParseException e) {
return sendError(req, HttpStatus.SC_BAD_REQUEST, e.getMessage());
}
try {
final byte[] msgResponse = method.call(msgRequest.build(), StandardCharsets.UTF_8).get();
if (msgResponse != null) {
final HttpResponse response = DefaultHttpResponseFactory.INSTANCE.newHttpResponse(req.getProtocolVersion(), HttpStatus.SC_OK, null);
final ByteArrayEntity entity = new ByteArrayEntity(msgResponse, ContentType.create(method.getFormat().getMimeType(), StandardCharsets.UTF_8));
response.setEntity(entity);
return response;
} else {
return sendError(req, HttpStatus.SC_INTERNAL_SERVER_ERROR, "Illegal method return value");
}
} catch (InterruptedException | ExecutionException e) {
log.error("Method error " + method.getName(), e);
return sendError(req, HttpStatus.SC_INTERNAL_SERVER_ERROR, e.getMessage());
}
}
@NotNull
private String getPathInfo(@NotNull HttpRequest req) throws URISyntaxException {
final URI uri = new URI(req.getRequestLine().getUri());
return uri.getPath() != null ? uri.getPath() : "";
}
@NotNull
private Map<String, String[]> getParameterMap(@NotNull HttpRequest req) throws URISyntaxException {
final Map<String, List<String>> params = new HashMap<>();
final List<NameValuePair> pairList = URLEncodedUtils.parse(new URI(req.getRequestLine().getUri()), StandardCharsets.UTF_8.name());
for (NameValuePair param : pairList) {
params.compute(param.getName(), (item, value) -> {
if (value == null) {
value = new ArrayList<>();
}
value.add(param.getValue());
return value;
});
}
final Map<String, String[]> result = new HashMap<>();
for (Map.Entry<String, List<String>> pair : params.entrySet()) {
result.put(pair.getKey(), pair.getValue().toArray(new String[pair.getValue().size()]));
}
return result;
}
@NotNull
private static Charset getCharset(@NotNull HttpEntity entity) {
final Header charset = entity.getContentEncoding();
return charset == null ? StandardCharsets.UTF_8 : Charset.forName(charset.getValue());
}
}