/*
* Copyright 1999-2012 Alibaba Group.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* (created at 2011-12-19)
*/
package com.alibaba.cobar.mysql.bio.executor;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.log4j.Logger;
import com.alibaba.cobar.config.ErrorCode;
import com.alibaba.cobar.exception.UnknownPacketException;
import com.alibaba.cobar.mysql.bio.Channel;
import com.alibaba.cobar.mysql.bio.MySQLChannel;
import com.alibaba.cobar.net.mysql.BinaryPacket;
import com.alibaba.cobar.net.mysql.ErrorPacket;
import com.alibaba.cobar.net.mysql.OkPacket;
import com.alibaba.cobar.route.RouteResultsetNode;
import com.alibaba.cobar.server.ServerConnection;
import com.alibaba.cobar.server.session.BlockingSession;
/**
* @author <a href="mailto:shuo.qius@alibaba-inc.com">QIU Shuo</a>
*/
public class DefaultCommitExecutor extends NodeExecutor {
private static final Logger LOGGER = Logger.getLogger(DefaultCommitExecutor.class);
private AtomicBoolean isFail = new AtomicBoolean(false);
private int nodeCount;
private final ReentrantLock lock = new ReentrantLock();
private final Condition taskFinished = lock.newCondition();
private volatile OkPacket indicatedOK;
protected Logger getLogger() {
return LOGGER;
}
protected String getErrorMessage() {
return "commit";
}
@Override
public void terminate() throws InterruptedException {
final ReentrantLock lock = this.lock;
lock.lock();
try {
while (nodeCount > 0) {
taskFinished.await();
}
} finally {
lock.unlock();
}
}
private void decrementCountToZero() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
nodeCount = 0;
taskFinished.signalAll();
} finally {
lock.unlock();
}
}
/**
* @param finish how many tasks finished
* @return is this last task
*/
private boolean decrementCountBy(int finished) {
ReentrantLock lock = this.lock;
lock.lock();
try {
boolean last = (nodeCount -= finished) <= 0;
taskFinished.signalAll();
return last;
} finally {
lock.unlock();
}
}
/**
* 提交事务
*/
public void commit(final OkPacket packet, final BlockingSession session, final int initCount) {
// 初始化
final ReentrantLock lock = this.lock;
lock.lock();
try {
this.isFail.set(false);
this.nodeCount = initCount;
this.indicatedOK = packet;
} finally {
lock.unlock();
}
if (session.getSource().isClosed()) {
decrementCountToZero();
return;
}
// 执行
final ConcurrentMap<RouteResultsetNode, Channel> target = session.getTarget();
Executor committer = session.getSource().getProcessor().getCommitter();
int started = 0;
for (RouteResultsetNode rrn : target.keySet()) {
if (rrn == null) {
try {
getLogger().error(
"null is contained in RoutResultsetNodes, source = " + session.getSource()
+ ", bindChannel = " + target);
} catch (Exception e) {
}
continue;
}
final MySQLChannel mc = (MySQLChannel) target.get(rrn);
if (mc != null) {
mc.setRunning(true);
committer.execute(new Runnable() {
@Override
public void run() {
_commit(mc, session);
}
});
++started;
}
}
if (started < initCount && decrementCountBy(initCount - started)) {
/**
* assumption: only caused by front-end connection close. <br/>
* Otherwise, packet must be returned to front-end
*/
session.clear();
}
}
private void _commit(MySQLChannel mc, BlockingSession session) {
ServerConnection source = session.getSource();
if (isFail.get() || source.isClosed()) {
mc.setRunning(false);
try {
throw new Exception("other task fails, commit cancel");
} catch (Exception e) {
handleException(mc, session, e);
}
return;
}
try {
BinaryPacket bin = mc.commit();
switch (bin.data[0]) {
case OkPacket.FIELD_COUNT:
mc.setRunning(false);
if (decrementCountBy(1)) {
try {
if (isFail.get()) { // some other tasks failed
session.clear();
source.writeErrMessage(ErrorCode.ER_YES, getErrorMessage() + " error!");
} else { // all tasks are successful
session.release();
if (indicatedOK != null) {
indicatedOK.write(source);
} else {
ByteBuffer buffer = source.allocate();
source.write(bin.write(buffer, source));
}
}
} catch (Exception e) {
getLogger().warn("exception happens in success notification: " + source, e);
}
}
break;
case ErrorPacket.FIELD_COUNT:
mc.setRunning(false);
isFail.set(true);
if (decrementCountBy(1)) {
try {
session.clear();
getLogger().warn(mc.getErrLog(getErrorMessage(), mc.getErrMessage(bin), source));
ByteBuffer buffer = source.allocate();
source.write(bin.write(buffer, source));
} catch (Exception e) {
getLogger().warn("exception happens in failure notification: " + source, e);
}
}
break;
default:
throw new UnknownPacketException(bin.toString());
}
} catch (IOException e) {
mc.close();
handleException(mc, session, e);
} catch (RuntimeException e) {
mc.close();
handleException(mc, session, e);
}
}
private void handleException(Channel mc, BlockingSession session, Exception e) {
isFail.set(true);
if (decrementCountBy(1)) {
try {
session.clear();
ServerConnection sc = session.getSource();
getLogger().warn(new StringBuilder().append(sc).append(mc).append(getErrorMessage()).toString(), e);
String msg = e.getMessage();
sc.writeErrMessage(ErrorCode.ER_YES, msg == null ? e.getClass().getSimpleName() : msg);
} catch (Exception e2) {
getLogger().warn("exception happens in failure notification: " + session.getSource(), e2);
}
}
}
}