/* * JBoss, Home of Professional Open Source. * Copyright 2011, Red Hat, Inc., and individual contributors * as indicated by the @author tags. See the copyright.txt file in the * distribution for a full listing of individual contributors. * * This is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as * published by the Free Software Foundation; either version 2.1 of * the License, or (at your option) any later version. * * This software is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU * Lesser General Public License for more details. * * You should have received a copy of the GNU Lesser General Public * License along with this software; if not, write to the Free * Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA * 02110-1301 USA, or see the FSF site: http://www.fsf.org. */ package org.teiid.jboss.rest; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.charset.Charset; import java.sql.Array; import java.sql.Blob; import java.sql.CallableStatement; import java.sql.Clob; import java.sql.Connection; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.SQLXML; import java.sql.Statement; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.StreamingOutput; import org.jboss.resteasy.plugins.providers.multipart.InputPart; import org.jboss.resteasy.plugins.providers.multipart.MultipartFormDataInput; import org.teiid.core.types.*; import org.teiid.core.util.Base64; import org.teiid.core.util.ObjectConverterUtil; import org.teiid.core.util.ReaderInputStream; import org.teiid.core.util.StringUtil; import org.teiid.jdbc.TeiidDriver; import org.teiid.query.function.source.XMLSystemFunctions; import org.teiid.query.sql.symbol.XMLSerialize; import org.teiid.query.sql.visitor.SQLStringVisitor; public abstract class TeiidRSProvider { public StreamingOutput execute(final String vdbName, final String version, final String procedureName, final LinkedHashMap<String, String> parameters, final String charSet, final boolean passthroughAuth, final boolean usingReturn) throws SQLException { return new StreamingOutput() { @Override public void write(OutputStream output) throws IOException, WebApplicationException { Connection conn = null; try { conn = getConnection(vdbName, version, passthroughAuth); LinkedHashMap<String, Object> updatedParameters = convertParameters(conn, vdbName, procedureName, parameters); InputStream is = executeProc(conn, procedureName, updatedParameters, charSet, usingReturn); ObjectConverterUtil.write(output, is, -1); } catch (SQLException e) { throw new WebApplicationException(e); } finally { if (conn != null) { try { conn.close(); } catch (SQLException e) { } } } } }; } public StreamingOutput executePost(final String vdbName, final String version, final String procedureName, final MultipartFormDataInput parameters, final String charSet, final boolean passthroughAuth, final boolean usingReturn) throws SQLException { return new StreamingOutput() { @Override public void write(OutputStream output) throws IOException, WebApplicationException { Connection conn = null; try { conn = getConnection(vdbName, version, passthroughAuth); LinkedHashMap<String, Object> updatedParameters = convertParameters(conn, vdbName, procedureName, parameters); InputStream is = executeProc(conn, procedureName, updatedParameters, charSet, usingReturn); ObjectConverterUtil.write(output, is, -1); } catch (SQLException e) { throw new WebApplicationException(e); } finally { if (conn != null) { try { conn.close(); } catch (SQLException e) { } } } } }; } public InputStream executeProc(Connection conn, String procedureName, LinkedHashMap<String, Object> parameters, String charSet, boolean usingReturn) throws SQLException { //the generated code sends a empty string rather than null. if (charSet != null && charSet.trim().isEmpty()) { charSet = null; } Object result = null; StringBuilder sb = new StringBuilder(); sb.append("{ "); //$NON-NLS-1$ if (usingReturn) { sb.append("? = "); //$NON-NLS-1$ } sb.append("CALL ").append(procedureName); //$NON-NLS-1$ sb.append("("); //$NON-NLS-1$ boolean first = true; for (Map.Entry<String, Object> entry : parameters.entrySet()) { if (entry.getValue() == null) { continue; } if (!first) { sb.append(", "); //$NON-NLS-1$ } first = false; sb.append(SQLStringVisitor.escapeSinglePart(entry.getKey())).append("=>?"); //$NON-NLS-1$ } sb.append(") }"); //$NON-NLS-1$ CallableStatement statement = conn.prepareCall(sb.toString()); if (!parameters.isEmpty()) { int i = usingReturn?2:1; for (Object value : parameters.values()) { if (value == null) { continue; } statement.setObject(i++, value); } } final boolean hasResultSet = statement.execute(); if (hasResultSet) { ResultSet rs = statement.getResultSet(); if (rs.next()) { result = rs.getObject(1); } else { throw new SQLException(RestServicePlugin.Util.gs(RestServicePlugin.Event.TEIID28002)); } } else if (!usingReturn){ throw new SQLException(RestServicePlugin.Util.gs(RestServicePlugin.Event.TEIID28002)); } else { result = statement.getObject(1); } return handleResult(charSet, result); } private LinkedHashMap<String, Object> convertParameters(Connection conn, String vdbName, String procedureName, LinkedHashMap<String, String> inputParameters) throws SQLException { Map<String, Class<?>> expectedTypes = getParameterTypes(conn, vdbName, procedureName); LinkedHashMap<String, Object> expectedValues = new LinkedHashMap<String, Object>(); try { for (String columnName : inputParameters.keySet()) { Class<?> runtimeType = expectedTypes.get(columnName); if (runtimeType == null) { throw new SQLException(RestServicePlugin.Util.gs(RestServicePlugin.Event.TEIID28001, columnName, procedureName)); } Object value = inputParameters.get(columnName); if (value != null) { if (Array.class.isAssignableFrom(runtimeType)) { List<String> array = StringUtil.split((String)value, ","); //$NON-NLS-1$ value = array.toArray(new String[array.size()]); } else if (DataTypeManager.DefaultDataClasses.VARBINARY.isAssignableFrom(runtimeType)) { value = Base64.decode((String)value); } else { if (DataTypeManager.isTransformable(String.class, runtimeType)) { Transform t = DataTypeManager.getTransform(String.class, runtimeType); value = t.transform(value, runtimeType); } } } expectedValues.put(columnName, value); } return expectedValues; } catch (TransformationException e) { throw new SQLException(e); } } private LinkedHashMap<String, Object> convertParameters(Connection conn, String vdbName, String procedureName, MultipartFormDataInput form) throws SQLException { Map<String, Class<?>> runtimeTypes = getParameterTypes(conn, vdbName, procedureName); LinkedHashMap<String, Object> expectedValues = new LinkedHashMap<String, Object>(); Map<String, List<InputPart>> inputParameters = form.getFormDataMap(); for (String columnName : inputParameters.keySet()) { Class<?> runtimeType = runtimeTypes.get(columnName); if (runtimeType == null) { throw new SQLException(RestServicePlugin.Util.gs(RestServicePlugin.Event.TEIID28001, columnName, procedureName)); } if (runtimeType.isAssignableFrom(Array.class)) { List<InputPart> valueStreams = inputParameters.get(columnName); ArrayList<Object> array = new ArrayList<Object>(); try { for (InputPart part : valueStreams) { array.add(part.getBodyAsString()); } } catch (IOException e) { throw new SQLException(e); } expectedValues.put(columnName, array.toArray(new Object[array.size()])); } else { final InputPart part = inputParameters.get(columnName).get(0); try { expectedValues.put(columnName, convertToRuntimeType(runtimeType, part)); } catch (IOException e) { throw new SQLException(e); } } } return expectedValues; } private Object convertToRuntimeType(Class<?> runtimeType, final InputPart part) throws IOException, SQLException { if (SQLXML.class.isAssignableFrom(runtimeType)) { SQLXMLImpl xml = new SQLXMLImpl(new InputStreamFactory() { @Override public InputStream getInputStream() throws IOException { return part.getBody(InputStream.class, null); } }); if (charset(part) != null) { xml.setEncoding(charset(part)); } return xml; } else if (Blob.class.isAssignableFrom(runtimeType)) { return new BlobImpl(new InputStreamFactory() { @Override public InputStream getInputStream() throws IOException { return part.getBody(InputStream.class, null); } }); } else if (Clob.class.isAssignableFrom(runtimeType)) { ClobImpl clob = new ClobImpl(new InputStreamFactory() { @Override public InputStream getInputStream() throws IOException { return part.getBody(InputStream.class, null); } }, -1); if (charset(part) != null) { clob.setEncoding(charset(part)); } return clob; } else if (DataTypeManager.DefaultDataClasses.VARBINARY.isAssignableFrom(runtimeType)) { return Base64.decode(part.getBodyAsString()); } else if (DataTypeManager.isTransformable(String.class, runtimeType)) { try { return DataTypeManager.transformValue(part.getBodyAsString(), runtimeType); } catch (TransformationException e) { throw new SQLException(e); } } return part.getBodyAsString(); } private String charset(final InputPart part) { return part.getMediaType().getParameters().get("charset"); //$NON-NLS-1$ } private LinkedHashMap<String, Class<?>> getParameterTypes(Connection conn, String vdbName, String procedureName) throws SQLException { String schemaName = procedureName.substring(0, procedureName.lastIndexOf('.')).replace('\"', ' ').trim(); String procName = procedureName.substring(procedureName.lastIndexOf('.')+1).replace('\"', ' ').trim(); LinkedHashMap<String, Class<?>> expectedTypes = new LinkedHashMap<String, Class<?>>(); try { ResultSet rs = conn.getMetaData().getProcedureColumns(vdbName, schemaName, procName, "%"); //$NON-NLS-1$ while(rs.next()) { String columnName = rs.getString(4); int columnDataType = rs.getInt(6); Class<?> runtimeType = DataTypeManager .getRuntimeType(Class.forName(JDBCSQLTypeInfo.getJavaClassName(columnDataType))); expectedTypes.put(columnName, runtimeType); } rs.close(); return expectedTypes; } catch (ClassNotFoundException e) { throw new SQLException(e); } } private InputStream handleResult(String charSet, Object result) throws SQLException { if (result == null) { return null; //or should this be an empty result? } if (result instanceof SQLXML) { if (charSet != null) { XMLSerialize serialize = new XMLSerialize(); serialize.setTypeString("blob"); //$NON-NLS-1$ serialize.setDeclaration(true); serialize.setEncoding(charSet); serialize.setDocument(true); try { return ((BlobType)XMLSystemFunctions.serialize(serialize, new XMLType((SQLXML)result))).getBinaryStream(); } catch (TransformationException e) { throw new SQLException(e); } } return ((SQLXML)result).getBinaryStream(); } else if (result instanceof Blob) { return ((Blob)result).getBinaryStream(); } else if (result instanceof Clob) { return new ReaderInputStream(((Clob)result).getCharacterStream(), charSet==null?Charset.defaultCharset():Charset.forName(charSet)); } return new ByteArrayInputStream(result.toString().getBytes(charSet==null?Charset.defaultCharset():Charset.forName(charSet))); } public StreamingOutput executeQuery(final String vdbName, final String vdbVersion, final String sql, boolean json, final boolean passthroughAuth) throws SQLException { return new StreamingOutput() { @Override public void write(OutputStream output) throws IOException, WebApplicationException { Connection conn = null; try { conn = getConnection(vdbName, vdbVersion, passthroughAuth); Statement statement = conn.createStatement(); final boolean hasResultSet = statement.execute(sql); Object result = null; if (hasResultSet) { ResultSet rs = statement.getResultSet(); if (rs.next()) { result = rs.getObject(1); } else { throw new SQLException(RestServicePlugin.Util.gs(RestServicePlugin.Event.TEIID28002)); } } InputStream is = handleResult(Charset.defaultCharset().name(), result); ObjectConverterUtil.write(output, is, -1); } catch (SQLException e) { throw new WebApplicationException(e); } finally { try { if (conn != null) { conn.close(); } } catch (SQLException e) { } } } }; } private Connection getConnection(String vdbName, String version, boolean passthough) throws SQLException { TeiidDriver driver = new TeiidDriver(); return driver.connect("jdbc:teiid:"+vdbName+"."+version+";"+(passthough?"PassthroughAuthentication=true;":""), null); //$NON-NLS-1$ //$NON-NLS-2$ //$NON-NLS-3$ //$NON-NLS-4$ //$NON-NLS-5$ } }