/* * Copyright 2014 mango.jfaster.org * * The Mango Project licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package org.jfaster.mango.sharding; import com.google.common.collect.HashMultiset; import com.google.common.collect.Multiset; import org.jfaster.mango.annotation.*; import org.jfaster.mango.datasource.SimpleDataSourceFactory; import org.jfaster.mango.operator.Mango; import org.jfaster.mango.support.DataSourceConfig; import org.jfaster.mango.support.Randoms; import org.jfaster.mango.support.Table; import org.jfaster.mango.support.model4table.Msg; import org.junit.Before; import org.junit.Test; import javax.sql.DataSource; import java.sql.Connection; import java.util.ArrayList; import java.util.List; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.*; /** * 测试数据源路由 * * @author ash */ public class DatabaseSharding2Test { private static Mango mango; private static String[] dsns = new String[]{"ds1", "ds2", "ds3"}; @Before public void before() throws Exception { mango = Mango.newInstance(); for (int i = 0; i < 3; i++) { DataSource ds = DataSourceConfig.getDataSource(i + 1); Connection conn = ds.getConnection(); Table.MSG.load(conn); conn.close(); mango.addDataSourceFactory(new SimpleDataSourceFactory(dsns[i], ds)); } } @Test public void testRandomPartition() { MsgDao dao = mango.create(MsgDao.class); int num = 100; List<Msg> msgs = Msg.createRandomMsgs(num); for (Msg msg : msgs) { int id = dao.insert(msg); assertThat(id, greaterThan(0)); msg.setId(id); } check(msgs, dao); for (Msg msg : msgs) { msg.setContent(Randoms.randomString(20)); } dao.batchUpdate(msgs); check(msgs, dao); } @Test public void testOnePartition() { MsgDao dao = mango.create(MsgDao.class); int num = 10; int uid = 100; List<Msg> msgs = new ArrayList<Msg>(); for (int i = 0; i < num; i++) { Msg msg = new Msg(); msg.setUid(uid); msg.setContent(Randoms.randomString(20)); msgs.add(msg); int id = dao.insert(msg); msg.setId(id); } check(msgs, dao); for (Msg msg : msgs) { msg.setContent(Randoms.randomString(20)); } dao.batchUpdate(msgs); check(msgs, dao); } private void check(List<Msg> msgs, MsgDao dao) { List<Msg> dbMsgs = new ArrayList<Msg>(); Multiset<Integer> ms = HashMultiset.create(); for (Msg msg : msgs) { ms.add(msg.getUid()); } for (Multiset.Entry<Integer> entry : ms.entrySet()) { dbMsgs.addAll(dao.getMsgs(entry.getElement())); } assertThat(dbMsgs, hasSize(msgs.size())); assertThat(dbMsgs, containsInAnyOrder(msgs.toArray())); } @DB(table = "msg") @Sharding(databaseShardingStrategy = MyDatabaseShardingStrategy.class) interface MsgDao { @ReturnGeneratedId @SQL("insert into #table(uid, content) values(:1.uid, :1.content)") int insert(@ShardingBy("uid") Msg msg); @SQL("update #table set content=:1.content where id=:1.id and uid=:1.uid") public int[] batchUpdate(@ShardingBy("uid") List<Msg> msgs); @SQL("select id, uid, content from #table where uid=:1") public List<Msg> getMsgs(@ShardingBy int uid); } public static class MyDatabaseShardingStrategy implements DatabaseShardingStrategy<Integer> { @Override public String getDataSourceFactoryName(Integer uid) { int tail = uid % 10; if (tail >= 0 && tail <= 2) { return dsns[0]; } else if (tail >= 3 && tail <= 5) { return dsns[1]; } else { return dsns[2]; } } } }