/** * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.aurora.scheduler.http.api; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.Optional; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.ws.rs.core.MediaType; import org.apache.thrift.TException; import org.apache.thrift.TProcessor; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.transport.TIOStreamTransport; import org.apache.thrift.transport.TTransport; import static java.util.Objects.requireNonNull; import static javax.ws.rs.core.HttpHeaders.ACCEPT; /** * An implementation of {@link org.apache.thrift.server.TServlet} that can handle multiple thrift * protocols. The protocols are dispatched on HTTP headers. */ public class TContentAwareServlet extends HttpServlet { private final TProcessor processor; private final InputConfig inputConfig; private final OutputConfig outputConfig; /** * Class which contains the mapping of the factory and the content type of the output. */ static class ContentFactoryPair implements TProtocolFactory { private final TProtocolFactory factory; private final MediaType outputType; ContentFactoryPair(TProtocolFactory factory, MediaType outputType) { this.factory = requireNonNull(factory); this.outputType = requireNonNull(outputType); } MediaType getOutputType() { return outputType; } @Override public TProtocol getProtocol(TTransport tTransport) { return factory.getProtocol(tTransport); } } /** * Configures how to interpret the Content-Type of the request. */ static class InputConfig { // Type to use when there is no Content-Type private final MediaType defaultType; // Mapping of values in Content-Type to protocol to use to deserialize private final Map<MediaType, ContentFactoryPair> inputMapping; InputConfig(MediaType defaultType, Map<MediaType, ContentFactoryPair> inputMapping) { this.defaultType = requireNonNull(defaultType); this.inputMapping = requireNonNull(inputMapping); } Optional<ContentFactoryPair> getFactory(Optional<MediaType> mediaType) { return Optional.ofNullable(inputMapping.get(mediaType.orElse(defaultType))); } } /** * Configures how to interpret the Accept header of the request. The defaultType's factory is * returned for almost all values to maintain backwards compatibility. */ static class OutputConfig { // Type to use when there is no Accept header private final MediaType defaultType; // Mapping of MediaTypes in the Accept header to protocol used to serialize the response private final Map<MediaType, ContentFactoryPair> outputMapping; private final ContentFactoryPair defaultFactory; OutputConfig(MediaType defaultType, Map<MediaType, ContentFactoryPair> outputMapping) { this.defaultType = requireNonNull(defaultType); this.outputMapping = requireNonNull(outputMapping); this.defaultFactory = requireNonNull(outputMapping.get(defaultType)); } ContentFactoryPair getFactory(Optional<MediaType> type) { return Optional.ofNullable(outputMapping.get(type.orElse(defaultType))) .orElse(defaultFactory); } } TContentAwareServlet(TProcessor processor, InputConfig inputConfig, OutputConfig outputConfig) { this.processor = requireNonNull(processor); this.inputConfig = requireNonNull(inputConfig); this.outputConfig = requireNonNull(outputConfig); } @Override protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { Optional<ContentFactoryPair> factoryOptional = inputConfig .getFactory(Optional.ofNullable(request.getContentType()).map(MediaType::valueOf)); if (!factoryOptional.isPresent()) { response.setStatus(HttpServletResponse.SC_UNSUPPORTED_MEDIA_TYPE); String msg = "Unsupported Content-Type: " + request.getContentType(); response.getOutputStream().write(msg.getBytes(StandardCharsets.UTF_8)); return; } TTransport transport = new TIOStreamTransport(request.getInputStream(), response.getOutputStream()); TProtocol inputProtocol = factoryOptional.get().getProtocol(transport); Optional<String> acceptHeader = Optional.ofNullable(request.getHeader(ACCEPT)); Optional<MediaType> acceptType = Optional.empty(); if (acceptHeader.isPresent()) { try { acceptType = acceptHeader.map(MediaType::valueOf); } catch (IllegalArgumentException e) { // Thrown if the Accept header contains more than one type or something else we can't // parse, we just treat is as no header (which will pick up the default value). acceptType = Optional.empty(); } } ContentFactoryPair outputProtocolFactory = outputConfig.getFactory(acceptType); response.setContentType(outputProtocolFactory.getOutputType().toString()); TProtocol outputProtocol = outputProtocolFactory.getProtocol(transport); try { processor.process(inputProtocol, outputProtocol); response.getOutputStream().flush(); } catch (TException e) { throw new ServletException(e); } } @Override protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { doPost(request, response); } }