/** * CopyRight by Chinamobile * * WorkerAgentForJob.java */ package com.chinamobile.bcbsp.workermanager; import java.io.IOException; import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.ipc.RPC; import org.apache.hadoop.ipc.RPC.Server; import com.chinamobile.bcbsp.Constants; import com.chinamobile.bcbsp.api.AggregateValue; import com.chinamobile.bcbsp.api.Aggregator; import com.chinamobile.bcbsp.sync.SuperStepReportContainer; import com.chinamobile.bcbsp.sync.WorkerSSController; import com.chinamobile.bcbsp.sync.WorkerSSControllerInterface; import com.chinamobile.bcbsp.util.BSPJob; import com.chinamobile.bcbsp.util.BSPJobID; import com.chinamobile.bcbsp.util.StaffAttemptID; /** * WorkerAgentForJob. * * It is create by WorkerManager for every job that running * on it. This class manages all staffs which belongs to the same job, * maintains public information, completes the local synchronization * and aggregation. * * @author * @version */ public class WorkerAgentForJob implements WorkerAgentInterface { private Map<StaffAttemptID, SuperStepReportContainer> runningStaffInformation = new HashMap<StaffAttemptID, SuperStepReportContainer>(); private volatile Integer staffReportCounter = 0; private WorkerSSControllerInterface wssc; private String workerManagerName; private int workerManagerNum = 0; private HashMap<Integer, String> partitionToWorkerManagerName = new HashMap<Integer, String>(); private final Map<InetSocketAddress, WorkerAgentInterface> workers = new ConcurrentHashMap<InetSocketAddress, WorkerAgentInterface>(); private int portForJob; private InetSocketAddress workAddress; private Server server = null; private Configuration conf; private static final Log LOG = LogFactory.getLog(WorkerAgentForJob.class); private BSPJobID jobId; private BSPJob jobConf; private WorkerManager workerManager; // For Aggregation /** Map for user registered aggregate values. */ private HashMap<String, Class<? extends AggregateValue<?>>> nameToAggregateValue = new HashMap<String, Class<? extends AggregateValue<?>>>(); /** Map for user registered aggregatros. */ private HashMap<String, Class<? extends Aggregator<?>>> nameToAggregator = new HashMap<String, Class<? extends Aggregator<?>>>(); @SuppressWarnings("unchecked") private HashMap<String, ArrayList<AggregateValue>> aggregateValues = new HashMap<String, ArrayList<AggregateValue>>(); @SuppressWarnings("unchecked") private HashMap<String, AggregateValue> aggregateResults = new HashMap<String, AggregateValue>(); public WorkerAgentForJob(Configuration conf, BSPJobID jobId, BSPJob jobConf, WorkerManager workerManager) throws IOException { this.jobId = jobId; this.jobConf = jobConf; this.workerManager = workerManager; this.workerManagerName = conf.get(Constants.BC_BSP_WORKERAGENT_HOST, Constants.BC_BSP_WORKERAGENT_HOST); this.wssc = new WorkerSSController(jobId, this.workerManagerName); this.conf = conf; String bindAddress = conf.get(Constants.BC_BSP_WORKERAGENT_HOST, Constants.DEFAULT_BC_BSP_WORKERAGENT_HOST); int bindPort = conf.getInt(Constants.BC_BSP_WORKERAGENT_PORT, Constants.DEFAULT_BC_BSP_WORKERAGENT_PORT); bindPort = bindPort + Integer.parseInt(jobId.toString().substring(17)); portForJob = bindPort; // network e.g. ip address, port. workAddress = new InetSocketAddress(bindAddress, bindPort); reinitialize(); // For Aggregation loadAggregators(); } public void reinitialize() { try { LOG.info("reinitialize() the WorkerAgentForJob: " + jobId.toString()); server = RPC.getServer(this, workAddress.getHostName(), workAddress.getPort(), conf); server.start(); LOG.info("WorkerAgent address:" + workAddress.getHostName() + " port:" + workAddress.getPort()); } catch (IOException e) { LOG.error("[reinitialize]", e); } } public WorkerAgentForJob(WorkerSSControllerInterface wssci) { this.wssc = wssci; } /** * Prepare to local synchronization, including computing all kinds of * information. * * @return SupterStepReportContainer */ @SuppressWarnings("unchecked") private SuperStepReportContainer prepareLocalBarrier() { int stageFlag = 0; long judgeFlag = 0; String[] dirFlag = { "1" }; String[] aggValues; for (Entry<StaffAttemptID, SuperStepReportContainer> e : this.runningStaffInformation .entrySet()) { SuperStepReportContainer tmp = e.getValue(); stageFlag = tmp.getStageFlag(); dirFlag = tmp.getDirFlag(); judgeFlag += tmp.getJudgeFlag(); // Get the aggregation values from the ssrcs. aggValues = tmp.getAggValues(); decapsulateAggregateValues(aggValues); }//end-for // Compute the aggregations for all staffs in the worker. localAggregate(); // Encapsulate the aggregation values to String[] for the ssrc. String[] newAggValues = encapsulateAggregateValues(); SuperStepReportContainer ssrc = new SuperStepReportContainer(stageFlag, dirFlag, judgeFlag, newAggValues); // newAggValues into the new ssrc. return ssrc; } public void addStaffReportCounter() { this.staffReportCounter++; } private void clearStaffReportCounter() { this.staffReportCounter = 0; } /** * All staffs belongs to the same job will use this to complete the local * synchronization and aggregation. * * @param staffId * @param superStepCounter * @param args * @return */ @Override public boolean localBarrier(BSPJobID jobId, StaffAttemptID staffId, int superStepCounter, SuperStepReportContainer ssrc) { this.runningStaffInformation.put(staffId, ssrc); synchronized (this.staffReportCounter) { addStaffReportCounter(); LOG.info(staffId.toString() + " [staffReportCounter]" + this.staffReportCounter); LOG.info(staffId.toString() + " [staffCounter]" + ssrc.getLocalBarrierNum()); if (this.staffReportCounter == ssrc.getLocalBarrierNum()) { clearStaffReportCounter(); switch (ssrc.getStageFlag()) { case Constants.SUPERSTEP_STAGE.FIRST_STAGE: wssc.firstStageSuperStepBarrier(superStepCounter, ssrc); break; case Constants.SUPERSTEP_STAGE.SECOND_STAGE: wssc.secondStageSuperStepBarrier(superStepCounter, prepareLocalBarrier()); break; case Constants.SUPERSTEP_STAGE.WRITE_CHECKPOINT_SATGE: case Constants.SUPERSTEP_STAGE.READ_CHECKPOINT_STAGE: if (ssrc.getStageFlag() == Constants.SUPERSTEP_STAGE.READ_CHECKPOINT_STAGE) { String workerNameAndPort = ssrc.getPartitionId() + ":" + this.workerManagerName + ":" + ssrc.getPort2(); for (Entry<StaffAttemptID, SuperStepReportContainer> e : this.runningStaffInformation.entrySet()) { if (e.getKey().equals(staffId)) { continue; } else { String str = e.getValue().getPartitionId() + ":" + this.workerManagerName + ":" + e.getValue().getPort2(); workerNameAndPort += Constants.KV_SPLIT_FLAG + str; } } ssrc.setActiveMQWorkerNameAndPorts(workerNameAndPort); wssc.checkPointStageSuperStepBarrier(superStepCounter, ssrc); } else { wssc.checkPointStageSuperStepBarrier(superStepCounter, ssrc); } break; case Constants.SUPERSTEP_STAGE.SAVE_RESULT_STAGE: wssc.saveResultStageSuperStepBarrier(superStepCounter, ssrc); break; default: LOG.error("The SUPERSTEP of " + ssrc.getStageFlag() + " is not known"); } return true; } else { return false; } } } @Override public int getNumberWorkers(BSPJobID jobId, StaffAttemptID staffId) { return this.workerManagerNum; } @Override public void setNumberWorkers(BSPJobID jobId, StaffAttemptID staffId, int num) { this.workerManagerNum = num; } @Override public void close() throws IOException { this.server.stop(); } @Override public long getProtocolVersion(String arg0, long arg1) throws IOException { return WorkerAgentInterface.versionID; } @Override public String getWorkerManagerName(BSPJobID jobId, StaffAttemptID staffId) { return this.workerManagerName; } /** * Add a new task to the job * * @param currentStaffStatus */ public void addStaffCounter(StaffAttemptID staffId) { SuperStepReportContainer ssrc = new SuperStepReportContainer(); this.runningStaffInformation.put(staffId, ssrc); } /** * Sets the job configuration * * @param jobConf */ public void setJobConf(BSPJob jobConf) { } /** * Get WorkerAgent BSPJobID */ @Override public BSPJobID getBSPJobID() { return this.jobId; } protected WorkerAgentInterface getWorkerAgentConnection( InetSocketAddress addr) { WorkerAgentInterface worker; synchronized (this.workers) { worker = workers.get(addr); if (worker == null) { try { worker = ( WorkerAgentInterface ) RPC.getProxy( WorkerAgentInterface.class, WorkerAgentInterface.versionID, addr, this.conf); } catch (IOException e) { } this.workers.put(addr, worker); } } return worker; } /** * This method is used to set mapping table that shows the partition to the * worker. * * @param jobId * @param partitionId * @param hostName */ public void setWorkerNametoPartitions(BSPJobID jobId, int partitionId, String hostName) { this.partitionToWorkerManagerName.put(partitionId, hostName + ":" + this.portForJob); } @SuppressWarnings("unchecked") private void loadAggregators() { int aggregateNum = this.jobConf.getAggregateNum(); String[] aggregateNames = this.jobConf.getAggregateNames(); for (int i = 0; i < aggregateNum; i ++) { String name = aggregateNames[i]; this.nameToAggregator.put(name, this.jobConf.getAggregatorClass(name)); this.nameToAggregateValue.put(name, jobConf.getAggregateValueClass(name)); this.aggregateValues.put(name, new ArrayList<AggregateValue>()); } } /** * To decapsulate the aggregation values from the String[]. * * The aggValues should be in form as follows: * [ AggregateName \t AggregateValue.toString() ] * * @param aggValues * String[] */ @SuppressWarnings("unchecked") private void decapsulateAggregateValues(String[] aggValues) { for (int i = 0; i < aggValues.length; i ++) { String[] aggValueRecord = aggValues[i].split(Constants.KV_SPLIT_FLAG); String aggName = aggValueRecord[0]; String aggValueString = aggValueRecord[1]; AggregateValue aggValue = null; try { aggValue = this.nameToAggregateValue.get(aggName).newInstance(); aggValue.initValue(aggValueString); // init the aggValue from its string form. } catch (InstantiationException e1) { LOG.error("InstantiationException", e1); } catch (IllegalAccessException e1) { LOG.error("IllegalAccessException", e1); }//end-try if (aggValue != null) { ArrayList<AggregateValue> list = this.aggregateValues.get(aggName); list.add(aggValue); // put the value to the values' list for aggregation ahead. }//end-if }//end-for } /** * To aggregate the values from the running staffs. */ @SuppressWarnings("unchecked") private void localAggregate() { // Clear the results' container before the calculation of a new super step. this.aggregateResults.clear(); // To calculate the aggregations. for (Entry<String, Class<? extends Aggregator<?>>> entry : this.nameToAggregator.entrySet()) { Aggregator<AggregateValue> aggregator = null; try { aggregator = ( Aggregator<AggregateValue> ) entry.getValue().newInstance(); } catch (InstantiationException e1) { LOG.error("InstantiationException", e1); } catch (IllegalAccessException e1) { LOG.error("IllegalAccessException", e1); } if (aggregator != null) { ArrayList<AggregateValue> aggVals = this.aggregateValues.get(entry.getKey()); AggregateValue resultValue = aggregator.aggregate(aggVals); this.aggregateResults.put(entry.getKey(), resultValue); aggVals.clear();// Clear the initial aggregate values after aggregation completes. } } } /** * To encapsulate the aggregation values to the String[]. * * The aggValues should be in form as follows: * [ AggregateName \t AggregateValue.toString() ] * * @return String[] */ @SuppressWarnings("unchecked") private String[] encapsulateAggregateValues() { int aggSize = this.aggregateResults.size(); String[] aggValues = new String[aggSize]; int i_a = 0; for (Entry<String, AggregateValue> entry : this.aggregateResults.entrySet()) { aggValues[i_a] = entry.getKey() + Constants.KV_SPLIT_FLAG + entry.getValue().toString(); i_a ++; } return aggValues; } public synchronized int getFreePort() { return this.workerManager.getFreePort(); } @Override public void setStaffAgentAddress(StaffAttemptID staffID, String addr) { this.workerManager.setStaffAgentAddress(staffID, addr); } }