package org.zstack.core.db; import org.springframework.beans.factory.annotation.Autowire; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Configurable; import org.zstack.header.exception.CloudRuntimeException; import org.zstack.utils.DebugUtils; import org.zstack.utils.Utils; import org.zstack.utils.logging.CLogger; import javax.sql.DataSource; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.locks.ReentrantLock; /** */ @Configurable(preConstruction = true, autowire = Autowire.BY_TYPE) public class GLock { private static final CLogger logger = Utils.getLogger(GLock.class); private static final Map<String, ReentrantLock> memLocks = new HashMap<String, ReentrantLock>(); private DataSource dataSource; private Connection conn; private final String name; private final long timeout; private boolean success = false; private static final ThreadLocal<List<String>> isLocked = new ThreadLocal<List<String>>() { @Override protected List<String> initialValue() { return new ArrayList<String>(); } }; private boolean separateThreadEnabled; @Autowired private DatabaseFacade dbf; public GLock(String name, long timeout) { this.name = name; this.timeout = timeout; dataSource = dbf.getDataSource(); } public boolean isSeparateThreadEnabled() { return separateThreadEnabled; } public void setSeparateThreadEnabled(boolean separateThreadEnabled) { this.separateThreadEnabled = separateThreadEnabled; } private void checkInThread() { List<String> locks = isLocked.get(); if (locks.contains(name)) { throw new CloudRuntimeException(String.format("Thread[%s] has acquired lock[%s], you can NOT acquire the lock again before unlock, GLock is non reentrant", Thread.currentThread().getName(), name)); } locks.add(name); } private void checkOutThread() { List<String> locks = isLocked.get(); locks.remove(name); } public void lock() { if (separateThreadEnabled) { checkInThread(); } ReentrantLock mlock = null; if (separateThreadEnabled) { synchronized (memLocks) { mlock = memLocks.get(name); if (mlock == null) { mlock = new ReentrantLock(); memLocks.put(name, mlock); } if (memLocks.size() > 100) { logger.warn(String.format("there are more than 100 GLocks[num:%s] are created, something must be wrong in our program", memLocks.size())); } } } try { if (logger.isTraceEnabled()) { logger.trace(String.format("[GLock]: thread[%s] is acquiring lock[%s]", Thread.currentThread().getName(), name)); } if (separateThreadEnabled) { mlock.lock(); if (logger.isTraceEnabled()) { logger.trace(String.format("[GLock Memory Lock]: thread[%s] got memory lock[%s]", Thread.currentThread().getName(), name)); } } PreparedStatement pstmt = null; try { conn = dataSource.getConnection(); conn.setAutoCommit(true); pstmt = conn.prepareStatement(String.format("select get_lock('%s', %s)", name, timeout)); ResultSet rs = pstmt.executeQuery(); if (rs == null) { String err = "Unable to get DB lock: " + name + ", internal database error happened"; throw new CloudRuntimeException(err); } else if (rs.first() && rs.getInt(1) == 0) { throw new CloudRuntimeException(String.format("lock[%s] failed, timeout after %s seconds", name, timeout)); } if (logger.isTraceEnabled()) { logger.trace(String.format("[GLock DB Lock]: thread: %s got DB lock[%s], during timeout[%s secs]", Thread.currentThread().getName(), name, timeout)); } } catch (SQLException e) { throw new CloudRuntimeException(String.format("[GLock Error]: cannon get DB connection for lock[%s]", name), e); } finally { if (pstmt != null) { try { pstmt.close(); } catch (SQLException e) { logger.warn("Unable to close PreparedStatement for lock: " + name, e); } } } success = true; } catch (Throwable t) { if (conn != null) { try { conn.close(); } catch (SQLException e) { logger.warn(e.getMessage(), e); } } if (separateThreadEnabled) { mlock.unlock(); } success = false; if (separateThreadEnabled) { checkOutThread(); } if (!(t instanceof CloudRuntimeException)) { throw new CloudRuntimeException(t); } else { throw (CloudRuntimeException)t; } } } public void unlock() { if (!success) { if (logger.isTraceEnabled()) { logger.trace(String.format("[GLock]: skip unlock for thread[%s] on lock[%s], because previous lock() is not success", Thread.currentThread().getName(), name)); } return; } ReentrantLock lock = null; if (separateThreadEnabled) { synchronized (memLocks) { lock = memLocks.get(name); } } try { if (separateThreadEnabled) { DebugUtils.Assert(lock != null, String.format("cannot find LockWrapper for GLock[%s], is unlock mistakenly called twice???", name)); } if (logger.isTraceEnabled()) { logger.trace(String.format("[GLock]: thread[%s] is releasing lock[%s]", Thread.currentThread().getName(), name)); } PreparedStatement pstmt = null; try { pstmt = conn.prepareStatement(String.format("select release_lock('%s')", name)); ResultSet rs = pstmt.executeQuery(); if (rs == null) { throw new CloudRuntimeException("Mysql cannot find lock: " + name); } else if (rs.first() && rs.getInt(1) == 0) { String err = "Unable to release DB lock: " + name + ", lock: " + name + " is not held by this connection, internal error"; throw new CloudRuntimeException(err); } if (logger.isTraceEnabled()) { logger.trace(String.format("[GLock Release DB Lock] thread[%s] released DB lock[%s]", Thread.currentThread().getName(), name)); } } catch (SQLException e) { throw new CloudRuntimeException("Unable to release lock: " + name, e); } finally { if (pstmt != null) { try { pstmt.close(); } catch (SQLException e) { logger.warn("Unable to close PreparedStatement for lock: " + name, e); } } try { conn.close(); } catch (SQLException e) { logger.warn(e.getMessage(), e); } } } finally { if (separateThreadEnabled) { if (lock != null) { lock.unlock(); } } if (separateThreadEnabled) { checkOutThread(); } if (logger.isTraceEnabled()) { logger.trace(String.format("[GLock Release Memory Lock]: thread[%s] released memory lock[%s]", Thread.currentThread().getName(), name)); } } } }