package edu.stanford.nlp.patterns.surface;
import edu.stanford.nlp.patterns.ConstantsAndVariables;
import java.lang.reflect.InvocationTargetException;
import java.util.*;
/**
* Created by Sonal Gupta on 10/8/14.
*/
public abstract class PatternsForEachToken<E> {
private static ConstantsAndVariables.PatternForEachTokenWay storeWay;
abstract public void addPatterns(Map<String, Map<Integer, Set<E>>> pats);
abstract public void addPatterns(String id, Map<Integer, Set<E>> p);
abstract public void createIndexIfUsingDBAndNotExists();
abstract public Map<Integer, Set<E>> getPatternsForAllTokens(String sentId);
abstract public boolean save(String dir);
// /**
// * Only for Lucene and DB
// * @return
// */
// abstract public PatternIndex readPatternIndex(String dir) throws IOException, ClassNotFoundException;
abstract public void setupSearch();
abstract int size();
//abstract public void savePatternIndex(PatternIndex index, String dir) throws IOException;
public void updatePatterns(Map<String, Map<Integer, Set<E>>> tempPatsForSents) {
for(Map.Entry<String, Map<Integer, Set<E>>> en :tempPatsForSents.entrySet()){
Map<Integer, Set<E>> m = getPatternsForAllTokens(en.getKey());
if(m == null)
m = new HashMap<>();
tempPatsForSents.get(en.getKey()).putAll(m);
}
this.addPatterns(tempPatsForSents);
close();
}
public ConstantsAndVariables.PatternForEachTokenWay getStoreWay() {
return storeWay;
}
public static PatternsForEachToken getPatternsInstance(Properties props, ConstantsAndVariables.PatternForEachTokenWay storePatsForEachToken)
{
storeWay = storePatsForEachToken;
PatternsForEachToken p = null;
switch(storePatsForEachToken){
case MEMORY:{
p = new PatternsForEachTokenInMemory(props);
break;
}
case DB:{
p = new PatternsForEachTokenDB(props);
break;
}
case LUCENE:
{
try{
Class c = Class.forName("edu.stanford.nlp.patterns.surface.PatternsForEachTokenLucene");
p = (PatternsForEachToken) c.getDeclaredConstructor(Properties.class).newInstance(props);
break;
}catch (ClassNotFoundException e) {
throw new RuntimeException("Lucene option is not distributed (license clash). Email us if you really want it.");
} catch (InvocationTargetException e) {
throw new RuntimeException(e);
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
} catch (InstantiationException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
}
}
}
return p;
//if(storePatsForEachToken.equals(DB)){}
}
public abstract Map<String,Map<Integer,Set<E>>> getPatternsForAllTokens(Collection<String> sampledSentIds);
public abstract void close();
public abstract void load(String allPatternsDir);
// @Option(name="allPatternsFile")
// String allPatternsFile = null;
//
// /**
// * If all patterns should be computed. Otherwise patterns are read from
// * allPatternsFile
// */
// @Option(name = "computeAllPatterns")
// public boolean computeAllPatterns = true;
//Connection conn;
// public PatternsForEachToken(Properties props, Map<String, Map<Integer, Set<Integer>>> pats) throws SQLException, ClassNotFoundException, IOException {
// Execution.fillOptions(this, props);
//
// if (useDBForTokenPatterns) {
// Execution.fillOptions(SQLConnection.class, props);
//
// assert tableName != null : "tableName property is null!";
// tableName = tableName.toLowerCase();
// if (createTable && !deleteExisting)
// throw new RuntimeException("Cannot have createTable as true and deleteExisting as false!");
// if (createTable){
// createTable();
// createUpsertFunction();
// }else{
// assert DBTableExists() : "Table " + tableName + " does not exists. Pass createTable=true to create a new table";
// }
// }else
// patternsForEachToken = new ConcurrentHashMap<String, Map<Integer, Set<Integer>>>();
//
// if(pats != null)
// addPatterns(pats);
// }
//
// public PatternsForEachToken(){}
//
// public PatternsForEachToken(Properties props) throws SQLException, IOException, ClassNotFoundException {
// this(props, null);
// }
// public void addPatterns(Map<String, Map<Integer, Set<Integer>>> pats) throws IOException, SQLException {
// Connection conn = null;
// PreparedStatement pstmt = null;
//
// if(useDBForTokenPatterns) {
// conn = SQLConnection.getConnection();
// pstmt =getPreparedStmt(conn);
// }
//
// for (Map.Entry<String, Map<Integer, Set<Integer>>> en : pats.entrySet()) {
// addPattern(en.getKey(), en.getValue(), pstmt);
// if(useDBForTokenPatterns)
// pstmt.addBatch();
// }
//
// if(useDBForTokenPatterns){
// pstmt.executeBatch();
// conn.commit();
// pstmt.close();
// conn.close();
// }
// }
//
// public void addPatterns(String id, Map<Integer, Set<Integer>> p) throws IOException, SQLException {
// PreparedStatement pstmt = null;
// Connection conn= null;
//
// if(useDBForTokenPatterns) {
// conn = SQLConnection.getConnection();
// pstmt = getPreparedStmt(conn);
// }
//
// addPattern(id, p, pstmt);
//
// if(useDBForTokenPatterns){
// pstmt.execute();
// conn.commit();
// pstmt.close();
// conn.close();
// }
// }
/*
public void addPatterns(String id, Map<Integer, Set<Integer>> p, PreparedStatement pstmt) throws IOException, SQLException {
for (Map.Entry<Integer, Set<Integer>> en2 : p.entrySet()) {
addPattern(id, en2.getKey(), en2.getValue(), pstmt);
if(useDBForTokenPatterns)
pstmt.addBatch();
}
}
*/
/*
public void addPatterns(String sentId, int tokenId, Set<Integer> patterns) throws SQLException, IOException{
PreparedStatement pstmt = null;
Connection conn= null;
if(useDBForTokenPatterns) {
conn = SQLConnection.getConnection();
pstmt = getPreparedStmt(conn);
}
addPattern(sentId, tokenId, patterns, pstmt);
if(useDBForTokenPatterns){
pstmt.execute();
conn.commit();
pstmt.close();
conn.close();
}
}
*/
/*
private void addPattern(String sentId, int tokenId, Set<Integer> patterns, PreparedStatement pstmt) throws SQLException, IOException {
if(pstmt != null){
// ByteArrayOutputStream baos = new ByteArrayOutputStream();
// ObjectOutputStream oos = new ObjectOutputStream(baos);
// oos.writeObject(patterns);
// byte[] patsAsBytes = baos.toByteArray();
// ByteArrayInputStream bais = new ByteArrayInputStream(patsAsBytes);
// pstmt.setBinaryStream(1, bais, patsAsBytes.length);
// pstmt.setObject(2, sentId);
// pstmt.setInt(3, tokenId);
// pstmt.setString(4,sentId);
// pstmt.setInt(5, tokenId);
// ByteArrayOutputStream baos2 = new ByteArrayOutputStream();
// ObjectOutputStream oos2 = new ObjectOutputStream(baos2);
// oos2.writeObject(patterns);
// byte[] patsAsBytes2 = baos2.toByteArray();
// ByteArrayInputStream bais2 = new ByteArrayInputStream(patsAsBytes2);
// pstmt.setBinaryStream(6, bais2, patsAsBytes2.length);
// pstmt.setString(7,sentId);
// pstmt.setInt(8, tokenId);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(baos);
oos.writeObject(patterns);
byte[] patsAsBytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(patsAsBytes);
pstmt.setBinaryStream(3, bais, patsAsBytes.length);
pstmt.setObject(1, sentId);
pstmt.setInt(2, tokenId);
} else{
if(!patternsForEachToken.containsKey(sentId))
patternsForEachToken.put(sentId, new ConcurrentHashMap<Integer, Set<Integer>>());
patternsForEachToken.get(sentId).put(tokenId, patterns);
}
}*/
// private void addPattern(String sentId, Map<Integer, Set<Integer>> patterns, PreparedStatement pstmt) throws SQLException, IOException {
//
// if(pstmt != null){
// ByteArrayOutputStream baos = new ByteArrayOutputStream();
// ObjectOutputStream oos = new ObjectOutputStream(baos);
// oos.writeObject(patterns);
// byte[] patsAsBytes = baos.toByteArray();
// ByteArrayInputStream bais = new ByteArrayInputStream(patsAsBytes);
// pstmt.setBinaryStream(2, bais, patsAsBytes.length);
// pstmt.setObject(1, sentId);
// //pstmt.setInt(2, tokenId);
//
//
// } else{
// if(!patternsForEachToken.containsKey(sentId))
// patternsForEachToken.put(sentId, new ConcurrentHashMap<Integer, Set<Integer>>());
// patternsForEachToken.get(sentId).putAll(patterns);
// }
// }
//
//
// public void createUpsertFunction() throws SQLException {
// Connection conn = SQLConnection.getConnection();
// String s = "CREATE OR REPLACE FUNCTION upsert_patterns(sentid1 text, pats1 bytea) RETURNS VOID AS $$\n" +
// "DECLARE\n" +
// "BEGIN\n" +
// " UPDATE " + tableName+ " SET patterns = pats1 WHERE sentid = sentid1;\n" +
// " IF NOT FOUND THEN\n" +
// " INSERT INTO " + tableName + " values (sentid1, pats1);\n" +
// " END IF;\n" +
// "END;\n" +
// "$$ LANGUAGE 'plpgsql';\n";
// Statement st = conn.createStatement();
// st.execute(s);
// conn.close();
// }
//
// public void createUpsertFunctionPatternIndex() throws SQLException {
// Connection conn = SQLConnection.getConnection();
// String s = "CREATE OR REPLACE FUNCTION upsert_patternindex(tablename1 text, index1 bytea) RETURNS VOID AS $$\n" +
// "DECLARE\n" +
// "BEGIN\n" +
// " UPDATE " + patternindicesTable + " SET index = index1 WHERE tablename = tablename;\n" +
// " IF NOT FOUND THEN\n" +
// " INSERT INTO " + patternindicesTable + " values (tablename1, index1);\n" +
// " END IF;\n" +
// "END;\n" +
// "$$ LANGUAGE 'plpgsql';\n";
// Statement st = conn.createStatement();
// st.execute(s);
// conn.close();
// }
//
//
//
//
//
//
// private PreparedStatement getPreparedStmt(Connection conn) throws SQLException {
// conn.setAutoCommit(false);
// //return conn.prepareStatement("UPDATE " + tableName + " SET patterns = ? WHERE sentid = ? and tokenid = ?; " +
// // "INSERT INTO " + tableName + " (sentid, tokenid, patterns) (SELECT ?,?,? WHERE NOT EXISTS (SELECT sentid FROM " + tableName + " WHERE sentid =? and tokenid=?));");
// // return conn.prepareStatement("INSERT INTO " + tableName + " (sentid, tokenid, patterns) (SELECT ?,?,? WHERE NOT EXISTS (SELECT sentid FROM " + tableName + " WHERE sentid =? and tokenid=?))");
// return conn.prepareStatement("select upsert_patterns(?,?)");
// }
//
//
//
//
///*
// public Set<Integer> getPatterns(String sentId, Integer tokenId) throws SQLException, IOException, ClassNotFoundException {
// if(useDBForTokenPatterns){
// Connection conn = SQLConnection.getConnection();
//
// String query = "Select patterns from " + tableName + " where sentid=\'" + sentId + "\' and tokenid = " + tokenId;
// Statement stmt = conn.createStatement();
// ResultSet rs = stmt.executeQuery(query);
// Set<Integer> pats = null;
// if(rs.next()){
// byte[] st = (byte[]) rs.getObject(1);
// ByteArrayInputStream baip = new ByteArrayInputStream(st);
// ObjectInputStream ois = new ObjectInputStream(baip);
// pats = (Set<Integer>) ois.readObject();
//
// }
// conn.close();
// return pats;
// }
// else
// return patternsForEachToken.get(sentId).get(tokenId);
// }*/
//
//
//
// public Map<Integer, Set<Integer>> getPatternsForAllTokens(String sentId) throws SQLException, IOException, ClassNotFoundException {
// if(useDBForTokenPatterns){
// Connection conn = SQLConnection.getConnection();
// //Map<Integer, Set<Integer>> pats = new ConcurrentHashMap<Integer, Set<Integer>>();
// String query = "Select patterns from " + tableName + " where sentid=\'" + sentId + "\'";
// Statement stmt = conn.createStatement();
// ResultSet rs = stmt.executeQuery(query);
// Map<Integer, Set<Integer>> patsToken = new HashMap<Integer, Set<Integer>>();
// if(rs.next()){
// byte[] st = (byte[]) rs.getObject(1);
// ByteArrayInputStream baip = new ByteArrayInputStream(st);
// ObjectInputStream ois = new ObjectInputStream(baip);
// patsToken = (Map<Integer, Set<Integer>>) ois.readObject();
// //pats.put(rs.getInt("tokenid"), patsToken);
// }
// conn.close();
// return patsToken;
// }
// else
// return patternsForEachToken.containsKey(sentId) ? patternsForEachToken.get(sentId): Collections.emptyMap();
// }
//
//
//
// boolean getUseDBForTokenPatterns(){
// return useDBForTokenPatterns;
// }
//
// public boolean writePatternsIfInMemory(String allPatternsFile) throws IOException {
// if(!useDBForTokenPatterns)
// {
// IOUtils.writeObjectToFile(this.patternsForEachToken, allPatternsFile);
// return true;
// }
// return false;
// }
//
//
// public boolean containsSentId(String sentId) throws SQLException {
// if(!useDBForTokenPatterns)
// return this.patternsForEachToken.containsKey(sentId);
// else {
// Connection conn = SQLConnection.getConnection();
// String query = "Select tokenid from " + tableName + " where sentid=\'" + sentId + "\' limit 1";
// Statement stmt = conn.createStatement();
// ResultSet rs = stmt.executeQuery(query);
//
// boolean contains = false;
//
// while(rs.next()){
// contains = true;
// break;
// }
//
// conn.close();
// return contains;
// }
// }
//
// public void createIndexIfUsingDBAndNotExists(){
// if(useDBForTokenPatterns){
// try {
// Redwood.log(Redwood.DBG, "Creating index for " + tableName);
// Connection conn = SQLConnection.getConnection();
// Statement stmt = conn.createStatement();
// boolean doesnotexist = false;
//
// //check if the index already exists
// try{
// Statement stmt2 = conn.createStatement();
// String query = "SELECT '"+tableName+"_index'::regclass";
// stmt2.execute(query);
// }catch (SQLException e){
// doesnotexist = true;
// }
//
// if(doesnotexist){
// String indexquery ="create index CONCURRENTLY " + tableName +"_index on " + tableName+ " using hash(\"sentid\") ";
// stmt.execute(indexquery);
// Redwood.log(Redwood.DBG, "Done creating index for " + tableName);
// }
// } catch (SQLException e) {
// throw new RuntimeException(e);
// }
// }
// }
//
// /**
// * not yet supported if backed by DB
// * @return
// */
// public Set<Map.Entry<String, Map<Integer, Set<Integer>>>> entrySet() {
// if(!useDBForTokenPatterns)
// return patternsForEachToken.entrySet();
// else
// //not yet supported if backed by DB
// throw new UnsupportedOperationException();
// }
//
// public void updatePatterns(Map<String, Map<Integer, Set<Integer>>> tempPatsForSents) {
// try {
// for(Map.Entry<String, Map<Integer, Set<Integer>>> en :tempPatsForSents.entrySet()){
// Map<Integer, Set<Integer>> m = getPatternsForAllTokens(en.getKey());
// if(m == null)
// m = new HashMap<Integer, Set<Integer>>();
// //m.putAll(en.getValue());
// tempPatsForSents.get(en.getKey()).putAll(m);
// }
// this.addPatterns(tempPatsForSents);
// } catch (IOException e) {
// e.printStackTrace();
// } catch (SQLException e) {
// e.printStackTrace();
// } catch (ClassNotFoundException e) {
// e.printStackTrace();
// }
// }
//
// public boolean DBTableExists() {
// try {
// Connection conn = null;
//
// conn = SQLConnection.getConnection();
//
// DatabaseMetaData dbm = conn.getMetaData();
// ResultSet tables = dbm.getTables(null, null, tableName, null);
// if (tables.next()) {
// System.out.println("Found table " + tableName);
// conn.close();
// return true;
// }
// conn.close();
// return false;
// }catch(SQLException e){
// throw new RuntimeException(e);
//
// }
// }
//
// public ConcurrentHashIndex<SurfacePattern> readPatternIndexFromDB(){
// try{
// Connection conn = SQLConnection.getConnection();
// //Map<Integer, Set<Integer>> pats = new ConcurrentHashMap<Integer, Set<Integer>>();
// String query = "Select * from " + patternindicesTable + " where tablename=\'" + tableName + "\'";
// Statement stmt = conn.createStatement();
// ResultSet rs = stmt.executeQuery(query);
// ConcurrentHashIndex<SurfacePattern> index = null;
// if(rs.next()){
// byte[] st = (byte[]) rs.getObject(1);
// ByteArrayInputStream baip = new ByteArrayInputStream(st);
// ObjectInputStream ois = new ObjectInputStream(baip);
// index = (ConcurrentHashIndex<SurfacePattern>) ois.readObject();
// }
// assert index != null;
// return index;
// }catch(SQLException e){
// throw new RuntimeException(e);
// } catch (ClassNotFoundException e) {
// throw new RuntimeException(e);
// } catch (IOException e) {
// throw new RuntimeException(e);
// }
// }
//
// public void savePatternIndexInDB(ConcurrentHashIndex<SurfacePattern> index) {
// try {
// createUpsertFunctionPatternIndex();
// Connection conn = SQLConnection.getConnection();
// PreparedStatement st = conn.prepareStatement("select upsert_patternindex(?,?)");
// st.setString(1,tableName);
// ByteArrayOutputStream baos = new ByteArrayOutputStream();
// ObjectOutputStream oos = new ObjectOutputStream(baos);
// oos.writeObject(index);
// byte[] patsAsBytes = baos.toByteArray();
// ByteArrayInputStream bais = new ByteArrayInputStream(patsAsBytes);
// st.setBinaryStream(2, bais, patsAsBytes.length);
//
// st.execute();
// st.close();
// conn.close();
// }catch (SQLException e){
// throw new RuntimeException(e);
// } catch (IOException e) {
// throw new RuntimeException(e);
// }
// }
}