package edu.stanford.nlp.patterns.surface; import edu.stanford.nlp.patterns.Pattern; import edu.stanford.nlp.patterns.SQLConnection; import edu.stanford.nlp.util.ArgumentParser; import edu.stanford.nlp.util.logging.Redwood; import java.io.*; import java.sql.*; import java.util.*; /** * Created by sonalg on 10/22/14. */ public class PatternsForEachTokenDB<E extends Pattern> extends PatternsForEachToken<E>{ @ArgumentParser.Option(name = "createTable") boolean createTable = false; @ArgumentParser.Option(name = "deleteExisting") boolean deleteExisting = false; @ArgumentParser.Option(name = "tableName") String tableName = null; @ArgumentParser.Option(name = "patternindicesTable") String patternindicesTable = "patternindices"; @ArgumentParser.Option(name="deleteDBResourcesOnExit") boolean deleteDBResourcesOnExit = true; public PatternsForEachTokenDB(Properties props, Map<String, Map<Integer, Set<E>>> pats){ ArgumentParser.fillOptions(this, props); ArgumentParser.fillOptions(SQLConnection.class, props); assert tableName != null : "tableName property is null!"; tableName = tableName.toLowerCase(); if (createTable && !deleteExisting) throw new RuntimeException("Cannot have createTable as true and deleteExisting as false!"); if (createTable){ createTable(); createUpsertFunction(); }else{ assert DBTableExists() : "Table " + tableName + " does not exists. Pass createTable=true to create a new table"; } if(pats != null) addPatterns(pats); } public PatternsForEachTokenDB(Properties props) { this(props, null); } void createTable() { String query =""; try { Connection conn = SQLConnection.getConnection(); if(DBTableExists()){ if (deleteExisting) { System.out.println("deleting table " + tableName); Statement stmt = conn.createStatement(); query = "drop table " + tableName; stmt.execute(query); stmt.close(); Statement stmtindex = conn.createStatement(); query = "DROP INDEX IF EXISTS " + tableName+"_index"; stmtindex.execute(query); stmtindex.close(); } } System.out.println("creating table " + tableName); Statement stmt = conn.createStatement(); //query = "create table IF NOT EXISTS " + tableName + " (\"sentid\" text, \"tokenid\" int, \"patterns\" bytea); "; query = "create table IF NOT EXISTS " + tableName + " (sentid text, patterns bytea); "; stmt.execute(query); stmt.close(); conn.close();} catch (SQLException e) { throw new RuntimeException("Error executing query " + query + "\n" + e); } } @Override public void addPatterns(Map<String, Map<Integer, Set<E>>> pats){ try { Connection conn = null; PreparedStatement pstmt = null; conn = SQLConnection.getConnection(); pstmt = getPreparedStmt(conn); for (Map.Entry<String, Map<Integer, Set<E>>> en : pats.entrySet()) { addPattern(en.getKey(), en.getValue(), pstmt); pstmt.addBatch(); } pstmt.executeBatch(); conn.commit(); pstmt.close(); conn.close(); }catch(SQLException e){ throw new RuntimeException(e); } catch (IOException e) { throw new RuntimeException(e); } } public void addPatterns(String id, Map<Integer, Set<E>> p){ try { PreparedStatement pstmt = null; Connection conn= null; conn = SQLConnection.getConnection(); pstmt = getPreparedStmt(conn); addPattern(id, p, pstmt); pstmt.execute(); conn.commit(); pstmt.close(); conn.close(); } catch (SQLException e) { throw new RuntimeException(e); } catch (IOException e) { throw new RuntimeException(e); } } /* public void addPatterns(String id, Map<Integer, Set<Integer>> p, PreparedStatement pstmt) throws IOException, SQLException { for (Map.Entry<Integer, Set<Integer>> en2 : p.entrySet()) { addPattern(id, en2.getKey(), en2.getValue(), pstmt); if(useDBForTokenPatterns) pstmt.addBatch(); } } */ /* public void addPatterns(String sentId, int tokenId, Set<Integer> patterns) throws SQLException, IOException{ PreparedStatement pstmt = null; Connection conn= null; if(useDBForTokenPatterns) { conn = SQLConnection.getConnection(); pstmt = getPreparedStmt(conn); } addPattern(sentId, tokenId, patterns, pstmt); if(useDBForTokenPatterns){ pstmt.execute(); conn.commit(); pstmt.close(); conn.close(); } } */ /* private void addPattern(String sentId, int tokenId, Set<Integer> patterns, PreparedStatement pstmt) throws SQLException, IOException { if(pstmt != null){ // ByteArrayOutputStream baos = new ByteArrayOutputStream(); // ObjectOutputStream oos = new ObjectOutputStream(baos); // oos.writeObject(patterns); // byte[] patsAsBytes = baos.toByteArray(); // ByteArrayInputStream bais = new ByteArrayInputStream(patsAsBytes); // pstmt.setBinaryStream(1, bais, patsAsBytes.length); // pstmt.setObject(2, sentId); // pstmt.setInt(3, tokenId); // pstmt.setString(4,sentId); // pstmt.setInt(5, tokenId); // ByteArrayOutputStream baos2 = new ByteArrayOutputStream(); // ObjectOutputStream oos2 = new ObjectOutputStream(baos2); // oos2.writeObject(patterns); // byte[] patsAsBytes2 = baos2.toByteArray(); // ByteArrayInputStream bais2 = new ByteArrayInputStream(patsAsBytes2); // pstmt.setBinaryStream(6, bais2, patsAsBytes2.length); // pstmt.setString(7,sentId); // pstmt.setInt(8, tokenId); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos); oos.writeObject(patterns); byte[] patsAsBytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(patsAsBytes); pstmt.setBinaryStream(3, bais, patsAsBytes.length); pstmt.setObject(1, sentId); pstmt.setInt(2, tokenId); } else{ if(!patternsForEachToken.containsKey(sentId)) patternsForEachToken.put(sentId, new ConcurrentHashMap<Integer, Set<Integer>>()); patternsForEachToken.get(sentId).put(tokenId, patterns); } }*/ private void addPattern(String sentId, Map<Integer, Set<E>> patterns, PreparedStatement pstmt) throws SQLException, IOException { if(pstmt != null){ ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos); oos.writeObject(patterns); byte[] patsAsBytes = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(patsAsBytes); pstmt.setBinaryStream(2, bais, patsAsBytes.length); pstmt.setObject(1, sentId); //pstmt.setInt(2, tokenId); } } public void createUpsertFunction() { try{ Connection conn = SQLConnection.getConnection(); String s = "CREATE OR REPLACE FUNCTION upsert_patterns(sentid1 text, pats1 bytea) RETURNS VOID AS $$\n" + "DECLARE\n" + "BEGIN\n" + " UPDATE " + tableName+ " SET patterns = pats1 WHERE sentid = sentid1;\n" + " IF NOT FOUND THEN\n" + " INSERT INTO " + tableName + " values (sentid1, pats1);\n" + " END IF;\n" + "END;\n" + "$$ LANGUAGE 'plpgsql';\n"; Statement st = conn.createStatement(); st.execute(s); conn.close();}catch(SQLException e){ throw new RuntimeException(e); } } public void createUpsertFunctionPatternIndex() throws SQLException { Connection conn = SQLConnection.getConnection(); String s = "CREATE OR REPLACE FUNCTION upsert_patternindex(tablename1 text, index1 bytea) RETURNS VOID AS $$\n" + "DECLARE\n" + "BEGIN\n" + " UPDATE " + patternindicesTable + " SET index = index1 WHERE tablename = tablename1;\n" + " IF NOT FOUND THEN\n" + " INSERT INTO " + patternindicesTable + " values (tablename1, index1);\n" + " END IF;\n" + "END;\n" + "$$ LANGUAGE 'plpgsql';\n"; Statement st = conn.createStatement(); st.execute(s); conn.close(); } private PreparedStatement getPreparedStmt(Connection conn) throws SQLException { conn.setAutoCommit(false); //return conn.prepareStatement("UPDATE " + tableName + " SET patterns = ? WHERE sentid = ? and tokenid = ?; " + // "INSERT INTO " + tableName + " (sentid, tokenid, patterns) (SELECT ?,?,? WHERE NOT EXISTS (SELECT sentid FROM " + tableName + " WHERE sentid =? and tokenid=?));"); // return conn.prepareStatement("INSERT INTO " + tableName + " (sentid, tokenid, patterns) (SELECT ?,?,? WHERE NOT EXISTS (SELECT sentid FROM " + tableName + " WHERE sentid =? and tokenid=?))"); return conn.prepareStatement("select upsert_patterns(?,?)"); } /* public Set<Integer> getPatterns(String sentId, Integer tokenId) throws SQLException, IOException, ClassNotFoundException { if(useDBForTokenPatterns){ Connection conn = SQLConnection.getConnection(); String query = "Select patterns from " + tableName + " where sentid=\'" + sentId + "\' and tokenid = " + tokenId; Statement stmt = conn.createStatement(); ResultSet rs = stmt.executeQuery(query); Set<Integer> pats = null; if(rs.next()){ byte[] st = (byte[]) rs.getObject(1); ByteArrayInputStream baip = new ByteArrayInputStream(st); ObjectInputStream ois = new ObjectInputStream(baip); pats = (Set<Integer>) ois.readObject(); } conn.close(); return pats; } else return patternsForEachToken.get(sentId).get(tokenId); }*/ @Override public Map<Integer, Set<E>> getPatternsForAllTokens(String sentId){ try{ Connection conn = SQLConnection.getConnection(); //Map<Integer, Set<Integer>> pats = new ConcurrentHashMap<Integer, Set<Integer>>(); String query = "Select patterns from " + tableName + " where sentid=\'" + sentId + "\'"; Statement stmt = conn.createStatement(); ResultSet rs = stmt.executeQuery(query); Map<Integer, Set<E>> patsToken = new HashMap<>(); if(rs.next()){ byte[] st = (byte[]) rs.getObject(1); ByteArrayInputStream baip = new ByteArrayInputStream(st); ObjectInputStream ois = new ObjectInputStream(baip); patsToken = (Map<Integer, Set<E>>) ois.readObject(); //pats.put(rs.getInt("tokenid"), patsToken); } conn.close(); return patsToken; }catch(SQLException e){ throw new RuntimeException(e); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } catch (IOException e) { throw new RuntimeException(e); } } @Override public boolean save(String dir) { //nothing to do return false; } @Override public void setupSearch() { //nothing to do } public boolean containsSentId(String sentId){ try { Connection conn = SQLConnection.getConnection(); String query = "Select tokenid from " + tableName + " where sentid=\'" + sentId + "\' limit 1"; Statement stmt = conn.createStatement(); ResultSet rs = stmt.executeQuery(query); boolean contains = false; while (rs.next()) { contains = true; break; } conn.close(); return contains; }catch(SQLException e){ throw new RuntimeException(e); } } @Override public void createIndexIfUsingDBAndNotExists(){ try { Redwood.log(Redwood.DBG, "Creating index for " + tableName); Connection conn = SQLConnection.getConnection(); Statement stmt = conn.createStatement(); boolean doesnotexist = false; //check if the index already exists try{ Statement stmt2 = conn.createStatement(); String query = "SELECT '"+tableName+"_index'::regclass"; stmt2.execute(query); }catch (SQLException e){ doesnotexist = true; } if(doesnotexist){ String indexquery ="create index CONCURRENTLY " + tableName +"_index on " + tableName+ " using hash(\"sentid\") "; stmt.execute(indexquery); Redwood.log(Redwood.DBG, "Done creating index for " + tableName); } } catch (SQLException e) { throw new RuntimeException(e); } } // /** // * not yet supported if backed by DB // * @return // */ // public Set<Map.Entry<String, Map<Integer, Set<Integer>>>> entrySet() { // if(!useDBForTokenPatterns) // return patternsForEachToken.entrySet(); // else // //not yet supported if backed by DB // throw new UnsupportedOperationException(); // } public boolean DBTableExists() { try { Connection conn = null; conn = SQLConnection.getConnection(); DatabaseMetaData dbm = conn.getMetaData(); ResultSet tables = dbm.getTables(null, null, tableName, null); if (tables.next()) { System.out.println("Found table " + tableName); conn.close(); return true; } conn.close(); return false; }catch(SQLException e){ throw new RuntimeException(e); } } // // @Override // public ConcurrentHashIndex<SurfacePattern> readPatternIndex(String dir){ // //dir parameter is not used! // try{ // Connection conn = SQLConnection.getConnection(); // //Map<Integer, Set<Integer>> pats = new ConcurrentHashMap<Integer, Set<Integer>>(); // String query = "Select index from " + patternindicesTable + " where tablename=\'" + tableName + "\'"; // Statement stmt = conn.createStatement(); // ResultSet rs = stmt.executeQuery(query); // ConcurrentHashIndex<SurfacePattern> index = null; // if(rs.next()){ // byte[] st = (byte[]) rs.getObject(1); // ByteArrayInputStream baip = new ByteArrayInputStream(st); // ObjectInputStream ois = new ObjectInputStream(baip); // index = (ConcurrentHashIndex<SurfacePattern>) ois.readObject(); // } // assert index != null; // return index; // }catch(SQLException e){ // throw new RuntimeException(e); // } catch (ClassNotFoundException e) { // throw new RuntimeException(e); // } catch (IOException e) { // throw new RuntimeException(e); // } // } // // @Override // public void savePatternIndex(ConcurrentHashIndex<SurfacePattern> index, String file) { // try { // createUpsertFunctionPatternIndex(); // Connection conn = SQLConnection.getConnection(); // PreparedStatement st = conn.prepareStatement("select upsert_patternindex(?,?)"); // st.setString(1,tableName); // ByteArrayOutputStream baos = new ByteArrayOutputStream(); // ObjectOutputStream oos = new ObjectOutputStream(baos); // oos.writeObject(index); // byte[] patsAsBytes = baos.toByteArray(); // ByteArrayInputStream bais = new ByteArrayInputStream(patsAsBytes); // st.setBinaryStream(2, bais, patsAsBytes.length); // st.execute(); // st.close(); // conn.close(); // System.out.println("Saved the pattern hash index for " + tableName + " in DB table " + patternindicesTable); // }catch (SQLException e){ // throw new RuntimeException(e); // } catch (IOException e) { // throw new RuntimeException(e); // } // } //batch processing below is copied from Java Ranch public static final int SINGLE_BATCH = 1; public static final int SMALL_BATCH = 4; public static final int MEDIUM_BATCH = 11; public static final int LARGE_BATCH = 51; //TODO: make this into an iterator!! @Override public Map<String, Map<Integer, Set<E>>> getPatternsForAllTokens(Collection<String> sampledSentIds) { try{ Map<String, Map<Integer, Set<E>>> pats = new HashMap<>(); Connection conn = SQLConnection.getConnection(); Iterator<String> iter = sampledSentIds.iterator(); int totalNumberOfValuesLeftToBatch = sampledSentIds.size(); while ( totalNumberOfValuesLeftToBatch > 0 ) { int batchSize = SINGLE_BATCH; if (totalNumberOfValuesLeftToBatch >= LARGE_BATCH) { batchSize = LARGE_BATCH; } else if (totalNumberOfValuesLeftToBatch >= MEDIUM_BATCH) { batchSize = MEDIUM_BATCH; } else if (totalNumberOfValuesLeftToBatch >= SMALL_BATCH) { batchSize = SMALL_BATCH; } totalNumberOfValuesLeftToBatch -= batchSize; StringBuilder inClause = new StringBuilder(); for (int i = 0; i < batchSize; i++) { inClause.append('?'); if (i != batchSize - 1) { inClause.append(','); } } PreparedStatement stmt = conn.prepareStatement( "select sentid, patterns from " + tableName + " where sentid in (" + inClause.toString() + ")"); for (int i=0; i < batchSize && iter.hasNext(); i++) { stmt.setString(i+1, iter.next()); // or whatever values you are trying to query by } stmt.execute(); ResultSet rs = stmt.getResultSet(); while(rs.next()){ String sentid = rs.getString(1); byte[] st = (byte[]) rs.getObject(2); ByteArrayInputStream baip = new ByteArrayInputStream(st); ObjectInputStream ois = new ObjectInputStream(baip); pats.put(sentid, (Map<Integer, Set<E>>) ois.readObject()); } } conn.close(); return pats; }catch(SQLException e){ throw new RuntimeException(e); } catch (ClassNotFoundException e) { throw new RuntimeException(e); } catch (IOException e) { throw new RuntimeException(e); } } @Override public void close() { //nothing to do } @Override public void load(String allPatternsDir) { //nothing to do } @Override public int size(){ //TODO: NOT IMPLEMENTED return Integer.MAX_VALUE; } }