/*- * * * Copyright 2015 Skymind,Inc. * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * * */ package org.nd4j.jdbc.loader.impl; import com.mchange.v2.c3p0.ComboPooledDataSource; import org.nd4j.jdbc.driverfinder.DriverFinder; import org.nd4j.jdbc.loader.api.JDBCNDArrayIO; import org.nd4j.linalg.api.complex.IComplexNDArray; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import javax.sql.DataSource; import java.io.ByteArrayOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; import java.sql.*; /** * Base class for loading ndarrays via org.nd4j.jdbc * * @author Adam Gibson */ public abstract class BaseLoader implements JDBCNDArrayIO { protected String tableName, columnName, idColumnName, jdbcUrl; protected DataSource dataSource; protected BaseLoader(DataSource dataSource, String jdbcUrl, String tableName, String idColumnName, String columnName) throws Exception { this.dataSource = dataSource; this.jdbcUrl = jdbcUrl; this.tableName = tableName; this.columnName = columnName; this.idColumnName = idColumnName; if (dataSource == null) { dataSource = new ComboPooledDataSource(); ComboPooledDataSource c = (ComboPooledDataSource) dataSource; c.setJdbcUrl(jdbcUrl); c.setDriverClass(DriverFinder.getDriver().getClass().getName()); } } protected BaseLoader(String jdbcUrl, String tableName, String idColumnName, String columnName) throws Exception { this.jdbcUrl = jdbcUrl; this.tableName = tableName; this.columnName = columnName; dataSource = new ComboPooledDataSource(); ComboPooledDataSource c = (ComboPooledDataSource) dataSource; c.setJdbcUrl(jdbcUrl); c.setDriverClass(DriverFinder.getDriver().getClass().getName()); this.idColumnName = idColumnName; } protected BaseLoader(DataSource dataSource, String jdbcUrl, String tableName, String columnName) throws Exception { this(dataSource, jdbcUrl, tableName, "id", columnName); } /** * Convert an ndarray to a blob * * @param toConvert the complex ndarray to convert * @return the converted complex ndarray */ @Override public Blob convert(IComplexNDArray toConvert) throws IOException, SQLException { ByteArrayOutputStream bos = new ByteArrayOutputStream(); DataOutputStream dos = new DataOutputStream(bos); Nd4j.writeComplex(toConvert, dos); byte[] bytes = bos.toByteArray(); Connection c = dataSource.getConnection(); Blob b = c.createBlob(); b.setBytes(1, bytes); c.close(); return b; } /** * Convert an ndarray to a blob * * @param toConvert the ndarray to convert * @return the converted ndarray */ @Override public Blob convert(INDArray toConvert) throws SQLException, IOException { ByteArrayOutputStream bos = new ByteArrayOutputStream(); DataOutputStream dos = new DataOutputStream(bos); Nd4j.write(toConvert, dos); byte[] bytes = bos.toByteArray(); Connection c = dataSource.getConnection(); Blob b = c.createBlob(); b.setBytes(1, bytes); c.close(); return b; } /** * Load an ndarray from a blob * * @param blob the blob to load from * @return the loaded ndarray */ @Override public INDArray load(Blob blob) throws SQLException, IOException { if (blob == null) return null; DataInputStream dis = new DataInputStream(blob.getBinaryStream()); return Nd4j.read(dis); } /** * Load a complex ndarray from a blob * * @param blob the blob to load from * @return the complex ndarray */ @Override public IComplexNDArray loadComplex(Blob blob) throws SQLException, IOException { DataInputStream dis = new DataInputStream(blob.getBinaryStream()); return Nd4j.readComplex(dis); } /** * Save the ndarray * * @param save the ndarray to save */ @Override public void save(INDArray save, String id) throws SQLException, IOException { doSave(save, id); } /** * Save the ndarray * * @param save the ndarray to save */ @Override public void save(IComplexNDArray save, String id) throws IOException, SQLException { doSave(save, id); } private void doSave(INDArray save, String id) throws SQLException, IOException { Connection c = dataSource.getConnection(); ByteArrayOutputStream bos = new ByteArrayOutputStream(); DataOutputStream dos = new DataOutputStream(bos); if (save instanceof IComplexNDArray) { IComplexNDArray c2 = (IComplexNDArray) save; Nd4j.writeComplex(c2, dos); } else Nd4j.write(save, dos); byte[] bytes = bos.toByteArray(); PreparedStatement preparedStatement = c.prepareStatement(insertStatement()); preparedStatement.setString(1, id); preparedStatement.setBytes(2, bytes); preparedStatement.executeUpdate(); preparedStatement.close(); c.close(); } /** * Load an ndarray blob given an id * * @param id the id to load * @return the blob */ @Override public Blob loadForID(String id) throws SQLException { Connection c = dataSource.getConnection(); PreparedStatement preparedStatement = c.prepareStatement(loadStatement()); preparedStatement.setString(1, id); ResultSet r = preparedStatement.executeQuery(); if (r.wasNull() || !r.next()) { c.close(); r.close(); preparedStatement.close(); return null; } else { Blob first = r.getBlob(2); c.close(); r.close(); preparedStatement.close(); return first; } } /** * Delete the given ndarray * * @param id the id of the ndarray to delete */ @Override public void delete(String id) throws SQLException { Connection c = dataSource.getConnection(); PreparedStatement p = c.prepareStatement(deleteStatement()); p.setString(1, id); p.execute(); p.close(); } }