package edu.stanford.nlp.util;
import edu.stanford.nlp.util.logging.Redwood;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.util.ArgumentParser.Option;
import java.io.File;
import java.sql.*;
import java.util.*;
import java.util.zip.GZIPInputStream;
/** To query Google Ngrams counts from SQL in a memory efficient manner.
* To get count of a phrase, use GoogleNGramsSQLBacked.getCount(phrase). Set this class options using
* Execution.fillOptions(GoogleNGramsSQLBacked.class, props);
* Created by Sonal Gupta
*/
public class GoogleNGramsSQLBacked {
/** A logger for this class */
private static Redwood.RedwoodChannels log = Redwood.channels(GoogleNGramsSQLBacked.class);
@Option(name="populateTables")
static boolean populateTables = false;
@Option(name="ngramsToPopulate")
static Set<Integer> ngramsToPopulate = null;
@Option(name="dataDir")
static String dataDir ="/u/nlp/scr/data/google-ngrams/data";
@Option(name="googleNgram_hostname", gloss="where psql is located.")
static String googleNgram_hostname = "jonsson";
@Option(name="googleNgram_dbname", gloss="the database name")
static String googleNgram_dbname;
@Option(name="googleNgram_username")
static String googleNgram_username="nlp";
@Option(name="tablenamePrefix")
static String tablenamePrefix ="googlengrams_";
@Option(name="escapetag")
static String escapetag = "tag";
static Set<String> existingTablenames = null;
static Connection connection = null;
private static String DBName;
static void connect () throws SQLException{
if(connection == null) {
assert googleNgram_dbname != null : "set googleNgram_dbname variable through the properties file";
connection = DriverManager.getConnection(
"jdbc:postgresql://" + googleNgram_hostname + "/" + googleNgram_dbname, googleNgram_username, "");
}
}
static String escapeString(String str){
return "$"+escapetag+"$"+ str + "$"+escapetag+"$" ;
}
public static boolean existsTable(String tablename) throws SQLException {
if(existingTablenames == null){
existingTablenames = new HashSet<>();
DatabaseMetaData md = connection.getMetaData();
ResultSet rs = md.getTables(null, null, "%", null);
while (rs.next()) {
existingTablenames.add(rs.getString(3).toLowerCase());
}
}
return (existingTablenames.contains(tablename.toLowerCase()));
}
/**
* Queries the SQL tables for the count of the phrase.
* Returns -1 if the phrase doesn't exist
* @param str : phrase
* @return : count, if exists. -1 if not.
* @throws SQLException
*/
public static long getCount(String str) {
String query = null;
try{
connect();
str = str.trim();
if(str.contains("'")){
str = StringUtils.escapeString(str, new char[]{'\''},'\'');
}
int ngram = str.split("\\s+").length;
String table = tablenamePrefix + ngram;
if(!existsTable(table))
return -1;
String phrase = escapeString(str);
query = "select count from " + table + " where phrase='" + phrase+"';";
Statement stmt = connection.createStatement();
ResultSet result = stmt.executeQuery(query);
if(result.next()){
return result.getLong("count");
}else
return -1;
}catch(SQLException e){
log.info("Error getting count for " + str+ ". The query was " + query);
e.printStackTrace();
throw new RuntimeException(e);
}
}
public static List<Pair<String, Long>> getCounts(Collection<String> strs) throws SQLException {
connect();
List<Pair<String, Long>> counts = new ArrayList<>();
String query = "";
for(String str: strs) {
str = str.trim();
int ngram = str.split("\\s+").length;
String table = tablenamePrefix + ngram;
if (!existsTable(table)){
counts.add(new Pair(str, (long) -1));
continue;
}
String phrase = escapeString(str);
query += "select count from " + table + " where phrase='" + phrase + "';";
}
if(query.isEmpty())
return counts;
PreparedStatement stmt = connection.prepareStatement(query);
boolean isresult = stmt.execute();
ResultSet rs;
Iterator<String> iter = strs.iterator();
do {
rs = stmt.getResultSet();
String ph = iter.next();
if (rs.next()) {
counts.add(new Pair(ph, rs.getLong("count")));
} else
counts.add(new Pair(ph, (long) -1));
isresult = stmt.getMoreResults();
} while (isresult);
assert(counts.size() == strs.size());
return counts;
}
//Adding google ngrams to the tables for the first time
public static void populateTablesInSQL(String dir, Collection<Integer> typesOfPhrases) throws SQLException{
connect();
Statement stmt = connection.createStatement();
for(Integer n: typesOfPhrases) {
String table = tablenamePrefix + n;
if(!existsTable(table))
throw new RuntimeException("Table " + table + " does not exist in the database! Run the following commands in the psql prompt:" +
"create table GoogleNgrams_<NGRAM> (phrase text primary key not null, count bigint not null); create index phrase_<NGRAM> on GoogleNgrams_<NGRAM>(phrase);");
for(String line: IOUtils.readLines(new File(dir + "/" + n + "gms/vocab_cs.gz"), GZIPInputStream.class)){
String[] tok = line.split("\t");
String q = "INSERT INTO " + table + " (phrase, count) VALUES (" + escapeString(tok[0]) +" , " + tok[1]+");";
stmt.execute(q);
}
}
}
/** Note that this is really really slow for ngram > 1
* TODO: make this fast (if we had been using mysql we could have)
* **/
static public int getTotalCount(int ngram){
try{
connect();
Statement stmt = connection.createStatement();
String table = tablenamePrefix + ngram;
String q = "select count(*) from " + table+";";
ResultSet s = stmt.executeQuery(q);
if(s.next()){
return s.getInt(1);
} else
throw new RuntimeException("getting table count is not working!");
}
catch(SQLException e){
throw new RuntimeException("getting table count is not working! " + e);
}
}
//return rank of 1 gram in google ngeams if it is less than 20k. Otherwise -1.
public static int get1GramRank(String str){
String query = null;
try{
connect();
str = str.trim();
if(str.contains("'")){
str = StringUtils.escapeString(str, new char[]{'\''},'\'');
}
int ngram = str.split("\\s+").length;
if(ngram > 1)
return -1;
String table = "googlengrams_1_ranked20k";
if(!existsTable(table))
return -1;
String phrase = escapeString(str);
query = "select rank from " + table + " where phrase='" + phrase+"';";
Statement stmt = connection.createStatement();
ResultSet result = stmt.executeQuery(query);
if(result.next()){
return result.getInt("rank");
}else
return -1;
}catch(SQLException e){
log.info("Error getting count for " + str+ ". The query was " + query);
e.printStackTrace();
throw new RuntimeException(e);
}
}
static public void closeConnection() throws SQLException {
if(connection != null)
connection.close();
connection = null;
}
public static void main(String[] args){
try{
Properties props = StringUtils.argsToPropertiesWithResolve(args);
ArgumentParser.fillOptions(GoogleNGramsSQLBacked.class, props);
connect();
//if(populateTables)
// populateTablesInSQL(dataDir, ngramsToPopulate);
//testing
System.out.println("For head,the count is " + getCount("head"));
//System.out.println(getCount("what the heck"));
//System.out.println(getCount("my name is john"));
System.out.println(getCounts(Arrays.asList("cancer","disease")));
System.out.println("Get count 1 gram " + getTotalCount(1));
if(props.getProperty("phrase") != null) {
String p = props.getProperty("phrase");
System.out.println("count for phrase " + p + " is " + getCount(p));
}
if(props.getProperty("rank") != null){
String p = props.getProperty("rank");
System.out.println("Rank of " + p+ " is " + get1GramRank(p));
}
closeConnection();
}catch(Exception e){
e.printStackTrace();
}
}
public static void setDBName(String DBName) {
googleNgram_dbname = DBName;
}
}