/*
* Copyright (C) 2015 The Pennsylvania State University and the University of Wisconsin
* Systems and Internet Infrastructure Security Laboratory
*
* Author: Damien Octeau
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.psu.cse.siis.ic3.db;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import edu.psu.cse.siis.coal.Constants;
public class StringTable extends Table {
private static final String INSERT = "INSERT INTO %s (st) VALUES (?)";
private static final String FIND = "SELECT id FROM %s WHERE st = ?";
private static final String BATCH_INSERT = INSERT;
private static final String BATCH_FIND = "SELECT id, st FROM %s WHERE st IN (?";
StringTable(String table) {
insertString = String.format(INSERT, table);
findString = String.format(FIND, table);
batchInsertString = String.format(BATCH_INSERT, table);
batchFindString = String.format(BATCH_FIND, table);
}
public Map<String, Integer> batchFind(Set<String> strings) throws SQLException {
Map<String, Integer> result = new HashMap<String, Integer>();
if (strings == null || strings.size() == 0 || (strings.size() == 1 && strings.contains(null))) {
return result;
}
StringBuilder queryBuilder = new StringBuilder(batchFindString);
for (int i = 1; i < strings.size(); ++i) {
queryBuilder.append(", ?");
}
queryBuilder.append(")");
PreparedStatement batchFindStatement =
getConnection().prepareStatement(queryBuilder.toString());
int parameterIndex = 1;
for (String string : strings) {
if (string == null) {
string = Constants.NULL_STRING;
}
batchFindStatement.setString(parameterIndex++, string);
}
ResultSet resultSet = batchFindStatement.executeQuery();
while (resultSet.next()) {
result.put(resultSet.getString("st"), resultSet.getInt("id"));
}
return result;
}
public Set<Integer> batchInsert(Set<String> strings, boolean[] allThere) throws SQLException {
if (strings == null || (strings.size() == 1 && strings.contains(null))) {
if (allThere != null) {
// System.out.println("TEST1");
allThere[0] = true;
}
return new HashSet<Integer>();
}
Map<String, Integer> alreadyThere = batchFind(strings);
// System.out.println("Found " + alreadyThere + " for " + strings);
Set<String> toBeInserted = new HashSet<String>(strings);
toBeInserted.removeAll(alreadyThere.keySet());
Set<Integer> result = new HashSet<Integer>(alreadyThere.values());
if (toBeInserted.size() == 0) {
if (allThere != null) {
// System.out.println("TEST2");
allThere[0] = true;
}
return result;
}
StringBuilder queryBuilder = new StringBuilder(batchInsertString);
for (int i = 1; i < toBeInserted.size(); ++i) {
queryBuilder.append(", (?)");
}
PreparedStatement batchInsertStatement =
getConnection().prepareStatement(queryBuilder.toString(), AUTOGENERATED_ID);
int parameterIndex = 1;
for (String string : toBeInserted) {
if (string == null) {
string = Constants.NULL_STRING;
}
batchInsertStatement.setString(parameterIndex++, string);
}
batchInsertStatement.executeUpdate();
ResultSet resultSet = batchInsertStatement.getGeneratedKeys();
while (resultSet.next()) {
result.add(resultSet.getInt(1));
}
// System.out.println("TEST3");
return result;
}
public int insert(String st) throws SQLException {
if (st == null) {
st = Constants.NULL_STRING;
}
int id = find(st);
if (id != NOT_FOUND) {
return id;
}
return forceInsert(st);
}
public int forceInsert(String st) throws SQLException {
if (insertStatement == null || insertStatement.isClosed()) {
insertStatement = getConnection().prepareStatement(insertString);
}
if (st == null) {
st = Constants.NULL_STRING;
}
insertStatement.setString(1, st);
if (insertStatement.executeUpdate() == 0) {
return NOT_FOUND;
}
return findAutoIncrement();
}
public int find(String st) throws SQLException {
if (findStatement == null || findStatement.isClosed()) {
findStatement = getConnection().prepareStatement(findString);
}
if (st == null) {
st = Constants.NULL_STRING;
}
findStatement.setString(1, st);
return processIntFindQuery(findStatement);
}
}