package org.wikibrain.core.dao.sql; import org.apache.commons.lang3.StringUtils; import org.jooq.Table; import org.jooq.TableField; import org.jooq.tools.jdbc.JDBCUtils; import org.wikibrain.core.dao.DaoException; import org.wikibrain.utils.WpThreadUtils; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.SQLException; import java.sql.Timestamp; import java.text.DateFormat; import java.text.SimpleDateFormat; import java.util.Date; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.TimeUnit; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Bulk loads data in batch form to speed up insertions. * * @author Shilad Sen */ public class FastLoader { private static final int NUM_INSERTERS = Math.min(WpThreadUtils.getMaxThreads(), 4); private static final Object POSION_PILL = new Object(); private boolean isPostGisLoader = false; static final Logger LOG = LoggerFactory.getLogger(FastLoader.class); static final int BATCH_SIZE = 1000; private final WpDataSource ds; private final String table; private final String[] fields; private BlockingQueue<Object[]> rowBuffer = new ArrayBlockingQueue<Object[]>(BATCH_SIZE * NUM_INSERTERS * 2); static enum InserterState { RUNNING, // In normal working mode FAILED, // Loader failed, it cannot be used anymore SHUTTING_DOWN, // Parent thread triggered a shutdown SHUTDOWN // Already shutdown } private Thread [] inserters = new Thread[NUM_INSERTERS]; private volatile InserterState inserterState = null; public FastLoader(WpDataSource ds, TableField[] fields) throws DaoException { this(ds, fields[0].getTable().getName(), getFieldNames(fields)); } public FastLoader(WpDataSource ds, String table, String[] fields, boolean isPostGisLoader) throws DaoException { this(ds,table, fields); this.isPostGisLoader = isPostGisLoader; } public FastLoader(WpDataSource ds, String table, String[] fields) throws DaoException { this.ds = ds; this.table = table; this.fields = fields; for (int i = 0; i < inserters.length; i++) { inserters[i] = new Thread(new Runnable() { public void run() { try { insertBatches(); } catch (DaoException e) { LOG.error("inserter failed", e); inserterState = InserterState.FAILED; rowBuffer.clear(); // allow any existing puts to go through } catch (SQLException e) { LOG.error("inserter failed", e); inserterState = InserterState.FAILED; rowBuffer.clear(); // allow any existing puts to go through } catch (InterruptedException e) { LOG.error("inserter interrupted", e); inserterState = InserterState.FAILED; rowBuffer.clear(); // allow any existing puts to go through } } }); inserters[i].start(); } inserterState = InserterState.RUNNING; } private static String[] getFieldNames(TableField[] fields) { String names[] = new String[fields.length]; for (int i = 0; i < fields.length; i++) { names[i] = fields[i].getName(); } return names; } /** * Saves a value to the datastore. * @param values * @throws DaoException */ private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd hh:mm:ss"); public void load(Object ... values) throws DaoException { if (inserters == null || inserterState != InserterState.RUNNING) { throw new IllegalStateException("inserter thread in state " + inserterState); } // Hack convert dates to Timestamps for (int i = 0; i < values.length; i++) { if (values[i] instanceof Date && !(values[i] instanceof Timestamp)) { values[i] = new Timestamp(((Date)values[i]).getTime()); } } if (values.length != fields.length) { throw new IllegalArgumentException(); } try { rowBuffer.put(values); } catch (InterruptedException e) { throw new DaoException(e); } } private void insertBatches() throws DaoException, SQLException, InterruptedException { boolean finished = false; Connection cnx = ds.getConnection(); if (isPostGisLoader){ try { ((org.postgresql.PGConnection) cnx).addDataType("geometry", Class.forName("org.postgis.PGgeometry")); // ((org.postgresql.PGConnection) cnx).addDataType("geometry", Class.forName("org.postgis.Multipolygon")); }catch(ClassNotFoundException e){ throw new DaoException("Could not find PostGIS geometry type. Is the PostGIS library in the class path?: " + e.getMessage()); } } PreparedStatement statement = null; try { String [] names = new String[fields.length]; String [] questions = new String[fields.length]; for (int i = 0; i < fields.length; i++) { names[i] = fields[i]; questions[i] = "?"; } String sql = "INSERT INTO " + table + "(" + StringUtils.join(names, ",") + ") " + "VALUES (" + StringUtils.join(questions, ",") + ");"; statement = cnx.prepareStatement(sql); while (!finished && inserterState != InserterState.FAILED) { // accumulate batch int batchSize = 0; while (!finished && batchSize < BATCH_SIZE && inserterState != InserterState.FAILED) { Object row[] = rowBuffer.poll(100, TimeUnit.MILLISECONDS); if (row == null) { // do nothing } else if (row[0] == POSION_PILL) { rowBuffer.put(new Object[]{POSION_PILL}); finished = true; } else { batchSize++; for (int i = 0; i < row.length; i++) { if(row[i] != null && row[i].getClass().equals(java.lang.Character.class)) statement.setObject(i + 1, row[i].toString()); else statement.setObject(i + 1, row[i]); } statement.addBatch(); } } try { statement.executeBatch(); cnx.commit(); } catch (SQLException e) { cnx.rollback(); while (e != null) { LOG.error("insert batch failed, attempting to continue:", e); e = e.getNextException(); } } statement.clearBatch(); } } finally { if (statement != null) { JDBCUtils.safeClose(statement); } AbstractSqlDao.quietlyCloseConn(cnx); } } public void endLoad() throws DaoException { try { if (inserterState == InserterState.RUNNING) { rowBuffer.put(new Object[]{POSION_PILL}); } inserterState = InserterState.SHUTTING_DOWN; } catch (InterruptedException e) { throw new DaoException(e); } for (Thread inserter : inserters) { if (inserter != null) { try { inserter.join(60000); } catch (InterruptedException e) { throw new DaoException(e); } } } inserterState = InserterState.SHUTDOWN; } public void close() throws DaoException { endLoad(); } }