/*
* Copyright 2014 mango.jfaster.org
*
* The Mango Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package org.jfaster.mango.operator;
import org.jfaster.mango.binding.BoundSql;
import org.jfaster.mango.binding.InvocationContext;
import org.jfaster.mango.descriptor.MethodDescriptor;
import org.jfaster.mango.exception.DescriptionException;
import org.jfaster.mango.parser.ASTRootNode;
import org.jfaster.mango.stat.OneExecuteStat;
import org.jfaster.mango.transaction.*;
import org.jfaster.mango.util.Iterables;
import org.jfaster.mango.util.ToStringHelper;
import javax.sql.DataSource;
import java.util.*;
/**
* @author ash
*/
public class BatchUpdateOperator extends AbstractOperator {
protected Transformer transformer;
public BatchUpdateOperator(ASTRootNode rootNode, MethodDescriptor md, Config config) {
super(rootNode, md, config);
transformer = TRANSFORMERS.get(md.getReturnRawType());
if (transformer == null) {
String expected = ToStringHelper.toString(TRANSFORMERS.keySet());
throw new DescriptionException("the return type of batch update " +
"expected one of " + expected + " but " + md.getReturnRawType());
}
}
@Override
public Object execute(Object[] values, OneExecuteStat stat) {
Iterables iterables = getIterables(values);
if (iterables.isEmpty()) {
return transformer.transform(new int[]{});
}
Map<DataSource, Group> gorupMap = new HashMap<DataSource, Group>();
int t = 0;
for (Object obj : iterables) {
InvocationContext context = invocationContextFactory.newInvocationContext(new Object[]{obj});
group(context, gorupMap, t++);
}
int[] ints = executeDb(gorupMap, t, stat);
return transformer.transform(ints);
}
protected void group(InvocationContext context, Map<DataSource, Group> groupMap, int position) {
context.setGlobalTable(tableGenerator.getTable(context));
DataSource ds = dataSourceGenerator.getDataSource(context, daoClass);
Group group = groupMap.get(ds);
if (group == null) {
group = new Group();
groupMap.put(ds, group);
}
rootNode.render(context);
BoundSql boundSql = context.getBoundSql();
invocationInterceptorChain.intercept(boundSql, context, ds); // 拦截器
group.add(boundSql, position);
}
protected Iterables getIterables(Object[] values) {
Object firstValue = values[0];
if (firstValue == null) {
throw new NullPointerException("batchUpdate's parameter can't be null");
}
Iterables iterables = new Iterables(firstValue);
return iterables;
}
protected int[] executeDb(Map<DataSource, Group> groupMap, int batchNum, OneExecuteStat stat) {
int[] r = new int[batchNum];
long now = System.nanoTime();
int t = 0;
try {
for (Map.Entry<DataSource, Group> entry : groupMap.entrySet()) {
DataSource ds = entry.getKey();
List<BoundSql> boundSqls = entry.getValue().getBoundSqls();
List<Integer> positions = entry.getValue().getPositions();
int[] ints = config.isUseTransactionForBatchUpdate() ?
useTransactionBatchUpdate(ds, boundSqls) :
jdbcOperations.batchUpdate(ds, boundSqls);
for (int i = 0; i < ints.length; i++) {
r[positions.get(i)] = ints[i];
}
t++;
}
} finally {
long cost = System.nanoTime() - now;
if (t == groupMap.entrySet().size()) {
stat.recordDatabaseExecuteSuccess(cost);
} else {
stat.recordDatabaseExecuteException(cost);
}
}
return r;
}
private int[] useTransactionBatchUpdate(DataSource ds, List<BoundSql> boundSqls) {
int[] ints;
Transaction transaction = TransactionFactory.newTransaction(ds);
try {
ints = jdbcOperations.batchUpdate(ds, boundSqls);
} catch (RuntimeException e) {
transaction.rollback();
throw e;
}
transaction.commit();
return ints;
}
protected static class Group {
private List<BoundSql> boundSqls = new LinkedList<BoundSql>();
private List<Integer> positions = new LinkedList<Integer>();
public void add(BoundSql boundSql, int position) {
boundSqls.add(boundSql);
positions.add(position);
}
public List<BoundSql> getBoundSqls() {
return boundSqls;
}
public List<Integer> getPositions() {
return positions;
}
}
private final static Map<Class, Transformer> TRANSFORMERS = new LinkedHashMap<Class, Transformer>();
static {
TRANSFORMERS.put(void.class, VoidTransformer.INSTANCE);
TRANSFORMERS.put(int.class, IntegerTransformer.INSTANCE);
TRANSFORMERS.put(int[].class, IntArrayTransformer.INSTANCE);
TRANSFORMERS.put(Void.class, VoidTransformer.INSTANCE);
TRANSFORMERS.put(Integer.class, IntegerTransformer.INSTANCE);
TRANSFORMERS.put(Integer[].class, IntegerArrayTransformer.INSTANCE);
}
public interface Transformer {
Object transform(int[] s);
}
enum IntArrayTransformer implements Transformer {
INSTANCE;
@Override
public Object transform(int[] s) {
return s;
}
}
enum IntegerArrayTransformer implements Transformer {
INSTANCE;
@Override
public Object transform(int[] s) {
Integer[] r = new Integer[s.length];
for (int i = 0; i < s.length; i++) {
r[i] = s[i];
}
return r;
}
}
enum IntegerTransformer implements Transformer {
INSTANCE;
@Override
public Object transform(int[] s) {
int r = 0;
for (int e : s) {
r += e;
}
return r;
}
}
enum VoidTransformer implements Transformer {
INSTANCE;
@Override
public Object transform(int[] s) {
return null;
}
}
}