/*
* Copyright (c) 2008, PostgreSQL Global Development Group
* See the LICENSE file in the project root for more information.
*/
package org.postgresql.test.jdbc2;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import org.postgresql.PGConnection;
import org.postgresql.copy.CopyIn;
import org.postgresql.copy.CopyManager;
import org.postgresql.copy.CopyOut;
import org.postgresql.copy.PGCopyOutputStream;
import org.postgresql.core.ServerVersion;
import org.postgresql.test.TestUtil;
import org.postgresql.util.PSQLState;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.io.StringReader;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
/**
* @author kato@iki.fi
*/
public class CopyTest {
private Connection con;
private CopyManager copyAPI;
private String copyParams;
// 0's required to match DB output for numeric(5,2)
private String[] origData =
{"First Row\t1\t1.10\n",
"Second Row\t2\t-22.20\n",
"\\N\t\\N\t\\N\n",
"\t4\t444.40\n"};
private int dataRows = origData.length;
private byte[] getData(String[] origData) {
ByteArrayOutputStream buf = new ByteArrayOutputStream();
PrintStream ps = new PrintStream(buf);
for (String anOrigData : origData) {
ps.print(anOrigData);
}
return buf.toByteArray();
}
@Before
public void setUp() throws Exception {
con = TestUtil.openDB();
TestUtil.createTable(con, "copytest", "stringvalue text, intvalue int, numvalue numeric(5,2)");
copyAPI = ((PGConnection) con).getCopyAPI();
if (TestUtil.haveMinimumServerVersion(con, ServerVersion.v9_0)) {
copyParams = "(FORMAT CSV, HEADER false)";
} else {
copyParams = "CSV";
}
}
@After
public void tearDown() throws Exception {
TestUtil.closeDB(con);
// one of the tests will render the existing connection broken,
// so we need to drop the table on a fresh one.
con = TestUtil.openDB();
try {
TestUtil.dropTable(con, "copytest");
} finally {
con.close();
}
}
private int getCount() throws SQLException {
Statement stmt = con.createStatement();
ResultSet rs = stmt.executeQuery("SELECT count(*) FROM copytest");
rs.next();
int result = rs.getInt(1);
rs.close();
return result;
}
@Test
public void testCopyInByRow() throws SQLException {
String sql = "COPY copytest FROM STDIN";
CopyIn cp = copyAPI.copyIn(sql);
for (String anOrigData : origData) {
byte[] buf = anOrigData.getBytes();
cp.writeToCopy(buf, 0, buf.length);
}
long count1 = cp.endCopy();
long count2 = cp.getHandledRowCount();
long expectedResult = -1;
expectedResult = dataRows;
assertEquals(expectedResult, count1);
assertEquals(expectedResult, count2);
try {
cp.cancelCopy();
} catch (SQLException se) { // should fail with obsolete operation
if (!PSQLState.OBJECT_NOT_IN_STATE.getState().equals(se.getSQLState())) {
fail("should have thrown object not in state exception.");
}
}
int rowCount = getCount();
assertEquals(dataRows, rowCount);
}
@Test
public void testCopyInAsOutputStream() throws SQLException, IOException {
String sql = "COPY copytest FROM STDIN";
OutputStream os = new PGCopyOutputStream((PGConnection) con, sql, 1000);
for (String anOrigData : origData) {
byte[] buf = anOrigData.getBytes();
os.write(buf);
}
os.close();
int rowCount = getCount();
assertEquals(dataRows, rowCount);
}
@Test
public void testCopyInFromInputStream() throws SQLException, IOException {
String sql = "COPY copytest FROM STDIN";
copyAPI.copyIn(sql, new ByteArrayInputStream(getData(origData)), 3);
int rowCount = getCount();
assertEquals(dataRows, rowCount);
}
@Test
public void testCopyInFromStreamFail() throws SQLException {
String sql = "COPY copytest FROM STDIN";
try {
copyAPI.copyIn(sql, new InputStream() {
public int read() {
throw new RuntimeException("COPYTEST");
}
}, 3);
} catch (Exception e) {
if (!e.toString().contains("COPYTEST")) {
fail("should have failed trying to read from our bogus stream.");
}
}
int rowCount = getCount();
assertEquals(0, rowCount);
}
@Test
public void testCopyInFromReader() throws SQLException, IOException {
String sql = "COPY copytest FROM STDIN";
copyAPI.copyIn(sql, new StringReader(new String(getData(origData))), 3);
int rowCount = getCount();
assertEquals(dataRows, rowCount);
}
@Test
public void testSkipping() {
String sql = "COPY copytest FROM STDIN";
String at = "init";
int rowCount = -1;
int skip = 0;
int skipChar = 1;
try {
while (skipChar > 0) {
at = "buffering";
InputStream ins = new ByteArrayInputStream(getData(origData));
at = "skipping";
ins.skip(skip++);
skipChar = ins.read();
at = "copying";
copyAPI.copyIn(sql, ins, 3);
at = "using connection after writing copy";
rowCount = getCount();
}
} catch (Exception e) {
if (!(skipChar == '\t')) {
// error expected when field separator consumed
fail("testSkipping at " + at + " round " + skip + ": " + e.toString());
}
}
assertEquals(dataRows * (skip - 1), rowCount);
}
@Test
public void testCopyOutByRow() throws SQLException, IOException {
testCopyInByRow(); // ensure we have some data.
String sql = "COPY copytest TO STDOUT";
CopyOut cp = copyAPI.copyOut(sql);
int count = 0;
byte buf[];
while ((buf = cp.readFromCopy()) != null) {
count++;
}
assertEquals(false, cp.isActive());
assertEquals(dataRows, count);
long rowCount = cp.getHandledRowCount();
long expectedResult = -1;
expectedResult = dataRows;
assertEquals(expectedResult, rowCount);
assertEquals(dataRows, getCount());
}
@Test
public void testCopyOut() throws SQLException, IOException {
testCopyInByRow(); // ensure we have some data.
String sql = "COPY copytest TO STDOUT";
ByteArrayOutputStream copydata = new ByteArrayOutputStream();
copyAPI.copyOut(sql, copydata);
assertEquals(dataRows, getCount());
// deep comparison of data written and read
byte[] copybytes = copydata.toByteArray();
assertTrue(copybytes != null);
for (int i = 0, l = 0; i < origData.length; i++) {
byte[] origBytes = origData[i].getBytes();
assertTrue(origBytes != null);
assertTrue("Copy is shorter than original", copybytes.length >= l + origBytes.length);
for (int j = 0; j < origBytes.length; j++, l++) {
assertEquals("content changed at byte#" + j + ": " + origBytes[j] + copybytes[l],
origBytes[j], copybytes[l]);
}
}
}
@Test
public void testNonCopyOut() throws SQLException, IOException {
String sql = "SELECT 1";
try {
copyAPI.copyOut(sql, new ByteArrayOutputStream());
fail("Can't use a non-copy query.");
} catch (SQLException sqle) {
}
// Ensure connection still works.
assertEquals(0, getCount());
}
@Test
public void testNonCopyIn() throws SQLException, IOException {
String sql = "SELECT 1";
try {
copyAPI.copyIn(sql, new ByteArrayInputStream(new byte[0]));
fail("Can't use a non-copy query.");
} catch (SQLException sqle) {
}
// Ensure connection still works.
assertEquals(0, getCount());
}
@Test
public void testStatementCopyIn() throws SQLException {
Statement stmt = con.createStatement();
try {
stmt.execute("COPY copytest FROM STDIN");
fail("Should have failed because copy doesn't work from a Statement.");
} catch (SQLException sqle) {
}
stmt.close();
assertEquals(0, getCount());
}
@Test
public void testStatementCopyOut() throws SQLException {
testCopyInByRow(); // ensure we have some data.
Statement stmt = con.createStatement();
try {
stmt.execute("COPY copytest TO STDOUT");
fail("Should have failed because copy doesn't work from a Statement.");
} catch (SQLException sqle) {
}
stmt.close();
assertEquals(dataRows, getCount());
}
@Test
public void testCopyQuery() throws SQLException, IOException {
testCopyInByRow(); // ensure we have some data.
long count = copyAPI.copyOut("COPY (SELECT generate_series(1,1000)) TO STDOUT",
new ByteArrayOutputStream());
assertEquals(1000, count);
}
@Test
public void testCopyRollback() throws SQLException {
con.setAutoCommit(false);
testCopyInByRow();
con.rollback();
assertEquals(0, getCount());
}
@Test
public void testChangeDateStyle() throws SQLException {
try {
con.setAutoCommit(false);
con.setTransactionIsolation(Connection.TRANSACTION_REPEATABLE_READ);
CopyManager manager = con.unwrap(PGConnection.class).getCopyAPI();
Statement stmt = con.createStatement();
stmt.execute("SET DateStyle = 'ISO, DMY'");
// I expect an SQLException
String sql = "COPY copytest FROM STDIN with xxx " + copyParams;
CopyIn cp = manager.copyIn(sql);
for (String anOrigData : origData) {
byte[] buf = anOrigData.getBytes();
cp.writeToCopy(buf, 0, buf.length);
}
long count1 = cp.endCopy();
long count2 = cp.getHandledRowCount();
con.commit();
} catch (SQLException ex) {
// the with xxx is a syntax error which shoud return a state of 42601
// if this fails the 'S' command is not being handled in the copy manager query handler
assertEquals("42601", ex.getSQLState());
con.rollback();
}
}
@Test
public void testLockReleaseOnCancelFailure() throws SQLException, InterruptedException {
if (!TestUtil.haveMinimumServerVersion(con, ServerVersion.v8_4)) {
// pg_backend_pid() requires PostgreSQL 8.4+
return;
}
// This is a fairly complex test because it is testing a
// deadlock that only occurs when the connection to postgres
// is broken during a copy operation. We'll start a copy
// operation, use pg_terminate_backend to rudely break it,
// and then cancel. The test passes if a subsequent operation
// on the Connection object fails to deadlock.
con.setAutoCommit(false);
Statement stmt = con.createStatement();
ResultSet rs = stmt.executeQuery("select pg_backend_pid()");
rs.next();
int pid = rs.getInt(1);
rs.close();
stmt.close();
CopyManager manager = con.unwrap(PGConnection.class).getCopyAPI();
CopyIn copyIn = manager.copyIn("COPY copytest FROM STDIN with " + copyParams);
try {
killConnection(pid);
byte[] bunchOfNulls = ",,\n".getBytes();
while (true) {
copyIn.writeToCopy(bunchOfNulls, 0, bunchOfNulls.length);
}
} catch (SQLException e) {
acceptIOCause(e);
} finally {
if (copyIn.isActive()) {
try {
copyIn.cancelCopy();
fail("cancelCopy should have thrown an exception");
} catch (SQLException e) {
acceptIOCause(e);
}
}
}
// Now we'll execute rollback on another thread so that if the
// deadlock _does_ occur the testcase doesn't just hange forever.
Rollback rollback = new Rollback(con);
rollback.start();
rollback.join(1000);
if (rollback.isAlive()) {
fail("rollback did not terminate");
}
SQLException rollbackException = rollback.exception();
if (rollbackException == null) {
fail("rollback should have thrown an exception");
}
acceptIOCause(rollbackException);
}
private static class Rollback extends Thread {
private final Connection con;
private SQLException rollbackException;
public Rollback(Connection con) {
setName("Asynchronous rollback");
setDaemon(true);
this.con = con;
}
@Override
public void run() {
try {
con.rollback();
} catch (SQLException e) {
rollbackException = e;
}
}
public SQLException exception() {
return rollbackException;
}
}
private void killConnection(int pid) throws SQLException {
Connection killerCon;
try {
killerCon = TestUtil.openPrivilegedDB();
} catch (Exception e) {
fail("Unable to open secondary connection to terminate copy");
return; // persuade Java killerCon will not be used uninitialized
}
try {
PreparedStatement stmt = killerCon.prepareStatement("select pg_terminate_backend(?)");
stmt.setInt(1, pid);
stmt.execute();
} finally {
killerCon.close();
}
}
private void acceptIOCause(SQLException e) throws SQLException {
if (!(e.getCause() instanceof IOException)) {
throw e;
}
}
}