/**
* =============================================================================
*
* ORCID (R) Open Source
* http://orcid.org
*
* Copyright (c) 2012-2014 ORCID, Inc.
* Licensed under an MIT-Style License (MIT)
* http://orcid.org/open-source-license
*
* This copyright and license information (including a link to the full license)
* shall be included in its entirety in all copies or substantial portion of
* the software.
*
* =============================================================================
*/
package org.orcid.core.web.filters;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.http.HttpServletRequest;
import org.orcid.core.manager.impl.OrcidUrlManager;
import org.orcid.pojo.ajaxForm.PojoUtil;
import org.springframework.beans.factory.annotation.Value;
public class CrossDomainWebManger {
private static final String LOCALHOST = "localhost";
Pattern p = Pattern.compile("^/public/.*|^/userStatus\\.json|^/lang\\.json");
@Value("${org.orcid.security.cors.allowed_domains:qa.orcid.org,sandbox.orcid.org,orcid.org}")
private String allowedDomains;
private List<String> domainsRegex;
public boolean allowed(HttpServletRequest request) throws MalformedURLException {
String path = OrcidUrlManager.getPathWithoutContextPath(request);
// Check origin header
if (!PojoUtil.isEmpty(request.getHeader("origin"))) {
// If it is a valid domain, allow
if (validateDomain(request.getHeader("origin"))) {
return true;
}
} else {
// Check referer header for localhost
if (!PojoUtil.isEmpty(request.getHeader("referer"))) {
URL netUrl = new URL(request.getHeader("referer"));
String domain = netUrl.getHost();
if (LOCALHOST.equals(domain)) {
return true;
}
}
}
// If it is and invalid domain, validate the path
if (validatePath(path)) {
return true;
}
return false;
}
public boolean validateDomain(String url) throws MalformedURLException {
URL netUrl = new URL(url);
String domain = netUrl.getHost();
for (String allowedDomain : getAllowedDomainsRegex()) {
if (domain.matches(allowedDomain)) {
return true;
}
}
return false;
}
private List<String> getAllowedDomainsRegex() {
if (domainsRegex == null) {
domainsRegex = new ArrayList<String>();
for (String allowedDomain : allowedDomains.split(",")) {
String regex = transformPatternIntoRegex(allowedDomain);
domainsRegex.add(regex);
}
}
return domainsRegex;
}
private String transformPatternIntoRegex(String domainPattern) {
String result = domainPattern.replace(".", "\\.");
return result;
}
public boolean validatePath(String path) {
Matcher m = p.matcher(path);
if (m.matches()) {
return true;
}
return false;
}
}