/* * Copyright 2013-2017 Simba Open Source * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package org.simbasecurity.core.service.thrift; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.Collection; import java.util.Map; import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import org.apache.thrift.TBaseProcessor; import org.apache.thrift.TException; import org.apache.thrift.TProcessor; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.transport.TIOStreamTransport; import org.apache.thrift.transport.TTransport; import org.springframework.util.ClassUtils; import org.springframework.web.servlet.FrameworkServlet; public class SpringTServlet<I> extends FrameworkServlet { private TProcessor processor; private final Class<? extends TBaseProcessor<I>> processorClass; private final String processorBean; private final TProtocolFactory inProtocolFactory; private final TProtocolFactory outProtocolFactory; private final Collection<Map.Entry<String, String>> customHeaders; protected SpringTServlet(Class<? extends TBaseProcessor<I>> processorClass, String processorBean, TProtocolFactory inProtocolFactory, TProtocolFactory outProtocolFactory) { this.processorClass = processorClass; this.processorBean = processorBean; this.inProtocolFactory = inProtocolFactory; this.outProtocolFactory = outProtocolFactory; this.customHeaders = new ArrayList<Map.Entry<String, String>>(); } protected SpringTServlet(Class<? extends TBaseProcessor<I>> processorClass, String processorBean, TProtocolFactory protocolFactory) { this(processorClass, processorBean, protocolFactory, protocolFactory); } private TProcessor getProcessor() { if (processor == null) { try { Constructor<? extends TBaseProcessor<I>> constructor = findConstructor(); processor = constructor.newInstance(getWebApplicationContext().getBean(processorBean)); } catch (Exception e) { throw new RuntimeException(e); } } return processor; } private Constructor<? extends TBaseProcessor<I>> findConstructor() throws NoSuchMethodException { Class<?> type = ClassUtils.getUserClass(getWebApplicationContext().getType(processorBean)); Constructor<? extends TBaseProcessor<I>> constructor = null; try { constructor = processorClass.getConstructor(type); } catch (NoSuchMethodException ignore) { } if (constructor == null) { for (Class<?> aClass : type.getInterfaces()) { try { constructor = processorClass.getConstructor(aClass); break; } catch (NoSuchMethodException ignore) { } } } if (constructor == null) { throw new IllegalStateException("Can't locate correct constructor on " + processorClass.getName()); } return constructor; } @Override protected void doService(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { TTransport inTransport; TTransport outTransport; try { response.setContentType("application/x-thrift"); if (null != this.customHeaders) { for (Map.Entry<String, String> header : this.customHeaders) { response.addHeader(header.getKey(), header.getValue()); } } InputStream in = request.getInputStream(); OutputStream out = response.getOutputStream(); TTransport transport = new TIOStreamTransport(in, out); inTransport = transport; outTransport = transport; TProtocol inProtocol = inProtocolFactory.getProtocol(inTransport); TProtocol outProtocol = outProtocolFactory.getProtocol(outTransport); getProcessor().process(inProtocol, outProtocol); out.flush(); } catch (TException te) { throw new ServletException(te); } } public void addCustomHeader(final String key, final String value) { this.customHeaders.add(new Map.Entry<String, String>() { public String getKey() { return key; } public String getValue() { return value; } public String setValue(String value) { return null; } }); } public void setCustomHeaders(Collection<Map.Entry<String, String>> headers) { this.customHeaders.clear(); this.customHeaders.addAll(headers); } }