package org.hibernate.test.cache.infinispan.util;
import java.util.concurrent.Callable;
import javax.transaction.SystemException;
import javax.transaction.TransactionManager;
import org.hibernate.Session;
import org.hibernate.SessionBuilder;
import org.hibernate.SessionFactory;
import org.hibernate.Transaction;
import org.hibernate.cache.infinispan.util.Caches;
import org.hibernate.engine.transaction.jta.platform.spi.JtaPlatform;
import org.hibernate.resource.transaction.spi.TransactionStatus;
/**
* @author Radim Vansa <rvansa@redhat.com>
*/
public final class TxUtil {
public static void withTxSession(boolean useJta, SessionFactory sessionFactory, ThrowingConsumer<Session, Exception> consumer) throws Exception {
JtaPlatform jtaPlatform = useJta ? sessionFactory.getSessionFactoryOptions().getServiceRegistry().getService(JtaPlatform.class) : null;
withTxSession(jtaPlatform, sessionFactory.withOptions(), consumer);
}
public static void withTxSession(JtaPlatform jtaPlatform, SessionBuilder sessionBuilder, ThrowingConsumer<Session, Exception> consumer) throws Exception {
if (jtaPlatform != null) {
TransactionManager tm = jtaPlatform.retrieveTransactionManager();
final SessionBuilder sb = sessionBuilder;
Caches.withinTx(tm, () -> {
withSession(sb, s -> {
consumer.accept(s);
// we need to flush the session before close when running with JTA transactions
s.flush();
});
return null;
});
} else {
withSession(sessionBuilder, s -> withResourceLocalTx(s, consumer));
}
}
public static <T> T withTxSessionApply(boolean useJta, SessionFactory sessionFactory, ThrowingFunction<Session, T, Exception> function) throws Exception {
JtaPlatform jtaPlatform = useJta ? sessionFactory.getSessionFactoryOptions().getServiceRegistry().getService(JtaPlatform.class) : null;
return withTxSessionApply(jtaPlatform, sessionFactory.withOptions(), function);
}
public static <T> T withTxSessionApply(JtaPlatform jtaPlatform, SessionBuilder sessionBuilder, ThrowingFunction<Session, T, Exception> function) throws Exception {
if (jtaPlatform != null) {
TransactionManager tm = jtaPlatform.retrieveTransactionManager();
Callable<T> callable = () -> withSessionApply(sessionBuilder, s -> {
T t = function.apply(s);
s.flush();
return t;
});
return Caches.withinTx(tm, callable);
} else {
return withSessionApply(sessionBuilder, s -> withResourceLocalTx(s, function));
}
}
public static <E extends Throwable> void withSession(SessionBuilder sessionBuilder, ThrowingConsumer<Session, E> consumer) throws E {
Session s = sessionBuilder.openSession();
try {
consumer.accept(s);
} finally {
s.close();
}
}
public static <R, E extends Throwable> R withSessionApply(SessionBuilder sessionBuilder, ThrowingFunction<Session, R, E> function) throws E {
Session s = sessionBuilder.openSession();
try {
return function.apply(s);
} finally {
s.close();
}
}
public static void withResourceLocalTx(Session session, ThrowingConsumer<Session, Exception> consumer) throws Exception {
Transaction transaction = session.beginTransaction();
boolean rollingBack = false;
try {
consumer.accept(session);
if (transaction.getStatus() == TransactionStatus.ACTIVE) {
transaction.commit();
} else {
rollingBack = true;
transaction.rollback();
}
} catch (Exception e) {
if (!rollingBack) {
try {
transaction.rollback();
} catch (Exception suppressed) {
e.addSuppressed(suppressed);
}
}
throw e;
}
}
public static <T> T withResourceLocalTx(Session session, ThrowingFunction<Session, T, Exception> consumer) throws Exception {
Transaction transaction = session.beginTransaction();
boolean rollingBack = false;
try {
T t = consumer.apply(session);
if (transaction.getStatus() == TransactionStatus.ACTIVE) {
transaction.commit();
} else {
rollingBack = true;
transaction.rollback();
}
return t;
} catch (Exception e) {
if (!rollingBack) {
try {
transaction.rollback();
} catch (Exception suppressed) {
e.addSuppressed(suppressed);
}
}
throw e;
}
}
public static void markRollbackOnly(boolean useJta, Session s) {
if (useJta) {
JtaPlatform jtaPlatform = s.getSessionFactory().getSessionFactoryOptions().getServiceRegistry().getService(JtaPlatform.class);
TransactionManager tm = jtaPlatform.retrieveTransactionManager();
try {
tm.setRollbackOnly();
} catch (SystemException e) {
throw new RuntimeException(e);
}
} else {
s.getTransaction().markRollbackOnly();
}
}
public interface ThrowingConsumer<T, E extends Throwable> {
void accept(T t) throws E;
}
public interface ThrowingFunction<T, R, E extends Throwable> {
R apply(T t) throws E;
}
}