/*********************************************************************************************************************** * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. **********************************************************************************************************************/ package eu.stratosphere.pact.runtime.test.util; import eu.stratosphere.configuration.Configuration; import eu.stratosphere.core.fs.Path; import eu.stratosphere.core.io.IOReadableWritable; import eu.stratosphere.core.memory.MemorySegment; import eu.stratosphere.nephele.execution.Environment; import eu.stratosphere.runtime.io.gates.InputChannelResult; import eu.stratosphere.runtime.io.gates.RecordAvailabilityListener; import eu.stratosphere.runtime.io.serialization.AdaptiveSpanningRecordDeserializer; import eu.stratosphere.runtime.io.Buffer; import eu.stratosphere.runtime.io.channels.ChannelID; import eu.stratosphere.runtime.io.gates.GateID; import eu.stratosphere.runtime.io.gates.InputGate; import eu.stratosphere.runtime.io.gates.OutputGate; import eu.stratosphere.nephele.jobgraph.JobID; import eu.stratosphere.nephele.protocols.AccumulatorProtocol; import eu.stratosphere.nephele.services.iomanager.IOManager; import eu.stratosphere.nephele.services.memorymanager.MemoryManager; import eu.stratosphere.nephele.services.memorymanager.spi.DefaultMemoryManager; import eu.stratosphere.runtime.io.network.bufferprovider.BufferAvailabilityListener; import eu.stratosphere.runtime.io.network.bufferprovider.BufferProvider; import eu.stratosphere.runtime.io.network.bufferprovider.GlobalBufferPool; import eu.stratosphere.runtime.io.network.bufferprovider.LocalBufferPoolOwner; import eu.stratosphere.nephele.template.InputSplitProvider; import eu.stratosphere.runtime.io.serialization.RecordDeserializer; import eu.stratosphere.runtime.io.serialization.RecordDeserializer.DeserializationResult; import eu.stratosphere.types.Record; import eu.stratosphere.util.MutableObjectIterator; import java.io.IOException; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.FutureTask; public class MockEnvironment implements Environment, BufferProvider, LocalBufferPoolOwner { private final MemoryManager memManager; private final IOManager ioManager; private final InputSplitProvider inputSplitProvider; private final Configuration jobConfiguration; private final Configuration taskConfiguration; private final List<InputGate<Record>> inputs; private final List<OutputGate> outputs; private final JobID jobID = new JobID(); private final Buffer mockBuffer; public MockEnvironment(long memorySize, MockInputSplitProvider inputSplitProvider, int bufferSize) { this.jobConfiguration = new Configuration(); this.taskConfiguration = new Configuration(); this.inputs = new LinkedList<InputGate<Record>>(); this.outputs = new LinkedList<OutputGate>(); this.memManager = new DefaultMemoryManager(memorySize); this.ioManager = new IOManager(System.getProperty("java.io.tmpdir")); this.inputSplitProvider = inputSplitProvider; this.mockBuffer = new Buffer(new MemorySegment(new byte[bufferSize]), bufferSize, null); } public void addInput(MutableObjectIterator<Record> inputIterator) { int id = inputs.size(); inputs.add(new MockInputGate(id, inputIterator)); } public void addOutput(List<Record> outputList) { int id = outputs.size(); outputs.add(new MockOutputGate(id, outputList)); } @Override public Configuration getTaskConfiguration() { return this.taskConfiguration; } @Override public MemoryManager getMemoryManager() { return this.memManager; } @Override public IOManager getIOManager() { return this.ioManager; } @Override public JobID getJobID() { return this.jobID; } @Override public Buffer requestBuffer(int minBufferSize) throws IOException { return mockBuffer; } @Override public Buffer requestBufferBlocking(int minBufferSize) throws IOException, InterruptedException { return mockBuffer; } @Override public int getBufferSize() { return this.mockBuffer.size(); } @Override public BufferAvailabilityRegistration registerBufferAvailabilityListener(BufferAvailabilityListener listener) { return BufferAvailabilityRegistration.FAILED_BUFFER_POOL_DESTROYED; } @Override public int getNumberOfChannels() { return 1; } @Override public void setDesignatedNumberOfBuffers(int numBuffers) { } @Override public void clearLocalBufferPool() { } @Override public void registerGlobalBufferPool(GlobalBufferPool globalBufferPool) { } @Override public void logBufferUtilization() { } @Override public void reportAsynchronousEvent() { } private static class MockInputGate extends InputGate<Record> { private MutableObjectIterator<Record> it; public MockInputGate(int id, MutableObjectIterator<Record> it) { super(new JobID(), new GateID(), id); this.it = it; } @Override public void registerRecordAvailabilityListener(final RecordAvailabilityListener<Record> listener) { super.registerRecordAvailabilityListener(listener); this.notifyRecordIsAvailable(0); } @Override public InputChannelResult readRecord(Record target) throws IOException, InterruptedException { if ((target = it.next(target)) != null) { // everything comes from the same source channel and buffer in this mock notifyRecordIsAvailable(0); return InputChannelResult.INTERMEDIATE_RECORD_FROM_BUFFER; } else { return InputChannelResult.END_OF_STREAM; } } } private class MockOutputGate extends OutputGate { private List<Record> out; private RecordDeserializer<Record> deserializer; private Record record; public MockOutputGate(int index, List<Record> outList) { super(new JobID(), new GateID(), index); this.out = outList; this.deserializer = new AdaptiveSpanningRecordDeserializer<Record>(); this.record = new Record(); } @Override public void sendBuffer(Buffer buffer, int targetChannel) throws IOException, InterruptedException { this.deserializer.setNextMemorySegment(MockEnvironment.this.mockBuffer.getMemorySegment(), MockEnvironment.this.mockBuffer.size()); while (this.deserializer.hasUnfinishedData()) { DeserializationResult result = this.deserializer.getNextRecord(this.record); if (result.isFullRecord()) { this.out.add(this.record.createCopy()); } if (result == DeserializationResult.LAST_RECORD_FROM_BUFFER || result == DeserializationResult.PARTIAL_RECORD) { break; } } } @Override public int getNumChannels() { return 1; } } @Override public Configuration getJobConfiguration() { return this.jobConfiguration; } @Override public int getCurrentNumberOfSubtasks() { return 1; } @Override public int getIndexInSubtaskGroup() { return 0; } @Override public void userThreadStarted(final Thread userThread) { // Nothing to do here } @Override public void userThreadFinished(final Thread userThread) { // Nothing to do here } @Override public InputSplitProvider getInputSplitProvider() { return this.inputSplitProvider; } @Override public String getTaskName() { return null; } @Override public GateID getNextUnboundInputGateID() { return null; } @Override public int getNumberOfOutputGates() { return this.outputs.size(); } @Override public int getNumberOfInputGates() { return this.inputs.size(); } @Override public Set<ChannelID> getOutputChannelIDs() { throw new IllegalStateException("getOutputChannelIDs called on MockEnvironment"); } @Override public Set<ChannelID> getInputChannelIDs() { throw new IllegalStateException("getInputChannelIDs called on MockEnvironment"); } @Override public Set<GateID> getOutputGateIDs() { throw new IllegalStateException("getOutputGateIDs called on MockEnvironment"); } @Override public Set<GateID> getInputGateIDs() { throw new IllegalStateException("getInputGateIDs called on MockEnvironment"); } @Override public Set<ChannelID> getOutputChannelIDsOfGate(final GateID gateID) { throw new IllegalStateException("getOutputChannelIDsOfGate called on MockEnvironment"); } @Override public Set<ChannelID> getInputChannelIDsOfGate(final GateID gateID) { throw new IllegalStateException("getInputChannelIDsOfGate called on MockEnvironment"); } @Override public OutputGate createAndRegisterOutputGate() { return this.outputs.remove(0); } @Override public <T extends IOReadableWritable> InputGate<T> createAndRegisterInputGate() { return (InputGate<T>) this.inputs.remove(0); } @Override public int getNumberOfOutputChannels() { return this.outputs.size(); } @Override public int getNumberOfInputChannels() { return this.inputs.size(); } @Override public AccumulatorProtocol getAccumulatorProtocolProxy() { throw new UnsupportedOperationException( "getAccumulatorProtocolProxy() is not supported by MockEnvironment"); } @Override public BufferProvider getOutputBufferProvider() { return this; } @Override public Map<String, FutureTask<Path>> getCopyTask() { return null; } }