/*
* Copyright (C) 2015 The Pennsylvania State University and the University of Wisconsin
* Systems and Internet Infrastructure Security Laboratory
*
* Author: Damien Octeau
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.psu.cse.siis.ic3.db;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
public abstract class TwoIntTable extends Table {
private static final String INSERT = "INSERT INTO %s (%s, %s) VALUES (?, ?)";
private static final String FIND = "SELECT id FROM %s WHERE %s = ? AND %s = ?";
private static final String BATCH_INSERT = INSERT;
private static final String BATCH_FIND = "SELECT %s, %s, %s FROM %s WHERE 1 = 0";
private final String firstColumn;
private final String secondColumn;
private final String batchInsertPattern;
private final String batchFindPattern;
TwoIntTable(String table, String firstColumn, String secondColumn) {
insertString = String.format(INSERT, table, firstColumn, secondColumn);
findString = String.format(FIND, table, firstColumn, secondColumn);
batchInsertString = String.format(BATCH_INSERT, table, firstColumn, secondColumn);
batchFindString = String.format(BATCH_FIND, ID, firstColumn, secondColumn, table);
this.firstColumn = firstColumn;
this.secondColumn = secondColumn;
this.batchInsertPattern = String.format(", (?, ?)");
this.batchFindPattern = String.format(" OR (%s = ? AND %s = ?)", firstColumn, secondColumn);
}
public Map<Pair<Integer, Integer>, Integer> batchFind(Set<Pair<Integer, Integer>> values)
throws SQLException {
Map<Pair<Integer, Integer>, Integer> found = new HashMap<Pair<Integer, Integer>, Integer>();
if (values == null || values.size() == 0) {
return found;
}
StringBuilder queryBuilder = new StringBuilder(batchFindString);
for (int i = 0; i < values.size(); ++i) {
queryBuilder.append(batchFindPattern);
}
PreparedStatement batchFindStatement =
getConnection().prepareStatement(queryBuilder.toString());
int parameterIndex = 1;
for (Pair<Integer, Integer> value : values) {
batchFindStatement.setInt(parameterIndex++, value.getO1());
batchFindStatement.setInt(parameterIndex++, value.getO2());
}
ResultSet resultSet = batchFindStatement.executeQuery();
while (resultSet.next()) {
found
.put(
new Pair<Integer, Integer>(resultSet.getInt(firstColumn), resultSet
.getInt(secondColumn)), resultSet.getInt(ID));
}
return found;
}
public Set<Integer> batchInsert(Set<Pair<Integer, Integer>> values) throws SQLException {
Map<Pair<Integer, Integer>, Integer> found = batchFind(values);
Set<Pair<Integer, Integer>> toBeInserted = new HashSet<Pair<Integer, Integer>>(values);
// Take the set difference. Obtain the values which have not been found;
toBeInserted.removeAll(found.keySet());
Set<Integer> result = batchForceInsert(toBeInserted);
result.addAll(found.values());
return result;
}
public Set<Integer> batchForceInsert(Set<Pair<Integer, Integer>> values) throws SQLException {
Set<Integer> result = new HashSet<Integer>();
if (values.size() > 0) {
StringBuilder queryBuilder = new StringBuilder(batchInsertString);
for (int i = 1; i < values.size(); ++i) {
queryBuilder.append(batchInsertPattern);
}
PreparedStatement batchInsertStatement =
getConnection().prepareStatement(queryBuilder.toString(), AUTOGENERATED_ID);
int parameterIndex = 1;
for (Pair<Integer, Integer> value : values) {
batchInsertStatement.setInt(parameterIndex++, value.getO1());
batchInsertStatement.setInt(parameterIndex++, value.getO2());
}
batchInsertStatement.executeUpdate();
ResultSet resultSet = batchInsertStatement.getGeneratedKeys();
while (resultSet.next()) {
result.add(resultSet.getInt(1));
}
}
return result;
}
public Set<Integer> batchForceInsert(Integer firstValue, List<Integer> values)
throws SQLException {
Set<Pair<Integer, Integer>> newValues = new HashSet<Pair<Integer, Integer>>();
if (values == null || values.size() == 0) {
return new HashSet<Integer>();
}
for (int value : values) {
newValues.add(new Pair<Integer, Integer>(firstValue, value));
}
return batchForceInsert(newValues);
}
public int insert(Integer firstValue, Integer secondValue) throws SQLException {
int id = find(firstValue, secondValue);
if (id != NOT_FOUND) {
return id;
}
return forceInsert(firstValue, secondValue);
}
public int forceInsert(Integer firstValue, Integer secondValue) throws SQLException {
if (insertStatement == null || insertStatement.isClosed()) {
insertStatement = getConnection().prepareStatement(insertString);
}
if (firstValue != null) {
insertStatement.setInt(1, firstValue);
} else {
insertStatement.setNull(1, Types.INTEGER);
}
if (secondValue != null) {
insertStatement.setInt(2, secondValue);
} else {
insertStatement.setNull(2, Types.INTEGER);
}
if (insertStatement.executeUpdate() == 0) {
return NOT_FOUND;
}
return findAutoIncrement();
}
public int find(Integer firstValue, Integer secondValue) throws SQLException {
if (findStatement == null || findStatement.isClosed()) {
findStatement = getConnection().prepareStatement(findString);
}
if (firstValue != null) {
findStatement.setInt(1, firstValue);
} else {
findStatement.setNull(1, Types.INTEGER);
}
if (secondValue != null) {
findStatement.setInt(2, secondValue);
} else {
findStatement.setNull(2, Types.INTEGER);
}
return processIntFindQuery(findStatement);
}
}