package org.hsweb.concurrent.lock.support.redis; import org.springframework.core.io.ClassPathResource; import org.springframework.data.redis.connection.StringRedisConnection; import org.springframework.data.redis.core.RedisCallback; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.data.redis.core.StringRedisTemplate; import org.springframework.data.redis.core.script.DefaultRedisScript; import org.springframework.scripting.support.ResourceScriptSource; import org.springframework.util.Assert; import java.util.*; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; /** * Created by zhouhao on 16-5-27. */ public class RedisReadWriteLock implements ReadWriteLock { static final String PREFIX = "lock:"; static final long DEFAULT_EXPIRE = 60; private ReadLock readLock; private WriteLock writeLock; private long lockKeyExpireTime = DEFAULT_EXPIRE; private long waitTime = 30; protected String lockValue; private String readLockKey, writeLockKey; private static DefaultRedisScript<Boolean> redisScriptRead; private static DefaultRedisScript<Boolean> redisScriptWrite; static { //初始化脚本 redisScriptRead = new DefaultRedisScript<>(); redisScriptWrite = new DefaultRedisScript<>(); redisScriptRead.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/vcheckAndsAdd.lua", RedisReadWriteLock.class))); redisScriptRead.setResultType(Boolean.class); redisScriptWrite.setScriptSource(new ResourceScriptSource(new ClassPathResource("scripts/scheckAndVset.lua", RedisReadWriteLock.class))); redisScriptWrite.setResultType(Boolean.class); } private StringRedisTemplate redisTemplate; public RedisReadWriteLock(String key, RedisTemplate redisTemplate) { Assert.notNull(key); Assert.notNull(redisTemplate); this.redisTemplate = new StringRedisTemplate(redisTemplate.getConnectionFactory()); readLockKey = PREFIX + key + ".read.lock"; writeLockKey = PREFIX + key + ".write.lock"; readLock = new ReadLock(); writeLock = new WriteLock(); lockValue = UUID.randomUUID().toString(); } @Override public Lock readLock() { return readLock; } @Override public Lock writeLock() { return writeLock; } private String getReadKey() { return readLockKey; } private String getWriteKey() { return writeLockKey; } protected void sleep() { try { Thread.sleep(waitTime); } catch (InterruptedException e) { } } public void setWaitTime(long waitTime) { this.waitTime = waitTime; } public void setLockKeyExpireTime(long lockKeyExpireTime) { this.lockKeyExpireTime = lockKeyExpireTime; } class ReadLock implements Lock { private List<String> keys = new ArrayList<>(); public ReadLock() { super(); keys.add(getWriteKey().toString()); keys.add(getReadKey().toString()); } public String lockValue() { return new String(lockValue).concat(Thread.currentThread().getId() + ""); } @Override public void lock() { while (true) { Boolean locked = redisTemplate.execute(redisScriptRead, keys, lockValue()); if (!locked) { sleep(); } else { /* * 此处增加对所有读锁的过期 * 1、防止项目停止,导致读锁一直存在 * * @TODO 后期可以抽出到 redisScriptRead脚本中 * */ expire(); break; } } } @Override public void lockInterruptibly() throws InterruptedException { boolean locked = redisTemplate.execute(redisScriptRead, keys, lockValue()); if (locked) { expire(); } else { throw new InterruptedException("could not get the read lock!"); } } @Override public boolean tryLock() { boolean locked = redisTemplate.execute(redisScriptRead, keys, lockValue()); if (locked) { expire(); } return locked; } @Override public boolean tryLock(long time, TimeUnit unit) throws InterruptedException { byte[] error = new byte[1]; boolean locked; long startWith = System.nanoTime(); do { locked = redisTemplate.execute(redisScriptRead, keys, lockValue()); if (locked) { expire(); break; } long now = System.nanoTime(); if (now - startWith > unit.toNanos(time)) { error[0] = 1; break; } sleep(); } while (!locked); if (error[0] == 1) { throw new InterruptedException("try lock time out!"); } return locked; } @Override public void unlock() { redisTemplate.execute((RedisCallback) conn -> { StringRedisConnection strConn = (StringRedisConnection) conn; Set<String> locks = strConn.sMembers(getReadKey()); if (locks == null || locks.size() == 0) return null; //当前读锁为自己持有 才解锁 if (locks.contains(lockValue())) { strConn.sRem(getReadKey(), lockValue()); } return null; }); } @Override public Condition newCondition() { throw new UnsupportedOperationException(); } private void expire() { redisTemplate.expire(getReadKey(), lockKeyExpireTime, TimeUnit.SECONDS); } } class WriteLock implements Lock { private List<String> keys = new ArrayList<>(); public WriteLock() { super(); keys.add(getReadKey()); keys.add(getWriteKey()); } @Override public void lock() { boolean locked; do { locked = redisTemplate.execute(redisScriptWrite, keys, lockValue); if (locked) { expire(); } else { sleep(); } } while (!locked); } @Override public void lockInterruptibly() throws InterruptedException { boolean locked = redisTemplate.execute(redisScriptWrite, keys, lockValue); if (locked) { expire(); } else { throw new InterruptedException(""); } } @Override public boolean tryLock() { boolean locked = redisTemplate.execute(redisScriptWrite, keys, lockValue); if (locked) { expire(); } return locked; } @Override public boolean tryLock(long time, TimeUnit unit) throws InterruptedException { byte[] error = new byte[1]; boolean locked; long startWith = System.nanoTime(); do { locked = redisTemplate.execute(redisScriptWrite, keys, lockValue); long now = System.nanoTime(); if (now - startWith > unit.toNanos(time)) { error[0] = 1; break; } sleep(); } while (!locked); if (locked) { expire(); } if (error[0] == 1) { throw new InterruptedException("lock time out!"); } return locked; } @Override public void unlock() { redisTemplate.execute((RedisCallback) conn -> { StringRedisConnection strConn = (StringRedisConnection) conn; strConn.del(getWriteKey()); return null; }); } @Override public Condition newCondition() { throw new UnsupportedOperationException(); } private void expire() { redisTemplate.expire(getWriteKey(), lockKeyExpireTime, TimeUnit.SECONDS); } } }