/*********************************************************************************************************************** * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. **********************************************************************************************************************/ package eu.stratosphere.pact.runtime.udf; import java.util.Collection; import java.util.HashMap; import java.util.Map; import java.util.concurrent.FutureTask; import eu.stratosphere.api.common.accumulators.Accumulator; import eu.stratosphere.api.common.accumulators.AccumulatorHelper; import eu.stratosphere.api.common.accumulators.DoubleCounter; import eu.stratosphere.api.common.accumulators.Histogram; import eu.stratosphere.api.common.accumulators.IntCounter; import eu.stratosphere.api.common.accumulators.LongCounter; import eu.stratosphere.api.common.cache.DistributedCache; import eu.stratosphere.api.common.functions.RuntimeContext; import eu.stratosphere.core.fs.Path; /** * */ public class RuntimeUDFContext implements RuntimeContext { private final String name; private final int numParallelSubtasks; private final int subtaskIndex; private DistributedCache distributedCache = new DistributedCache(); private HashMap<String, Accumulator<?, ?>> accumulators = new HashMap<String, Accumulator<?, ?>>(); private HashMap<String, Collection<?>> broadcastVars = new HashMap<String, Collection<?>>(); public RuntimeUDFContext(String name, int numParallelSubtasks, int subtaskIndex) { this.name = name; this.numParallelSubtasks = numParallelSubtasks; this.subtaskIndex = subtaskIndex; } public RuntimeUDFContext(String name, int numParallelSubtasks, int subtaskIndex, Map<String, FutureTask<Path>> cpTasks) { this.name = name; this.numParallelSubtasks = numParallelSubtasks; this.subtaskIndex = subtaskIndex; this.distributedCache.setCopyTasks(cpTasks); } @Override public String getTaskName() { return this.name; } @Override public int getNumberOfParallelSubtasks() { return this.numParallelSubtasks; } @Override public int getIndexOfThisSubtask() { return this.subtaskIndex; } @Override public IntCounter getIntCounter(String name) { return (IntCounter) getAccumulator(name, IntCounter.class); } @Override public LongCounter getLongCounter(String name) { return (LongCounter) getAccumulator(name, LongCounter.class); } @Override public Histogram getHistogram(String name) { return (Histogram) getAccumulator(name, Histogram.class); } @Override public DoubleCounter getDoubleCounter(String name) { return (DoubleCounter) getAccumulator(name, DoubleCounter.class); } @Override public <V, A> void addAccumulator(String name, Accumulator<V, A> accumulator) { if (accumulators.containsKey(name)) { throw new UnsupportedOperationException("The counter '" + name + "' already exists and cannot be added."); } accumulators.put(name, accumulator); } @SuppressWarnings("unchecked") @Override public <V, A> Accumulator<V, A> getAccumulator(String name) { return (Accumulator<V, A>) accumulators.get(name); } @SuppressWarnings("unchecked") private <V, A> Accumulator<V, A> getAccumulator(String name, Class<? extends Accumulator<V, A>> accumulatorClass) { Accumulator<?, ?> accumulator = accumulators.get(name); if (accumulator != null) { AccumulatorHelper.compareAccumulatorTypes(name, accumulator.getClass(), accumulatorClass); } else { // Create new accumulator try { accumulator = accumulatorClass.newInstance(); } catch (InstantiationException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } accumulators.put(name, accumulator); } return (Accumulator<V, A>) accumulator; } @Override public HashMap<String, Accumulator<?, ?>> getAllAccumulators() { return this.accumulators; } public void setBroadcastVariable(String name, Collection<?> value) { this.broadcastVars.put(name, value); } @Override @SuppressWarnings("unchecked") public <RT> Collection<RT> getBroadcastVariable(String name) { if (!this.broadcastVars.containsKey(name)) { throw new IllegalArgumentException("Trying to access an unbound broadcast variable '" + name + "'."); } return (Collection<RT>) this.broadcastVars.get(name); } @Override public DistributedCache getDistributedCache() { return this.distributedCache; } }