/* * Copyright (C) 2015 The Pennsylvania State University and the University of Wisconsin * Systems and Internet Infrastructure Security Laboratory * * Author: Damien Octeau * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package edu.psu.cse.siis.ic3.db; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Types; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; public abstract class TwoIntTable extends Table { private static final String INSERT = "INSERT INTO %s (%s, %s) VALUES (?, ?)"; private static final String FIND = "SELECT id FROM %s WHERE %s = ? AND %s = ?"; private static final String BATCH_INSERT = INSERT; private static final String BATCH_FIND = "SELECT %s, %s, %s FROM %s WHERE 1 = 0"; private final String firstColumn; private final String secondColumn; private final String batchInsertPattern; private final String batchFindPattern; TwoIntTable(String table, String firstColumn, String secondColumn) { insertString = String.format(INSERT, table, firstColumn, secondColumn); findString = String.format(FIND, table, firstColumn, secondColumn); batchInsertString = String.format(BATCH_INSERT, table, firstColumn, secondColumn); batchFindString = String.format(BATCH_FIND, ID, firstColumn, secondColumn, table); this.firstColumn = firstColumn; this.secondColumn = secondColumn; this.batchInsertPattern = String.format(", (?, ?)"); this.batchFindPattern = String.format(" OR (%s = ? AND %s = ?)", firstColumn, secondColumn); } public Map<Pair<Integer, Integer>, Integer> batchFind(Set<Pair<Integer, Integer>> values) throws SQLException { Map<Pair<Integer, Integer>, Integer> found = new HashMap<Pair<Integer, Integer>, Integer>(); if (values == null || values.size() == 0) { return found; } StringBuilder queryBuilder = new StringBuilder(batchFindString); for (int i = 0; i < values.size(); ++i) { queryBuilder.append(batchFindPattern); } PreparedStatement batchFindStatement = getConnection().prepareStatement(queryBuilder.toString()); int parameterIndex = 1; for (Pair<Integer, Integer> value : values) { batchFindStatement.setInt(parameterIndex++, value.getO1()); batchFindStatement.setInt(parameterIndex++, value.getO2()); } ResultSet resultSet = batchFindStatement.executeQuery(); while (resultSet.next()) { found .put( new Pair<Integer, Integer>(resultSet.getInt(firstColumn), resultSet .getInt(secondColumn)), resultSet.getInt(ID)); } return found; } public Set<Integer> batchInsert(Set<Pair<Integer, Integer>> values) throws SQLException { Map<Pair<Integer, Integer>, Integer> found = batchFind(values); Set<Pair<Integer, Integer>> toBeInserted = new HashSet<Pair<Integer, Integer>>(values); // Take the set difference. Obtain the values which have not been found; toBeInserted.removeAll(found.keySet()); Set<Integer> result = batchForceInsert(toBeInserted); result.addAll(found.values()); return result; } public Set<Integer> batchForceInsert(Set<Pair<Integer, Integer>> values) throws SQLException { Set<Integer> result = new HashSet<Integer>(); if (values.size() > 0) { StringBuilder queryBuilder = new StringBuilder(batchInsertString); for (int i = 1; i < values.size(); ++i) { queryBuilder.append(batchInsertPattern); } PreparedStatement batchInsertStatement = getConnection().prepareStatement(queryBuilder.toString(), AUTOGENERATED_ID); int parameterIndex = 1; for (Pair<Integer, Integer> value : values) { batchInsertStatement.setInt(parameterIndex++, value.getO1()); batchInsertStatement.setInt(parameterIndex++, value.getO2()); } batchInsertStatement.executeUpdate(); ResultSet resultSet = batchInsertStatement.getGeneratedKeys(); while (resultSet.next()) { result.add(resultSet.getInt(1)); } } return result; } public Set<Integer> batchForceInsert(Integer firstValue, List<Integer> values) throws SQLException { Set<Pair<Integer, Integer>> newValues = new HashSet<Pair<Integer, Integer>>(); if (values == null || values.size() == 0) { return new HashSet<Integer>(); } for (int value : values) { newValues.add(new Pair<Integer, Integer>(firstValue, value)); } return batchForceInsert(newValues); } public int insert(Integer firstValue, Integer secondValue) throws SQLException { int id = find(firstValue, secondValue); if (id != NOT_FOUND) { return id; } return forceInsert(firstValue, secondValue); } public int forceInsert(Integer firstValue, Integer secondValue) throws SQLException { if (insertStatement == null || insertStatement.isClosed()) { insertStatement = getConnection().prepareStatement(insertString); } if (firstValue != null) { insertStatement.setInt(1, firstValue); } else { insertStatement.setNull(1, Types.INTEGER); } if (secondValue != null) { insertStatement.setInt(2, secondValue); } else { insertStatement.setNull(2, Types.INTEGER); } if (insertStatement.executeUpdate() == 0) { return NOT_FOUND; } return findAutoIncrement(); } public int find(Integer firstValue, Integer secondValue) throws SQLException { if (findStatement == null || findStatement.isClosed()) { findStatement = getConnection().prepareStatement(findString); } if (firstValue != null) { findStatement.setInt(1, firstValue); } else { findStatement.setNull(1, Types.INTEGER); } if (secondValue != null) { findStatement.setInt(2, secondValue); } else { findStatement.setNull(2, Types.INTEGER); } return processIntFindQuery(findStatement); } }