/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.gcp.testing;
import static com.google.common.base.Preconditions.checkArgument;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.json.JsonFactory;
import com.google.api.client.util.BackOff;
import com.google.api.client.util.BackOffUtils;
import com.google.api.client.util.Sleeper;
import com.google.api.services.bigquery.Bigquery;
import com.google.api.services.bigquery.BigqueryScopes;
import com.google.api.services.bigquery.model.QueryRequest;
import com.google.api.services.bigquery.model.QueryResponse;
import com.google.api.services.bigquery.model.TableCell;
import com.google.api.services.bigquery.model.TableRow;
import com.google.auth.Credentials;
import com.google.auth.http.HttpCredentialsAdapter;
import com.google.auth.oauth2.GoogleCredentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.common.collect.Lists;
import com.google.common.hash.HashCode;
import com.google.common.hash.Hashing;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.nio.charset.StandardCharsets;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import javax.annotation.Nonnull;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.testing.SerializableMatcher;
import org.apache.beam.sdk.util.BackOffAdapter;
import org.apache.beam.sdk.util.FluentBackoff;
import org.apache.beam.sdk.util.Transport;
import org.hamcrest.Description;
import org.hamcrest.TypeSafeMatcher;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* A matcher to verify data in BigQuery by processing given query
* and comparing with content's checksum.
*
* <p>Example:
* <pre>{@code [
* assertThat(job, new BigqueryMatcher(appName, projectId, queryString, expectedChecksum));
* ]}</pre>
*/
@NotThreadSafe
public class BigqueryMatcher extends TypeSafeMatcher<PipelineResult>
implements SerializableMatcher<PipelineResult> {
private static final Logger LOG = LoggerFactory.getLogger(BigqueryMatcher.class);
// The maximum number of retries to execute a BigQuery RPC
static final int MAX_QUERY_RETRIES = 4;
// The initial backoff for executing a BigQuery RPC
private static final Duration INITIAL_BACKOFF = Duration.standardSeconds(1L);
// The total number of rows in query response to be formatted for debugging purpose
private static final int TOTAL_FORMATTED_ROWS = 20;
// The backoff factory with initial configs
static final FluentBackoff BACKOFF_FACTORY =
FluentBackoff.DEFAULT
.withMaxRetries(MAX_QUERY_RETRIES)
.withInitialBackoff(INITIAL_BACKOFF);
private final String applicationName;
private final String projectId;
private final String query;
private final String expectedChecksum;
private String actualChecksum;
private transient QueryResponse response;
public BigqueryMatcher(
String applicationName, String projectId, String query, String expectedChecksum) {
validateArgument("applicationName", applicationName);
validateArgument("projectId", projectId);
validateArgument("query", query);
validateArgument("expectedChecksum", expectedChecksum);
this.applicationName = applicationName;
this.projectId = projectId;
this.query = query;
this.expectedChecksum = expectedChecksum;
}
@Override
protected boolean matchesSafely(PipelineResult pipelineResult) {
LOG.info("Verifying Bigquery data");
Bigquery bigqueryClient = newBigqueryClient(applicationName);
// execute query
LOG.debug("Executing query: {}", query);
try {
QueryRequest queryContent = new QueryRequest();
queryContent.setQuery(query);
response = queryWithRetries(
bigqueryClient, queryContent, Sleeper.DEFAULT,
BackOffAdapter.toGcpBackOff(BACKOFF_FACTORY.backoff()));
} catch (IOException | InterruptedException e) {
if (e instanceof InterruptedIOException) {
Thread.currentThread().interrupt();
}
throw new RuntimeException("Failed to fetch BigQuery data.", e);
}
if (!response.getJobComplete()) {
// query job not complete, verification failed
return false;
} else {
// compute checksum
actualChecksum = generateHash(response.getRows());
LOG.debug("Generated a SHA1 checksum based on queried data: {}", actualChecksum);
return expectedChecksum.equals(actualChecksum);
}
}
@VisibleForTesting
Bigquery newBigqueryClient(String applicationName) {
HttpTransport transport = Transport.getTransport();
JsonFactory jsonFactory = Transport.getJsonFactory();
Credentials credential = getDefaultCredential();
return new Bigquery.Builder(transport, jsonFactory, new HttpCredentialsAdapter(credential))
.setApplicationName(applicationName)
.build();
}
@Nonnull
@VisibleForTesting
QueryResponse queryWithRetries(Bigquery bigqueryClient, QueryRequest queryContent,
Sleeper sleeper, BackOff backOff)
throws IOException, InterruptedException {
IOException lastException = null;
do {
if (lastException != null) {
LOG.warn("Retrying query ({}) after exception", queryContent.getQuery(), lastException);
}
try {
QueryResponse response = bigqueryClient.jobs().query(projectId, queryContent).execute();
if (response != null) {
return response;
} else {
lastException =
new IOException("Expected valid response from query job, but received null.");
}
} catch (IOException e) {
// ignore and retry
lastException = e;
}
} while(BackOffUtils.next(sleeper, backOff));
throw new RuntimeException(
String.format(
"Unable to get BigQuery response after retrying %d times using query (%s)",
MAX_QUERY_RETRIES,
queryContent.getQuery()),
lastException);
}
private void validateArgument(String name, String value) {
checkArgument(
!Strings.isNullOrEmpty(value), "Expected valid %s, but was %s", name, value);
}
private Credentials getDefaultCredential() {
GoogleCredentials credential;
try {
credential = GoogleCredentials.getApplicationDefault();
} catch (IOException e) {
throw new RuntimeException("Failed to get application default credential.", e);
}
if (credential.createScopedRequired()) {
Collection<String> bigqueryScope =
Lists.newArrayList(BigqueryScopes.CLOUD_PLATFORM_READ_ONLY);
credential = credential.createScoped(bigqueryScope);
}
return credential;
}
private String generateHash(@Nonnull List<TableRow> rows) {
List<HashCode> rowHashes = Lists.newArrayList();
for (TableRow row : rows) {
List<String> cellsInOneRow = Lists.newArrayList();
for (TableCell cell : row.getF()) {
cellsInOneRow.add(Objects.toString(cell.getV()));
Collections.sort(cellsInOneRow);
}
rowHashes.add(
Hashing.sha1().hashString(cellsInOneRow.toString(), StandardCharsets.UTF_8));
}
return Hashing.combineUnordered(rowHashes).toString();
}
@Override
public void describeTo(Description description) {
description
.appendText("Expected checksum is (")
.appendText(expectedChecksum)
.appendText(")");
}
@Override
public void describeMismatchSafely(PipelineResult pResult, Description description) {
String info;
if (!response.getJobComplete()) {
// query job not complete
info = String.format("The query job hasn't completed. Got response: %s", response);
} else {
// checksum mismatch
info = String.format("was (%s).%n"
+ "\tTotal number of rows are: %d.%n"
+ "\tQueried data details:%s",
actualChecksum, response.getTotalRows(), formatRows(TOTAL_FORMATTED_ROWS));
}
description.appendText(info);
}
private String formatRows(int totalNumRows) {
StringBuilder samples = new StringBuilder();
List<TableRow> rows = response.getRows();
for (int i = 0; i < totalNumRows && i < rows.size(); i++) {
samples.append(String.format("%n\t\t"));
for (TableCell field : rows.get(i).getF()) {
samples.append(String.format("%-10s", field.getV()));
}
}
if (rows.size() > totalNumRows) {
samples.append(String.format("%n\t\t..."));
}
return samples.toString();
}
}