/*********************************************************************************************************************** * Copyright (C) 2010-2013 by the Stratosphere project (http://stratosphere.eu) * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the * specific language governing permissions and limitations under the License. **********************************************************************************************************************/ package eu.stratosphere.pact.runtime.task; import java.util.List; import eu.stratosphere.api.common.functions.GenericJoiner; import eu.stratosphere.api.common.typeutils.TypeComparator; import eu.stratosphere.api.common.typeutils.TypePairComparatorFactory; import eu.stratosphere.api.common.typeutils.TypeSerializer; import eu.stratosphere.core.memory.MemorySegment; import eu.stratosphere.pact.runtime.hash.MutableHashTable; import eu.stratosphere.pact.runtime.task.util.TaskConfig; import eu.stratosphere.pact.runtime.util.EmptyMutableObjectIterator; import eu.stratosphere.util.Collector; import eu.stratosphere.util.MutableObjectIterator; public abstract class AbstractCachedBuildSideMatchDriver<IT1, IT2, OT> extends MatchDriver<IT1, IT2, OT> implements ResettablePactDriver<GenericJoiner<IT1, IT2, OT>, OT> { /** * We keep it without generic parameters, because they vary depending on which input is the build side. */ protected volatile MutableHashTable<?, ?> hashJoin; private final int buildSideIndex; private final int probeSideIndex; protected AbstractCachedBuildSideMatchDriver(int buildSideIndex, int probeSideIndex) { this.buildSideIndex = buildSideIndex; this.probeSideIndex = probeSideIndex; } // -------------------------------------------------------------------------------------------- @Override public boolean isInputResettable(int inputNum) { if (inputNum < 0 || inputNum > 1) { throw new IndexOutOfBoundsException(); } return inputNum == buildSideIndex; } @Override public void initialize() throws Exception { TaskConfig config = this.taskContext.getTaskConfig(); TypeSerializer<IT1> serializer1 = this.taskContext.<IT1>getInputSerializer(0).getSerializer(); TypeSerializer<IT2> serializer2 = this.taskContext.<IT2>getInputSerializer(1).getSerializer(); TypeComparator<IT1> comparator1 = this.taskContext.getInputComparator(0); TypeComparator<IT2> comparator2 = this.taskContext.getInputComparator(1); MutableObjectIterator<IT1> input1 = this.taskContext.getInput(0); MutableObjectIterator<IT2> input2 = this.taskContext.getInput(1); TypePairComparatorFactory<IT1, IT2> pairComparatorFactory = this.taskContext.getTaskConfig().getPairComparatorFactory(this.taskContext.getUserCodeClassLoader()); int numMemoryPages = this.taskContext.getMemoryManager().computeNumberOfPages(config.getMemoryDriver()); List<MemorySegment> memSegments = this.taskContext.getMemoryManager().allocatePages( this.taskContext.getOwningNepheleTask(), numMemoryPages); if (buildSideIndex == 0 && probeSideIndex == 1) { MutableHashTable<IT1, IT2> hashJoin = new MutableHashTable<IT1, IT2>(serializer1, serializer2, comparator1, comparator2, pairComparatorFactory.createComparator21(comparator1, comparator2), memSegments, this.taskContext.getIOManager()); this.hashJoin = hashJoin; hashJoin.open(input1, EmptyMutableObjectIterator.<IT2>get()); } else if (buildSideIndex == 1 && probeSideIndex == 0) { MutableHashTable<IT2, IT1> hashJoin = new MutableHashTable<IT2, IT1>(serializer2, serializer1, comparator2, comparator1, pairComparatorFactory.createComparator12(comparator1, comparator2), memSegments, this.taskContext.getIOManager()); this.hashJoin = hashJoin; hashJoin.open(input2, EmptyMutableObjectIterator.<IT1>get()); } else { throw new Exception("Error: Inconcistent setup for repeatable hash join driver."); } } @Override public void prepare() throws Exception { // nothing } @Override public void run() throws Exception { final GenericJoiner<IT1, IT2, OT> matchStub = this.taskContext.getStub(); final Collector<OT> collector = this.taskContext.getOutputCollector(); if (buildSideIndex == 0) { final TypeSerializer<IT1> buildSideSerializer = taskContext.<IT1> getInputSerializer(0).getSerializer(); final TypeSerializer<IT2> probeSideSerializer = taskContext.<IT2> getInputSerializer(1).getSerializer(); IT1 buildSideRecordFirst; IT1 buildSideRecordOther; IT2 probeSideRecord; IT2 probeSideRecordCopy; final IT1 buildSideRecordFirstReuse = buildSideSerializer.createInstance(); final IT1 buildSideRecordOtherReuse = buildSideSerializer.createInstance(); final IT2 probeSideRecordReuse = probeSideSerializer.createInstance(); final IT2 probeSideRecordCopyReuse = probeSideSerializer.createInstance(); @SuppressWarnings("unchecked") final MutableHashTable<IT1, IT2> join = (MutableHashTable<IT1, IT2>) this.hashJoin; final MutableObjectIterator<IT2> probeSideInput = taskContext.<IT2>getInput(1); while (this.running && ((probeSideRecord = probeSideInput.next(probeSideRecordReuse)) != null)) { final MutableHashTable.HashBucketIterator<IT1, IT2> bucket = join.getMatchesFor(probeSideRecord); if ((buildSideRecordFirst = bucket.next(buildSideRecordFirstReuse)) != null) { while ((buildSideRecordOther = bucket.next(buildSideRecordOtherReuse)) != null) { probeSideRecordCopy = probeSideSerializer.copy(probeSideRecord, probeSideRecordCopyReuse); matchStub.join(buildSideRecordOther, probeSideRecordCopy, collector); } matchStub.join(buildSideRecordFirst, probeSideRecord, collector); } } } else if (buildSideIndex == 1) { final TypeSerializer<IT2> buildSideSerializer = taskContext.<IT2>getInputSerializer(1).getSerializer(); final TypeSerializer<IT1> probeSideSerializer = taskContext.<IT1>getInputSerializer(0).getSerializer(); IT2 buildSideRecordFirst; IT2 buildSideRecordOther; IT1 probeSideRecord; IT1 probeSideRecordCopy; final IT2 buildSideRecordFirstReuse = buildSideSerializer.createInstance(); final IT2 buildSideRecordOtherReuse = buildSideSerializer.createInstance(); final IT1 probeSideRecordReuse = probeSideSerializer.createInstance(); final IT1 probeSideRecordCopyReuse = probeSideSerializer.createInstance(); @SuppressWarnings("unchecked") final MutableHashTable<IT2, IT1> join = (MutableHashTable<IT2, IT1>) this.hashJoin; final MutableObjectIterator<IT1> probeSideInput = taskContext.<IT1>getInput(0); while (this.running && ((probeSideRecord = probeSideInput.next(probeSideRecordReuse)) != null)) { final MutableHashTable.HashBucketIterator<IT2, IT1> bucket = join.getMatchesFor(probeSideRecord); if ((buildSideRecordFirst = bucket.next(buildSideRecordFirstReuse)) != null) { while ((buildSideRecordOther = bucket.next(buildSideRecordOtherReuse)) != null) { probeSideRecordCopy = probeSideSerializer.copy(probeSideRecord, probeSideRecordCopyReuse); matchStub.join(probeSideRecordCopy, buildSideRecordOther, collector); } matchStub.join(probeSideRecord, buildSideRecordFirst, collector); } } } else { throw new Exception(); } } @Override public void cleanup() throws Exception {} @Override public void reset() throws Exception {} @Override public void teardown() { MutableHashTable<?, ?> ht = this.hashJoin; if (ht != null) { ht.close(); } } @Override public void cancel() { this.running = false; if (this.hashJoin != null) { this.hashJoin.close(); } } }