/* * Copyright (c) 2016 OBiBa. All rights reserved. * * This program and the accompanying materials * are made available under the terms of the GNU Public License v3.0. * * You should have received a copy of the GNU General Public License * along with this program. If not, see <http://www.gnu.org/licenses/>. */ package org.obiba.runtime.upgrade.support.jdbc; import java.io.IOException; import java.sql.SQLException; import javax.annotation.PostConstruct; import javax.sql.DataSource; import org.obiba.runtime.Version; import org.obiba.runtime.jdbc.DatabaseProduct; import org.obiba.runtime.jdbc.DatabaseProductRegistry; import org.obiba.runtime.upgrade.AbstractUpgradeStep; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.core.io.Resource; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.datasource.init.ScriptUtils; import org.springframework.test.jdbc.JdbcTestUtils; import org.springframework.transaction.annotation.Transactional; @Transactional public class SqlScriptUpgradeStep extends AbstractUpgradeStep { private static final Logger log = LoggerFactory.getLogger(SqlScriptUpgradeStep.class); private DataSource dataSource; private Resource scriptPath; private String scriptBasename; private Resource script; public SqlScriptUpgradeStep() { } public SqlScriptUpgradeStep(DataSource dataSource, String scriptBasename, Resource scriptPath) { this.dataSource = dataSource; this.scriptBasename = scriptBasename; this.scriptPath = scriptPath; } public void setDataSource(DataSource dataSource) { this.dataSource = dataSource; } public void setScriptBasename(String scriptBasename) { this.scriptBasename = scriptBasename; } public void setScriptPath(Resource scriptPath) { this.scriptPath = scriptPath; } @PostConstruct public void initialize() throws IOException { log.debug("Identifying database."); DatabaseProduct product = getDatabaseProduct(dataSource); log.debug("Database product is: {}", product); script = scriptPath.createRelative(getProductSpecificScriptName(product)); log.debug("Sql script {} exists {}", script.getDescription(), script.exists()); if(!script.exists()) { script = scriptPath.createRelative(getScriptName()); log.debug("Sql script {} exists {}", script.getDescription(), script.exists()); if(!script.exists()) { throw new IllegalStateException( "Cannot find sql script to execute. Script path '" + scriptPath + "' basename '" + scriptBasename + "' database product '" + product + "'."); } } } @Override public void execute(Version currentVersion) { log.info("Applying script {} to database.", script.getFilename()); executeScript(dataSource, script); } protected void executeScript(DataSource dataSource, Resource script) { try { ScriptUtils.executeSqlScript(dataSource.getConnection(), script); } catch(SQLException e) { log.info("Script execution failed {}.", e.getStackTrace()); } } protected DatabaseProduct getDatabaseProduct(DataSource dataSource) { return new DatabaseProductRegistry().getDatabaseProduct(dataSource); } protected String getProductSpecificScriptName(DatabaseProduct product) { return scriptBasename + "_" + product.getNormalizedName() + ".sql"; } protected String getScriptName() { return scriptBasename + ".sql"; } }