/* * Copyright 2011 Google Inc. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.waveprotocol.box.waveimport.google.oauth; import com.google.common.base.Charsets; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.inject.Inject; import java.io.IOException; import java.util.List; import java.util.logging.Logger; import org.apache.commons.httpclient.Header; import org.apache.commons.httpclient.HttpClient; import org.apache.commons.httpclient.URIException; import org.apache.commons.httpclient.methods.EntityEnclosingMethod; /** * A fetch service that signs fetch requests with OAuth2. * * @author hearnden@google.com (David Hearnden) * @author ohler@google.com (Christian Ohler) */ public final class OAuthedFetchService { /** * Detects whether an HTTP response indicates that the OAuth token has * expired. */ public interface TokenRefreshNeededDetector { boolean refreshNeeded(int responseCode, Header[] responseHeaders, byte[] responseContent) throws IOException; } public static final TokenRefreshNeededDetector RESPONSE_CODE_401_DETECTOR = new TokenRefreshNeededDetector() { @Override public boolean refreshNeeded(int responseCode, Header[] responseHeaders, byte[] responseContent) { return responseCode == 401; } }; @SuppressWarnings("unused") private static final Logger log = Logger.getLogger(OAuthedFetchService.class.getName()); private final HttpClient fetch; private final OAuthRequestHelper helper; @Inject public OAuthedFetchService(HttpClient fetch, OAuthRequestHelper helper) { this.fetch = fetch; this.helper = helper; } private String describeRequest(EntityEnclosingMethod req) throws URIException { StringBuilder b = new StringBuilder(req.getName() + " " + req.getURI()); for (Header h : req.getRequestHeaders()) { b.append("\n" + h.getName() + ": " + h.getValue()); } return "" + b; } private String describeResponse(int responseCode, Header[] responseHeaders, byte[] responseBody, boolean includeBody) { StringBuilder b = new StringBuilder(responseCode + " with " + responseBody.length + " bytes of content"); for (Header h : responseHeaders) { b.append("\n" + h.getName() + ": " + h.getValue()); } if (includeBody) { b.append("\n" + new String(responseBody, Charsets.UTF_8)); } else { b.append("\n<content elided>"); } return "" + b; } private int fetch1(EntityEnclosingMethod req, TokenRefreshNeededDetector refreshNeeded, boolean tokenJustRefreshed) throws IOException { log.info("Sending request (token just refreshed: " + tokenJustRefreshed + "): " + describeRequest(req)); helper.authorize(req); //log.info("req after authorizing: " + describeRequest(req)); int code = fetch.executeMethod(req); log.info("response: " + describeResponse(code, req.getResponseHeaders(), req.getResponseBody(), false)); if (refreshNeeded.refreshNeeded(code, req.getResponseHeaders(), req.getResponseBody())) { if (tokenJustRefreshed) { throw new NeedNewOAuthTokenException("Token just refreshed, still no good: " + describeResponse(code, req.getResponseHeaders(), req.getResponseBody(), true)); } else { helper.refreshToken(); return fetch1(req, refreshNeeded, true); } } else { return code; } } public int fetch(EntityEnclosingMethod request, TokenRefreshNeededDetector refreshNeeded) throws IOException { return fetch1(request, refreshNeeded, false); } public int fetch(EntityEnclosingMethod request) throws IOException { return fetch(request, RESPONSE_CODE_401_DETECTOR); } // TODO(ohler): Move these static utility methods to some other utility class. /** Gets the values of all headers with the name {@code headerName}. */ public static List<String> getHeaders(Header[] respHeaders, String headerName) { ImmutableList.Builder<String> b = ImmutableList.builder(); for (Header h : respHeaders) { // HTTP header names are case-insensitive. App Engine downcases them when // deployed but not when running locally. if (headerName.equalsIgnoreCase(h.getName())) { b.add(h.getValue()); } } return b.build(); } /** * Checks that exactly one header named {@code headerName} is present and * returns its value. */ public static String getSingleHeader(Header[] respHeaders, String headerName) { return Iterables.getOnlyElement(getHeaders(respHeaders, headerName)); } /** Returns the body of {@code resp}, assuming that its encoding is UTF-8. */ private static String getUtf8ResponseBodyUnchecked(byte[] respContent) { if (respContent == null) { return ""; } else { return new String(respContent, Charsets.UTF_8); } } /** * Checks that the Content-Type of {@code resp} is * {@code expectedUtf8ContentType} (which is assumed to imply UTF-8 encoding) * and returns the body as a String. */ public static String getUtf8ResponseBody(Header[] respHeaders, byte[] respContent, String expectedUtf8ContentType) throws IOException { String contentType = getSingleHeader(respHeaders, "Content-Type"); if (!expectedUtf8ContentType.equals(contentType)) { throw new IOException("Unexpected Content-Type: " + contentType + " (wanted " + expectedUtf8ContentType + "); body as UTF-8: " + getUtf8ResponseBodyUnchecked(respContent)); } return getUtf8ResponseBodyUnchecked(respContent); } }