/* * Copyright [2013-2014] PayPal Software Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package ml.shifu.guagua.worker; import java.io.IOException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import ml.shifu.guagua.io.Bytable; import ml.shifu.guagua.io.GuaguaFileSplit; import ml.shifu.guagua.io.GuaguaRecordReader; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * Abstract {@link WorkerComputable} implementation to load data one by one and only in the very 1st iteration. * * <p> * To load data successfully, make sure {@link GuaguaRecordReader} is initialized firstly. * * <p> * After data is loaded in the first iteration, one can store the data into collections (meomory or disk) to do later * iteration logic. * * <p> * TODO add multi-thread version to load data. * * @param <MASTER_RESULT> * master result for computation in each iteration. * @param <WORKER_RESULT> * worker result for computation in each iteration. * @param <KEY> * key type for each record * @param <VALUE> * value type for each record */ public abstract class AbstractWorkerComputable<MASTER_RESULT extends Bytable, WORKER_RESULT extends Bytable, KEY extends Bytable, VALUE extends Bytable> implements WorkerComputable<MASTER_RESULT, WORKER_RESULT> { private static final Logger LOG = LoggerFactory.getLogger(AbstractWorkerComputable.class); private AtomicBoolean isLoaded = new AtomicBoolean(false); private GuaguaRecordReader<KEY, VALUE> recordReader; /* * (non-Javadoc) * * @see ml.shifu.guagua.worker.WorkerComputable#compute(ml.shifu.guagua.worker.WorkerContext) */ @Override public WORKER_RESULT compute(WorkerContext<MASTER_RESULT, WORKER_RESULT> context) throws IOException { if(this.isLoaded.compareAndSet(false, true)) { init(context); long start = System.nanoTime(); preLoad(context); long count = 0; for(GuaguaFileSplit fileSplit: context.getFileSplits()) { LOG.info("Loading filesplit: {}", fileSplit); try { initRecordReader(fileSplit); LOG.info("file extension: {}, split: {}", fileSplit.getExtension(), fileSplit); context.setAttachment(fileSplit.getExtension()); while(getRecordReader().nextKeyValue()) { load(getRecordReader().getCurrentKey(), getRecordReader().getCurrentValue(), context); count += 1L; } } finally { if(getRecordReader() != null) { getRecordReader().close(); } } } if(count == 0L) { throw new IllegalStateException( "Record account in such worker is zero, please check if any exceptions in your input data."); } postLoad(context); LOG.info("Load {} records.", count); LOG.info("Data loading time: {}ms", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); } long start = System.nanoTime(); try { return doCompute(context); } finally { LOG.info("Computation time for application {} container {} iteration {}: {}ms.", context.getAppId(), context.getContainerId(), context.getCurrentIteration(), TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); } } /** * Do some pre work before loading data. */ protected void preLoad(WorkerContext<MASTER_RESULT, WORKER_RESULT> context) { } /** * Do some post work before loading data. */ protected void postLoad(WorkerContext<MASTER_RESULT, WORKER_RESULT> context) { } /** * Each {@link GuaguaFileSplit} must be initialized before loading data. */ public abstract void initRecordReader(GuaguaFileSplit fileSplit) throws IOException; /** * Initialization work for the whole computation */ public abstract void init(WorkerContext<MASTER_RESULT, WORKER_RESULT> context); /** * Real computation logic after data loading. */ public abstract WORKER_RESULT doCompute(WorkerContext<MASTER_RESULT, WORKER_RESULT> context); /** * Load data one by one before computation. */ public abstract void load(KEY currentKey, VALUE currentValue, WorkerContext<MASTER_RESULT, WORKER_RESULT> context); public GuaguaRecordReader<KEY, VALUE> getRecordReader() { return recordReader; } public void setRecordReader(GuaguaRecordReader<KEY, VALUE> recordReader) { this.recordReader = recordReader; } }