package com.github.witoldsz.ultm.internal;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Optional;
import java.util.function.Consumer;
import javax.sql.DataSource;
import com.github.witoldsz.ultm.TxManager;
import com.github.witoldsz.ultm.UnitOfWork;
import com.github.witoldsz.ultm.UnitOfWorkCall;
import com.github.witoldsz.ultm.UnitOfWorkException;
/**
*
* @author witoldsz
*/
public class ThreadLocalTxManager implements TxManager, ConnectionProvider {
private static final WrappedConnection TRANSACTION_MARKER = new WrappedConnection(null);
private final ThreadLocal<WrappedConnection> connections = new ThreadLocal<>();
private final DataSource rawDataSource;
private final Consumer<Connection> connectionTuner;
private Optional<Runnable> afterRollbackListener = Optional.empty();
public ThreadLocalTxManager(DataSource rawDataSource, Consumer<Connection> connectionTuner) {
this.rawDataSource = rawDataSource;
this.connectionTuner = connectionTuner;
}
@Override
public void setAfterRollbackListener(Runnable listener) {
this.afterRollbackListener = Optional.ofNullable(listener);
}
@Override
public WrappedConnection get() throws SQLException {
WrappedConnection c = connections.get();
if (c == null) {
throw new IllegalStateException("Transaction is not active.");
}
if (c == TRANSACTION_MARKER) {
Connection rawConnection = rawDataSource.getConnection();
connections.set(c = new WrappedConnection(rawConnection));
if (c.getAutoCommit()) c.setAutoCommit(false); // just to make sure
connectionTuner.accept(c);
}
return c;
}
@Override
public <T> T txUnwrappedResult(UnitOfWorkCall<T> unit) throws Exception {
begin();
try {
T result = unit.call();
commit();
return result;
} catch (Exception e) {
rollback();
throw e;
}
}
@Override
public <T> T txResult(UnitOfWorkCall<T> unit) {
try {
return txUnwrappedResult(unit);
} catch (RuntimeException ex) {
throw ex;
} catch (Exception ex) {
throw new UnitOfWorkException(ex);
}
}
@Override
public void txUnwrapped(UnitOfWork unit) throws Exception {
txUnwrappedResult(() -> { unit.run(); return null;});
}
public void tx(UnitOfWork unit) {
txResult(() -> {unit.run(); return null;});
}
@Override
public void begin() {
throwIfAlreadyAssigned();
connections.set(TRANSACTION_MARKER);
}
@Override
public void commit() {
pullDelegatedConnection().ifPresent( delegated -> {
try {
delegated.commit();
delegated.close();
} catch (SQLException ex) {
throw new UnitOfWorkException(ex);
}
});
}
@Override
public void rollback() {
pullDelegatedConnection().ifPresent( delegated -> {
try {
delegated.rollback();
delegated.close();
} catch (SQLException ex) {
throw new UnitOfWorkException(ex);
} finally {
afterRollbackListener.ifPresent(Runnable::run);
}
});
}
private Optional<Connection> pullDelegatedConnection() {
WrappedConnection c = connections.get();
if (c == null) {
throw new IllegalStateException("Transaction is not active.");
}
connections.remove();
return c == TRANSACTION_MARKER ? Optional.empty() : Optional.of(c.getDelegate());
}
private void throwIfAlreadyAssigned() {
if (connections.get() != null) {
throw new IllegalStateException("Transaction is in progress already.");
}
}
}