/* * Copyright (C) 2015 Google Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except * in compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software distributed under the License * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. See the License for the specific language governing permissions and limitations under * the License. */ package com.google.cloud.genomics.dataflow.functions; import static com.google.common.collect.Lists.newArrayList; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import com.google.cloud.dataflow.sdk.Pipeline; import com.google.cloud.dataflow.sdk.testing.DataflowAssert; import com.google.cloud.dataflow.sdk.testing.TestPipeline; import com.google.cloud.dataflow.sdk.transforms.Create; import com.google.cloud.dataflow.sdk.transforms.DoFnTester; import com.google.cloud.dataflow.sdk.transforms.SerializableFunction; import com.google.cloud.dataflow.sdk.values.KV; import com.google.cloud.dataflow.sdk.values.PCollection; import com.google.cloud.genomics.utils.grpc.VariantUtils; import com.google.common.collect.Lists; import com.google.genomics.v1.Variant; import com.google.genomics.v1.VariantCall; import org.hamcrest.CoreMatchers; import org.hamcrest.collection.IsIterableWithSize; import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import java.util.Arrays; import java.util.Comparator; import java.util.List; @RunWith(JUnit4.class) public class JoinNonVariantSegmentsWithVariantsTest { private static final List<VariantCall> variantCalls = Lists.newArrayList(VariantCall.newBuilder() .setCallSetName("het-alt sample").addGenotype(1).addGenotype(0).build(), VariantCall .newBuilder().setCallSetName("hom-alt sample").addGenotype(1).addGenotype(1).build()); private static final List<VariantCall> blockRecord1Calls = Lists.newArrayList(VariantCall .newBuilder().setCallSetName("hom sample").addGenotype(0).addGenotype(0).build(), VariantCall .newBuilder().setCallSetName("no call sample").addGenotype(-1).addGenotype(-1).build()); private static final List<VariantCall> blockRecord2Calls = Lists.newArrayList(VariantCall .newBuilder().setCallSetName("hom no-call sample").addGenotype(-1).addGenotype(0).build()); private static final Variant expectedSnp1 = Variant.newBuilder().setReferenceName("chr7") .setStart(200010).setEnd(200011).setReferenceBases("A").addAlternateBases("C") .addAllCalls(variantCalls).addAllCalls(blockRecord1Calls).build(); private static final Variant expectedSnp2 = Variant.newBuilder().setReferenceName("chr7") .setStart(200019).setEnd(200020).setReferenceBases("T").addAlternateBases("G") .addAllCalls(variantCalls).addAllCalls(blockRecord1Calls).addAllCalls(blockRecord2Calls) .build(); private static final Variant expectedInsert = Variant.newBuilder().setReferenceName("chr7") .setStart(200010).setEnd(200011).setReferenceBases("A").addAlternateBases("AC") .addAllCalls(variantCalls).build(); private Variant snp1; private Variant snp2; private Variant insert; private Variant blockRecord1; private Variant blockRecord2; private Variant[] input; @Before public void setUp() { snp1 = Variant.newBuilder().setReferenceName("chr7").setStart(200010).setEnd(200011) .setReferenceBases("A").addAlternateBases("C").addAllCalls(variantCalls).build(); snp2 = Variant.newBuilder().setReferenceName("chr7").setStart(200019).setEnd(200020) .setReferenceBases("T").addAlternateBases("G").addAllCalls(variantCalls).build(); insert = Variant.newBuilder().setReferenceName("chr7").setStart(200010).setEnd(200011) .setReferenceBases("A").addAlternateBases("AC").addAllCalls(variantCalls).build(); blockRecord1 = Variant.newBuilder().setReferenceName("chr7").setStart(199005).setEnd(202050) .setReferenceBases("A").addAllCalls(blockRecord1Calls).build(); blockRecord2 = Variant.newBuilder().setReferenceName("chr7").setStart(200011).setEnd(200020) .setReferenceBases("A").addAllCalls(blockRecord2Calls).build(); input = new Variant[] {snp1, snp2, insert, blockRecord1, blockRecord2}; } @Test public void testVariantVariantComparator() { Comparator<Variant> comparator = VariantUtils.NON_VARIANT_SEGMENT_COMPARATOR; assertEquals(-1, comparator.compare(blockRecord1, snp1)); assertEquals(1, comparator.compare(blockRecord2, snp1)); assertEquals(-1, comparator.compare(snp1, snp2)); // Two variants at the same location Variant snp1DifferentAlt = Variant.newBuilder(snp1) .clearAlternateBases() .addAlternateBases("G") .build(); assertTrue(0 > comparator.compare(snp1, snp1DifferentAlt)); // Block record and variant at the same location Variant blockRecordForSnp1 = Variant.newBuilder(snp1) .clearAlternateBases() .build(); assertEquals(1, comparator.compare(snp1, blockRecordForSnp1)); List<Variant> variants = newArrayList(input); variants.add(snp1DifferentAlt); variants.add(blockRecordForSnp1); // Check all permutations for (Variant v1 : variants) { for (Variant v2 : variants) { assertTrue(Integer.signum(comparator.compare(v1, v2)) == -Integer.signum(comparator .compare(v2, v1))); } } } @Test public void testIsOverlapping() { assertTrue(VariantUtils.isOverlapping(blockRecord1, snp1)); assertTrue(VariantUtils.isOverlapping(blockRecord1, snp2)); assertFalse(VariantUtils.isOverlapping(blockRecord2, snp1)); assertTrue(VariantUtils.isOverlapping(blockRecord2, snp2)); } @Test public void testCombineVariantsFn() { DoFnTester<KV<KV<String, Long>, Iterable<Variant>>, Variant> fn = DoFnTester.of(new JoinNonVariantSegmentsWithVariants.CombineVariantsFn()); Assert.assertThat(fn.processBatch(KV.of(KV.of("chr7", 200000L), (Iterable<Variant>) Arrays.asList(input))), CoreMatchers.hasItems(expectedSnp1, expectedSnp2, expectedInsert)); } @Test public void testBinVariantsFn() { DoFnTester<Variant, KV<KV<String, Long>, Variant>> binVariantsFn = DoFnTester.of(new JoinNonVariantSegmentsWithVariants.BinShuffleAndCombineTransform.BinVariantsFn()); List<KV<KV<String, Long>, Variant>> binVariantsOutput = binVariantsFn.processBatch(input); assertThat(binVariantsOutput, CoreMatchers.hasItem(KV.of(KV.of("chr7", 200000L), snp1))); assertThat(binVariantsOutput, CoreMatchers.hasItem(KV.of(KV.of("chr7", 200000L), snp2))); assertThat(binVariantsOutput, CoreMatchers.hasItem(KV.of(KV.of("chr7", 200000L), insert))); assertThat(binVariantsOutput, CoreMatchers.hasItem(KV.of(KV.of("chr7", 199000L), blockRecord1))); assertThat(binVariantsOutput, CoreMatchers.hasItem(KV.of(KV.of("chr7", 200000L), blockRecord1))); assertThat(binVariantsOutput, CoreMatchers.hasItem(KV.of(KV.of("chr7", 201000L), blockRecord1))); assertThat(binVariantsOutput, CoreMatchers.hasItem(KV.of(KV.of("chr7", 202000L), blockRecord1))); assertThat(binVariantsOutput, CoreMatchers.hasItem(KV.of(KV.of("chr7", 200000L), blockRecord2))); assertEquals(8, binVariantsOutput.size()); } @Test public void testBinShuffleAndCombine() { Pipeline p = TestPipeline.create(); PCollection<Variant> mergedVariants = p.apply(Create.of(input)) .apply(new JoinNonVariantSegmentsWithVariants.BinShuffleAndCombineTransform()); DataflowAssert.that(mergedVariants).satisfies( new AssertThatHasExpectedContentsForTestJoinVariants()); p.run(); } static class AssertThatHasExpectedContentsForTestJoinVariants implements SerializableFunction<Iterable<Variant>, Void> { @Override public Void apply(Iterable<Variant> actual) { assertThat(actual, CoreMatchers.hasItem(expectedSnp1)); assertThat(actual, CoreMatchers.hasItem(expectedSnp2)); assertThat(actual, CoreMatchers.hasItem(expectedInsert)); assertThat(actual, IsIterableWithSize.<Variant>iterableWithSize(3)); return null; } } }