package org.mariadb.jdbc; import org.junit.Assume; import org.junit.BeforeClass; import org.junit.Test; import java.io.*; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import static org.junit.Assert.*; public class LocalInfileInputStreamTest extends BaseTest { /** * Initialisation. * * @throws SQLException exception */ @BeforeClass() public static void initClass() throws SQLException { createTable("LocalInfileInputStreamTest", "id int, test varchar(100)"); createTable("ttlocal", "id int, test varchar(100)"); createTable("ldinfile", "a varchar(10)"); createTable("`infile`", "`a` varchar(50) DEFAULT NULL, `b` varchar(50) DEFAULT NULL", "ENGINE=InnoDB DEFAULT CHARSET=latin1"); } @Test public void testLocalInfileInputStream() throws SQLException { try (Statement st = sharedConnection.createStatement()) { // Build a tab-separated record file StringBuilder builder = new StringBuilder(); builder.append("1\thello\n"); builder.append("2\tworld\n"); InputStream inputStream = new ByteArrayInputStream(builder.toString().getBytes()); ((MariaDbStatement) st).setLocalInfileInputStream(inputStream); st.executeUpdate("LOAD DATA LOCAL INFILE 'dummy.tsv' INTO TABLE LocalInfileInputStreamTest (id, test)"); ResultSet rs = st.executeQuery("SELECT COUNT(*) FROM LocalInfileInputStreamTest"); boolean next = rs.next(); assertTrue(next); int count = rs.getInt(1); assertEquals(2, count); rs = st.executeQuery("SELECT * FROM LocalInfileInputStreamTest"); validateRecord(rs, 1, "hello"); validateRecord(rs, 2, "world"); } } @Test public void testLocalInfileValidInterceptor() throws Exception { File temp = File.createTempFile("validateInfile", ".txt"); StringBuilder builder = new StringBuilder(); builder.append("1,hello\n"); builder.append("2,world\n"); try (BufferedWriter bw = new BufferedWriter(new FileWriter(temp))) { bw.write(builder.toString()); } testLocalInfile(temp.getAbsolutePath().replace("\\", "/")); } @Test public void testLocalInfileUnValidInterceptor() throws Exception { File temp = File.createTempFile("localInfile", ".txt"); StringBuilder builder = new StringBuilder(); builder.append("1,hello\n"); builder.append("2,world\n"); try (BufferedWriter bw = new BufferedWriter(new FileWriter(temp))) { bw.write(builder.toString()); } try { testLocalInfile(temp.getAbsolutePath().replace("\\", "/")); fail("Must have been intercepted"); } catch (SQLException sqle) { assertTrue(sqle.getMessage().contains("LOCAL DATA LOCAL INFILE request to send local file named") && sqle.getMessage().contains("not validated by interceptor \"org.mariadb.jdbc.LocalInfileInterceptorImpl\"")); } //check that connection state is correct Statement st = sharedConnection.createStatement(); ResultSet rs = st.executeQuery("SELECT 1"); rs.next(); assertEquals(1, rs.getInt(1)); } private void testLocalInfile(String file) throws SQLException { try (Statement st = sharedConnection.createStatement()) { st.executeUpdate("LOAD DATA LOCAL INFILE '" + file + "' INTO TABLE ttlocal " + " FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"'" + " (id, test)"); ResultSet rs = st.executeQuery("SELECT COUNT(*) FROM ttlocal"); boolean next = rs.next(); assertTrue(next); assertEquals(2, rs.getInt(1)); rs = st.executeQuery("SELECT * FROM ttlocal"); validateRecord(rs, 1, "hello"); validateRecord(rs, 2, "world"); } } @Test public void loadDataInfileEmpty() throws SQLException, IOException { // Create temp file. File temp = File.createTempFile("validateInfile", ".tmp"); try { Statement st = sharedConnection.createStatement(); st.execute("LOAD DATA LOCAL INFILE '" + temp.getAbsolutePath().replace('\\', '/') + "' INTO TABLE ldinfile"); try (ResultSet rs = st.executeQuery("SELECT * FROM ldinfile")) { assertFalse(rs.next()); } } finally { temp.delete(); } } @Test public void testPrepareLocalInfileWithoutInputStream() throws SQLException { try { PreparedStatement st = sharedConnection.prepareStatement("LOAD DATA LOCAL INFILE 'validateInfile.tsv' " + "INTO TABLE ldinfile"); st.execute(); fail(); } catch (SQLException e) { assertTrue(e.getMessage().contains("Could not send file")); //check that connection is alright try { assertFalse(sharedConnection.isClosed()); Statement st = sharedConnection.createStatement(); st.execute("SELECT 1"); } catch (SQLException eee) { fail(); } } } private void validateRecord(ResultSet rs, int expectedId, String expectedTest) throws SQLException { boolean next = rs.next(); assertTrue(next); int id = rs.getInt(1); String test = rs.getString(2); assertEquals(expectedId, id); assertEquals(expectedTest, test); } private File createTmpData(long recordNumber) throws Exception { File file = File.createTempFile("./infile" + recordNumber, ".tmp"); //write it try (BufferedWriter writer = new BufferedWriter(new FileWriter(file))) { // Every row is 8 bytes to make counting easier for (long i = 0; i < recordNumber; i++) { writer.write("\"a\",\"b\""); writer.write("\n"); } } return file; } private void checkBigLocalInfile(long fileSize) throws Exception { long recordNumber = fileSize / 8; try (Statement statement = sharedConnection.createStatement()) { statement.execute("truncate `infile`"); File file = createTmpData(recordNumber); try (InputStream is = new BufferedInputStream(new FileInputStream(file))) { MariaDbStatement stmt = statement.unwrap(MariaDbStatement.class); stmt.setLocalInfileInputStream(is); int insertNumber = stmt.executeUpdate("LOAD DATA LOCAL INFILE 'ignoredFileName' " + "INTO TABLE `infile` " + "COLUMNS TERMINATED BY ',' ENCLOSED BY '\\\"' ESCAPED BY '\\\\' " + "LINES TERMINATED BY '\\n' (`a`, `b`)"); assertEquals(insertNumber, recordNumber); } statement.setFetchSize(1000); //to avoid using too much memory for tests try (ResultSet rs = statement.executeQuery("SELECT * FROM `infile`")) { for (int i = 0; i < recordNumber; i++) { assertTrue("record " + i + " doesn't exist", rs.next()); assertEquals("a", rs.getString(1)); assertEquals("b", rs.getString(2)); } assertFalse(rs.next()); } } } /** * CONJ-375 : error with local infile with size > 16mb. * * @throws Exception if error occus */ @Test public void testSmallBigLocalInfileInputStream() throws Exception { checkBigLocalInfile(256); } @Test public void test2xBigLocalInfileInputStream() throws Exception { Assume.assumeTrue(checkMaxAllowedPacketMore40m("test2xBigLocalInfileInputStream")); checkBigLocalInfile(16777216 * 2); } @Test public void test2xMaxAllowedPacketLocalInfileInputStream() throws Exception { ResultSet rs = sharedConnection.createStatement().executeQuery("select @@max_allowed_packet"); rs.next(); long maxAllowedPacket = rs.getLong(1); try { checkBigLocalInfile(maxAllowedPacket * 2); fail("must have fail"); } catch (SQLException sqle) { assertTrue(sqle.getMessage().contains("Could not send query: query size is >= to max_allowed_packet")); } } }