package org.ovirt.engine.core.aaa.filters;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
public class EnforceAuthFilter implements Filter {
private final List<String> additionalSchemes = new ArrayList<>();
@Override
public void init(FilterConfig filterConfig) throws ServletException {
for (String paramName : Collections.list(filterConfig.getInitParameterNames())) {
if (paramName.startsWith("scheme")) {
additionalSchemes.add(filterConfig.getInitParameter(paramName));
}
}
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException,
ServletException {
HttpServletRequest req = (HttpServletRequest)request;
HttpServletResponse res = (HttpServletResponse)response;
if (FiltersHelper.isAuthenticated(req)) {
chain.doFilter(request, response);
} else {
@SuppressWarnings("unchecked")
List<String> schemes = (List<String>) req.getAttribute(FiltersHelper.Constants.REQUEST_SCHEMES_KEY);
if (schemes == null) {
schemes = Collections.emptyList();
}
Set<String> allSchemes = new HashSet<>(schemes);
if (additionalSchemes != null) {
allSchemes.addAll(additionalSchemes);
}
for (String scheme: allSchemes) {
res.setHeader(FiltersHelper.Constants.HEADER_WWW_AUTHENTICATE, scheme);
}
res.sendError(HttpServletResponse.SC_UNAUTHORIZED);
}
}
@Override
public void destroy() {
}
}