/* * Copyright 2017 LINE Corporation * * LINE Corporation licenses this file to you under the Apache License, * version 2.0 (the "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at: * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. */ package com.linecorp.armeria.common; import static com.google.common.base.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; import java.util.AbstractSet; import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; import java.util.Set; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Ascii; import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterators; /** * An immutable {@link Set} of {@link MediaType}s which provides useful methods for content negotiation. * * <p>This {@link Set} provides {@link #match(MediaType...)} and {@link #matchHeaders(CharSequence...)} * so that a user can find the preferred {@link MediaType} that matches the specified media ranges. For example: * <pre>{@code * MediaTypeSet set = new MediaTypeSet(MediaType.HTML_UTF_8, MediaType.PLAIN_TEXT_UTF_8); * * Optional<MediaType> negotiated1 = set.matchHeaders("text/html; q=0.5, text/plain"); * assert negotiated1.isPresent(); * assert negotiated1.get().equals(MediaType.PLAIN_TEXT_UTF_8); * * Optional<MediaType> negotiated2 = set.matchHeaders("audio/*, text/*"); * assert negotiated2.isPresent(); * assert negotiated2.get().equals(MediaType.HTML_UTF_8); * * Optional<MediaType> negotiated3 = set.matchHeaders("video/webm"); * assert !negotiated3.isPresent(); * }</pre> */ public final class MediaTypeSet extends AbstractSet<MediaType> { private static final String WILDCARD = MediaType.ANY_TYPE.type(); private static final String Q = "q"; private final MediaType[] mediaTypes; /** * Creates a new instance. */ public MediaTypeSet(MediaType... mediaTypes) { this(ImmutableList.copyOf(requireNonNull(mediaTypes, "mediaTypes"))); } /** * Creates a new instance. */ public MediaTypeSet(Iterable<MediaType> mediaTypes) { final Set<MediaType> mediaTypesCopy = new LinkedHashSet<>(); // Using a Set to deduplicate for (MediaType mediaType : requireNonNull(mediaTypes, "mediaTypes")) { requireNonNull(mediaType, "mediaTypes contains null."); checkArgument(!mediaType.hasWildcard(), "mediaTypes contains a wildcard media type: %s", mediaType); // Ensure qvalue does not exists. final List<String> otherQValues = mediaType.parameters().get(Q); checkArgument(otherQValues == null || otherQValues.isEmpty(), "mediaTypes contains a media type with a q-value parameter: %s", mediaType); mediaTypesCopy.add(mediaType); } this.mediaTypes = mediaTypesCopy.toArray(new MediaType[mediaTypesCopy.size()]); } @Override public boolean contains(Object o) { if (!(o instanceof MediaType)) { return false; } for (MediaType e : mediaTypes) { if (e.equals(o)) { return true; } } return false; } @Override public Iterator<MediaType> iterator() { return Iterators.forArray(mediaTypes); } @Override public int size() { return mediaTypes.length; } /** * Finds the {@link MediaType} in this {@link List} that matches one of the media ranges specified in the * specified string. * * @param acceptHeaders the values of the {@code "accept"} header, as defined in * <a href="https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html">the section 14.1, RFC2616</a> * @return the most preferred {@link MediaType} that matches one of the specified media ranges. * {@link Optional#empty()} if there are no matches or {@code acceptHeaders} does not contain * any valid ranges. */ public Optional<MediaType> matchHeaders(Iterable<? extends CharSequence> acceptHeaders) { requireNonNull(acceptHeaders, "acceptHeaders"); final List<MediaType> ranges = new ArrayList<>(4); for (CharSequence acceptHeader : acceptHeaders) { addRanges(ranges, acceptHeader); } return match(ranges); } /** * Finds the {@link MediaType} in this {@link List} that matches one of the media ranges specified in the * specified string. * * @param acceptHeaders the values of the {@code "accept"} header, as defined in * <a href="https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html">the section 14.1, RFC2616</a> * @return the most preferred {@link MediaType} that matches one of the specified media ranges. * {@link Optional#empty()} if there are no matches or {@code acceptHeaders} does not contain * any valid ranges. */ public Optional<MediaType> matchHeaders(CharSequence... acceptHeaders) { requireNonNull(acceptHeaders, "acceptHeaders"); final List<MediaType> ranges = new ArrayList<>(4); for (CharSequence acceptHeader : acceptHeaders) { addRanges(ranges, acceptHeader); } return match(ranges); } /** * Finds the {@link MediaType} in this {@link List} that matches one of the specified media ranges. * * @return the most preferred {@link MediaType} that matches one of the specified media ranges. * {@link Optional#empty()} if there are no matches or {@code ranges} does not contain * any valid ranges. */ public Optional<MediaType> match(MediaType... ranges) { return match(Arrays.asList(requireNonNull(ranges, "ranges"))); } /** * Finds the {@link MediaType} in this {@link List} that matches one of the specified media ranges. * * @return the most preferred {@link MediaType} that matches one of the specified media ranges. * {@link Optional#empty()} if there are no matches or {@code ranges} does not contain * any valid ranges. */ public Optional<MediaType> match(Iterable<MediaType> ranges) { requireNonNull(ranges, "ranges"); MediaType match = null; float matchQ = Float.NEGATIVE_INFINITY; // higher = better int matchNumWildcards = Integer.MAX_VALUE; // lower = better int matchNumParams = Integer.MIN_VALUE; // higher = better for (MediaType range : ranges) { requireNonNull(range, "ranges contains null."); for (MediaType candidate : mediaTypes) { if (!matches(range, candidate)) { continue; } float qValue = qValue(range); final int numWildcards = numWildcards(range); final int numParams; if (qValue < 0) { // qvalue does not exist; use the default value of 1.0. qValue = 1.0f; numParams = range.parameters().size(); } else { // Do not count the qvalue. numParams = range.parameters().size() - 1; } final boolean isBetter; if (qValue > matchQ) { isBetter = true; } else if (Math.copySign(qValue - matchQ, 1.0f) <= 0.0001f) { if (matchNumWildcards > numWildcards) { isBetter = true; } else if (matchNumWildcards == numWildcards) { isBetter = numParams > matchNumParams; } else { isBetter = false; } } else { isBetter = false; } if (isBetter) { match = candidate; matchQ = qValue; matchNumWildcards = numWildcards; matchNumParams = numParams; } } } return Optional.ofNullable(match); } private static boolean matches(MediaType range, MediaType candidate) { // Similar to what MediaType.is(MediaType) does except that this one // compares the parameters case-insensitively and excludes 'q' parameter. return (WILDCARD.equals(range.type()) || range.type().equals(candidate.type())) && (WILDCARD.equals(range.subtype()) || range.subtype().equals(candidate.subtype())) && containsAllParameters(range.parameters(), candidate.parameters()); } /** * Returns {@code true} if {@code actualParameters} contains all entries of {@code expectedParameters}. * Note that this method does *not* require {@code actualParameters} to contain *only* the entries of * {@code expectedParameters}. i.e. {@code actualParameters} can contain an entry that's non-existent * in {@code expectedParameters}. */ private static boolean containsAllParameters(Map<String, List<String>> expectedParameters, Map<String, List<String>> actualParameters) { if (expectedParameters.isEmpty()) { return true; } for (Entry<String, List<String>> requiredEntry : expectedParameters.entrySet()) { final String expectedName = requiredEntry.getKey(); final List<String> expectedValues = requiredEntry.getValue(); if (Q.equals(expectedName)) { continue; } final List<String> actualValues = actualParameters.get(expectedName); assert !expectedValues.isEmpty(); if (actualValues == null || actualValues.isEmpty()) { // Does not contain any required values. return false; } if (!containsAllValues(expectedValues, actualValues)) { return false; } } return true; } private static boolean containsAllValues(List<String> expectedValues, List<String> actualValues) { final int numRequiredValues = expectedValues.size(); for (int i = 0; i < numRequiredValues; i++) { if (!containsValue(expectedValues.get(i), actualValues)) { return false; } } return true; } private static boolean containsValue(String expectedValue, List<String> actualValues) { final int numActualValues = actualValues.size(); for (int i = 0; i < numActualValues; i++) { if (Ascii.equalsIgnoreCase(expectedValue, actualValues.get(i))) { return true; } } return false; } private static float qValue(MediaType range) { // Find 'q' or 'Q'. final List<String> qValues = range.parameters().get(Q); if (qValues == null || qValues.isEmpty()) { // qvalue does not exist. return Float.NEGATIVE_INFINITY; } try { // Parse the qvalue. Make sure it's within the range of [0, 1]. return Math.max(Math.min(Float.parseFloat(qValues.get(0)), 1.0f), 0.0f); } catch (NumberFormatException e) { // The range with a malformed qvalue gets the lowest possible preference. return 0.0f; } } private static int numWildcards(MediaType candidate) { int numWildcards = 0; if (WILDCARD.equals(candidate.type())) { numWildcards++; } if (WILDCARD.equals(candidate.subtype())) { numWildcards++; } return numWildcards; } private static final int ST_SKIP_LEADING_WHITESPACES = 0; private static final int ST_READ_QDTEXT = 1; private static final int ST_READ_QUOTED_STRING = 2; private static final int ST_READ_QUOTED_STRING_ESCAPED = 3; @VisibleForTesting @SuppressWarnings("checkstyle:FallThrough") static void addRanges(List<MediaType> ranges, CharSequence header) { final int length = header.length(); int state = ST_SKIP_LEADING_WHITESPACES; int firstCharIdx = 0; int lastCharIdx = -1; for (int i = 0; i < length; i++) { final char ch = header.charAt(i); switch (state) { case ST_SKIP_LEADING_WHITESPACES: if (!Character.isWhitespace(ch)) { state = ST_READ_QDTEXT; firstCharIdx = i; lastCharIdx = -1; } else { break; } case ST_READ_QDTEXT: if (Character.isWhitespace(ch)) { break; } switch (ch) { case '"': state = ST_READ_QUOTED_STRING; break; case ',': addRange(ranges, header, firstCharIdx, lastCharIdx); state = ST_SKIP_LEADING_WHITESPACES; break; default: lastCharIdx = i; } break; case ST_READ_QUOTED_STRING: switch (ch) { case '\\': state = ST_READ_QUOTED_STRING_ESCAPED; break; case '"': state = ST_READ_QDTEXT; lastCharIdx = i; break; } break; case ST_READ_QUOTED_STRING_ESCAPED: state = ST_READ_QUOTED_STRING; break; } } if (state == ST_READ_QDTEXT) { addRange(ranges, header, firstCharIdx, lastCharIdx); } } private static void addRange( List<MediaType> ranges, CharSequence header, int firstCharIdx, int lastCharIdx) { if (lastCharIdx >= 0) { try { ranges.add(MediaType.parse(header.subSequence(firstCharIdx, lastCharIdx + 1).toString())); } catch (IllegalArgumentException e) { // Ignore the malformed media range. } } } }