/**
* Copyright 2014 Duan Bingnan
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.pinus4j.task;
import java.util.ArrayList;
import java.util.List;
import org.pinus4j.api.query.IQuery;
import org.pinus4j.cluster.IDBCluster;
import org.pinus4j.cluster.resources.GlobalDBResource;
import org.pinus4j.cluster.resources.IDBResource;
import org.pinus4j.cluster.resources.ShardingDBResource;
import org.pinus4j.datalayer.IRecordIterator;
import org.pinus4j.datalayer.iterator.GlobalRecordIterator;
import org.pinus4j.datalayer.iterator.ShardingRecordIterator;
import org.pinus4j.entity.DefaultEntityMetaManager;
import org.pinus4j.entity.IEntityMetaManager;
import org.pinus4j.exceptions.DBClusterException;
import org.pinus4j.exceptions.DBOperationException;
import org.pinus4j.exceptions.TaskException;
import org.pinus4j.utils.ThreadPool;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* 数据处理执行器.
*
* @author duanbn
*/
public class TaskExecutor<E> {
public static final Logger LOG = LoggerFactory.getLogger(TaskExecutor.class);
/**
* 处理线程池名称.
*/
private static final String THREADPOOL_NAME = "pinus";
/**
* 本次处理的数据对象
*/
private Class<E> clazz;
/**
* 数据库集群引用
*/
private IDBCluster dbCluster;
private IEntityMetaManager entityMetaManager = DefaultEntityMetaManager.getInstance();
public TaskExecutor(Class<E> clazz, IDBCluster dbCluster) {
this.clazz = clazz;
this.dbCluster = dbCluster;
}
public TaskFuture execute(ITask<E> task) {
return execute(task, null);
}
public TaskFuture execute(ITask<E> task, IQuery query) {
// 初始化任务.
try {
task.init();
} catch (Exception e) {
throw new TaskException(e);
}
// 创建线程池.
ThreadPool threadPool = ThreadPool.newInstance(THREADPOOL_NAME);
TaskFuture future = null;
String clusterName = entityMetaManager.getClusterName(clazz);
IRecordIterator<E> reader = null;
if (entityMetaManager.isShardingEntity(clazz)) { // 分片情况
List<IDBResource> dbResources;
try {
dbResources = this.dbCluster.getAllMasterShardingDBResource(clazz);
} catch (Exception e) {
throw new DBOperationException(e);
}
List<IRecordIterator<E>> readers = new ArrayList<IRecordIterator<E>>(dbResources.size());
// 计算总数
long total = 0;
for (IDBResource dbResource : dbResources) {
reader = new ShardingRecordIterator<E>((ShardingDBResource) dbResource, clazz);
if (task.taskBuffer() > 0) {
reader.setStep(task.taskBuffer());
}
reader.setQuery(query);
readers.add(reader);
total += reader.getCount();
}
future = new TaskFuture(total, threadPool, task);
future.addDBResource(dbResources);
for (IRecordIterator<E> r : readers) {
threadPool.submit(new RecrodReaderThread<E>(r, threadPool, task, future));
}
} else { // 全局情况
RecrodThread<E> rt = null;
IDBResource dbResource;
try {
dbResource = this.dbCluster.getMasterGlobalDBResource(clusterName,
entityMetaManager.getTableName(clazz));
} catch (DBClusterException e) {
throw new DBOperationException(e);
}
reader = new GlobalRecordIterator<E>((GlobalDBResource) dbResource, clazz);
if (task.taskBuffer() > 0) {
reader.setStep(task.taskBuffer());
}
reader.setQuery(query);
future = new TaskFuture(reader.getCount(), threadPool, task);
future.addDBResource(dbResource);
while (reader.hasNext()) {
List<E> record = reader.nextMore();
rt = new RecrodThread<E>(record, task, future);
threadPool.submit(rt);
}
}
return future;
}
/**
* 只是在数据分片情况下会被使用.
*
* @author duanbn
* @param <E>
*/
public static class RecrodReaderThread<E> implements Runnable {
private IRecordIterator<E> recordReader;
private ThreadPool threadPool;
private ITask<E> task;
private TaskFuture future;
public RecrodReaderThread(IRecordIterator<E> recordReader, ThreadPool threadPool, ITask<E> task,
TaskFuture future) {
this.recordReader = recordReader;
this.threadPool = threadPool;
this.task = task;
this.future = future;
}
@Override
public void run() {
RecrodThread<E> rt = null;
while (recordReader.hasNext()) {
List<E> record = recordReader.nextMore();
rt = new RecrodThread<E>(record, task, future);
threadPool.submit(rt);
}
}
}
/**
* 具体执行任务方法.
*
* @author duanbn
* @param <E>
*/
public static class RecrodThread<E> implements Runnable {
public static final Logger LOG = LoggerFactory.getLogger(RecrodThread.class);
private List<E> record;
private ITask<E> task;
private TaskFuture future;
public RecrodThread(List<E> record, ITask<E> task, TaskFuture future) {
this.record = record;
this.task = task;
this.future = future;
}
@Override
public void run() {
try {
this.task.batchRecord(record);
this.task.afterBatch();
} catch (Exception e) {
LOG.warn("do task failure " + record, e);
} finally {
this.future.down(record.size());
this.future.incrCount(record.size());
}
}
}
}