/**
*
*/
package com.ganji.as.thrift.protocol.cluster.load.balance;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;
import com.ganji.as.thrift.protocol.client.request.ThriftClientInvocation;
import com.ganji.as.thrift.protocol.server.nodes.discovery.ServerNode;
import com.ganji.as.thrift.protocol.server.nodes.discovery.ServerNodeInfo;
import com.ganji.as.thrift.protocol.server.nodes.discovery.ServiceEndpoint;
/**
* @author yikangfeng
* @date 2015年7月22日
*/
public class WeightedRoundRobinLoadBalance extends AbstractLoadBalance {
private int currentIndex = -1;// 上一次选择的服务器
private int currentWeight = 0;// 当前调度的权值
private int maxWeight = 0; // 最大权重
private int gcdWeight = 0; // 所有服务器权重的最大公约数
private int serverCount = 0; // 服务器数量
private List<ServerNode> serverNodes_; // 服务器集合
public WeightedRoundRobinLoadBalance() {
currentIndex = -1;
currentWeight = 0;
}
/**
* 返回最大公约数
*
* @param a
* @param b
* @return
*/
private static int gcd(int a, int b) {
BigInteger b1 = new BigInteger(String.valueOf(a));
BigInteger b2 = new BigInteger(String.valueOf(b));
BigInteger gcd = b1.gcd(b2);
return gcd.intValue();
}
/**
* 返回所有服务器权重的最大公约数 37 * @param serverList 38 * @return 39
*/
private static int getGCDForServers(List<ServerNode> serverList) {
int w = 0;
for (int i = 0, len = serverList.size(); i < len - 1; i++) {
if (w == 0) {
w = gcd(serverList.get(i).getWeight(), serverList.get(i + 1)
.getWeight());
} else {
w = gcd(w, serverList.get(i + 1).getWeight());
}
}
return w;
}
/**
* 返回所有服务器中的最大权重
*
* @param serverList
* @return
*/
public static int getMaxWeightForServers(List<ServerNode> serverList) {
int w = 0;
for (int i = 0, len = serverList.size(); i < len - 1; i++) {
if (w == 0) {
w = Math.max(serverList.get(i).getWeight(),
serverList.get(i + 1).getWeight());
} else {
w = Math.max(w, serverList.get(i + 1).getWeight());
}
}
return w;
}
/**
* 算法流程: 假设有一组服务器 S = {S0, S1, …, Sn-1} 有相应的权重,变量currentIndex表示上次选择的服务器
* 权值currentWeight初始化为0,currentIndex初始化为-1 ,当第一次的时候返回 权值取最大的那个服务器, 通过权重的不断递减
* 寻找 适合的服务器返回,直到轮询结束,权值返回为0
*/
public ServerNode getAvailableServerNode() {
while (true) {
currentIndex = (currentIndex + 1) % serverCount;
if (currentIndex == 0) {
currentWeight = currentWeight - gcdWeight;
if (currentWeight <= 0) {
currentWeight = maxWeight;
if (currentWeight == 0)
return null;
}
}
if (serverNodes_.get(currentIndex).getWeight() >= currentWeight) {
return serverNodes_.get(currentIndex);
}
}
}
@Override
ServerNode doSelect(final List<ServerNode> serverNodes,final ThriftClientInvocation clientInvocation) {
// TODO Auto-generated method stub
synchronized (this) {
this.serverNodes_ = serverNodes;
this.serverCount = this.serverNodes_.size();
this.maxWeight = getMaxWeightForServers(this.serverNodes_);
this.gcdWeight = getGCDForServers(this.serverNodes_);
return this.getAvailableServerNode();
}
}
static public void main(String[] args) {
WeightedRoundRobinLoadBalance loadBalance = new WeightedRoundRobinLoadBalance();
List<ServerNode> serverNodes = new ArrayList<>();
ServerNodeInfo serverNodeInfo1 = new ServerNodeInfo();
ServiceEndpoint serviceEndpoint1 = new ServiceEndpoint();
serviceEndpoint1.setHost("host1");
serviceEndpoint1.setPort(8081);
serverNodeInfo1.setServiceEndpoint(serviceEndpoint1);
ServerNodeInfo serverNodeInfo2 = new ServerNodeInfo();
ServiceEndpoint serviceEndpoint2 = new ServiceEndpoint();
serviceEndpoint2.setHost("host2");
serviceEndpoint2.setPort(8082);
serverNodeInfo2.setServiceEndpoint(serviceEndpoint2);
ServerNodeInfo serverNodeInfo3 = new ServerNodeInfo();
ServiceEndpoint serviceEndpoint3 = new ServiceEndpoint();
serviceEndpoint3.setHost("host3");
serviceEndpoint3.setPort(8083);
serverNodeInfo3.setServiceEndpoint(serviceEndpoint3);
serverNodes.add(serverNodeInfo1);
serverNodes.add(serverNodeInfo2);
serverNodes.add(serverNodeInfo3);
for (int i = 0; i < 1000; ++i) {
try {
System.out.println(loadBalance.select(serverNodes,null));
} catch (Throwable e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
}
}