/* * Copyright (C) 2015 The Pennsylvania State University and the University of Wisconsin * Systems and Internet Infrastructure Security Laboratory * * Author: Damien Octeau * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package edu.psu.cse.siis.ic3.db; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; import edu.psu.cse.siis.coal.Constants; public class StringTable extends Table { private static final String INSERT = "INSERT INTO %s (st) VALUES (?)"; private static final String FIND = "SELECT id FROM %s WHERE st = ?"; private static final String BATCH_INSERT = INSERT; private static final String BATCH_FIND = "SELECT id, st FROM %s WHERE st IN (?"; StringTable(String table) { insertString = String.format(INSERT, table); findString = String.format(FIND, table); batchInsertString = String.format(BATCH_INSERT, table); batchFindString = String.format(BATCH_FIND, table); } public Map<String, Integer> batchFind(Set<String> strings) throws SQLException { Map<String, Integer> result = new HashMap<String, Integer>(); if (strings == null || strings.size() == 0 || (strings.size() == 1 && strings.contains(null))) { return result; } StringBuilder queryBuilder = new StringBuilder(batchFindString); for (int i = 1; i < strings.size(); ++i) { queryBuilder.append(", ?"); } queryBuilder.append(")"); PreparedStatement batchFindStatement = getConnection().prepareStatement(queryBuilder.toString()); int parameterIndex = 1; for (String string : strings) { if (string == null) { string = Constants.NULL_STRING; } batchFindStatement.setString(parameterIndex++, string); } ResultSet resultSet = batchFindStatement.executeQuery(); while (resultSet.next()) { result.put(resultSet.getString("st"), resultSet.getInt("id")); } return result; } public Set<Integer> batchInsert(Set<String> strings, boolean[] allThere) throws SQLException { if (strings == null || (strings.size() == 1 && strings.contains(null))) { if (allThere != null) { // System.out.println("TEST1"); allThere[0] = true; } return new HashSet<Integer>(); } Map<String, Integer> alreadyThere = batchFind(strings); // System.out.println("Found " + alreadyThere + " for " + strings); Set<String> toBeInserted = new HashSet<String>(strings); toBeInserted.removeAll(alreadyThere.keySet()); Set<Integer> result = new HashSet<Integer>(alreadyThere.values()); if (toBeInserted.size() == 0) { if (allThere != null) { // System.out.println("TEST2"); allThere[0] = true; } return result; } StringBuilder queryBuilder = new StringBuilder(batchInsertString); for (int i = 1; i < toBeInserted.size(); ++i) { queryBuilder.append(", (?)"); } PreparedStatement batchInsertStatement = getConnection().prepareStatement(queryBuilder.toString(), AUTOGENERATED_ID); int parameterIndex = 1; for (String string : toBeInserted) { if (string == null) { string = Constants.NULL_STRING; } batchInsertStatement.setString(parameterIndex++, string); } batchInsertStatement.executeUpdate(); ResultSet resultSet = batchInsertStatement.getGeneratedKeys(); while (resultSet.next()) { result.add(resultSet.getInt(1)); } // System.out.println("TEST3"); return result; } public int insert(String st) throws SQLException { if (st == null) { st = Constants.NULL_STRING; } int id = find(st); if (id != NOT_FOUND) { return id; } return forceInsert(st); } public int forceInsert(String st) throws SQLException { if (insertStatement == null || insertStatement.isClosed()) { insertStatement = getConnection().prepareStatement(insertString); } if (st == null) { st = Constants.NULL_STRING; } insertStatement.setString(1, st); if (insertStatement.executeUpdate() == 0) { return NOT_FOUND; } return findAutoIncrement(); } public int find(String st) throws SQLException { if (findStatement == null || findStatement.isClosed()) { findStatement = getConnection().prepareStatement(findString); } if (st == null) { st = Constants.NULL_STRING; } findStatement.setString(1, st); return processIntFindQuery(findStatement); } }