package org.ohdsi.webapi.shiro.management;
import io.buji.pac4j.filter.CallbackFilter;
import io.buji.pac4j.filter.SecurityFilter;
import io.buji.pac4j.realm.Pac4jRealm;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import javax.annotation.PostConstruct;
import javax.servlet.Filter;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.HttpMethod;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.authc.Authenticator;
import org.apache.shiro.authc.pam.ModularRealmAuthenticator;
import org.apache.shiro.realm.Realm;
import org.apache.shiro.web.filter.authz.SslFilter;
import org.apache.shiro.web.filter.session.NoSessionCreationFilter;
import org.apache.shiro.web.servlet.AdviceFilter;
import org.apache.shiro.web.util.WebUtils;
import org.ohdsi.webapi.shiro.CorsFilter;
import org.ohdsi.webapi.shiro.Entities.RoleEntity;
import org.ohdsi.webapi.shiro.ForceSessionCreationFilter;
import org.ohdsi.webapi.shiro.InvalidateAccessTokenFilter;
import org.ohdsi.webapi.shiro.JwtAuthFilter;
import org.ohdsi.webapi.shiro.JwtAuthRealm;
import org.ohdsi.webapi.shiro.LogoutFilter;
import org.ohdsi.webapi.shiro.PermissionManager;
import org.ohdsi.webapi.shiro.ProcessResponseContentFilter;
import org.ohdsi.webapi.shiro.RedirectOnFailedOAuthFilter;
import org.ohdsi.webapi.shiro.SendTokenInHeaderFilter;
import org.ohdsi.webapi.shiro.SendTokenInUrlFilter;
import org.ohdsi.webapi.shiro.SkipFurtherFilteringFilter;
import org.ohdsi.webapi.shiro.UpdateAccessTokenFilter;
import org.ohdsi.webapi.shiro.UrlBasedAuthorizingFilter;
import org.ohdsi.webapi.source.Source;
import org.ohdsi.webapi.source.SourceRepository;
import org.pac4j.core.client.Clients;
import org.pac4j.core.config.Config;
import org.pac4j.oauth.client.FacebookClient;
import org.pac4j.oauth.client.Google2Client;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import waffle.shiro.negotiate.NegotiateAuthenticationFilter;
import waffle.shiro.negotiate.NegotiateAuthenticationRealm;
import waffle.shiro.negotiate.NegotiateAuthenticationStrategy;
/**
*
* @author gennadiy.anisimov
*/
@Component
public class AtlasSecurity extends Security {
private final Log log = LogFactory.getLog(getClass());
@Autowired
private PermissionManager authorizer;
@Autowired
SourceRepository sourceRepository;
@Value("${security.token.expiration}")
private int tokenExpirationIntervalInSeconds;
@Value("${server.port}")
private int sslPort;
@Value("${security.ssl.enabled}")
private boolean sslEnabled;
@Value("${security.oauth.callback.ui}")
private String oauthUiCallback;
@Value("${security.oauth.callback.api}")
private String oauthApiCallback;
@Value("${security.oauth.google.apiKey}")
private String googleApiKey;
@Value("${security.oauth.google.apiSecret}")
private String googleApiSecret;
@Value("${security.oauth.facebook.apiKey}")
private String facebookApiKey;
@Value("${security.oauth.facebook.apiSecret}")
private String facebookApiSecret;
private final Set<String> defaultRoles = new LinkedHashSet<>();
private final Map<String, String> cohortdefinitionCreatorPermissionTemplates = new LinkedHashMap<>();
private final Map<String, String> conceptsetCreatorPermissionTemplates = new LinkedHashMap<>();
private final Map<String, String> sourcePermissionTemplates = new LinkedHashMap<>();
public AtlasSecurity() {
this.defaultRoles.add("public");
this.cohortdefinitionCreatorPermissionTemplates.put("cohortdefinition:%s:put", "Update Cohort Definition with ID = %s");
this.cohortdefinitionCreatorPermissionTemplates.put("cohortdefinition:%s:delete", "Delete Cohort Definition with ID = %s");
this.conceptsetCreatorPermissionTemplates.put("conceptset:%s:post", "Update Concept Set with ID = %s");
this.conceptsetCreatorPermissionTemplates.put("conceptset:%s:items:post", "Update Items of Concept Set with ID = %s");
this.conceptsetCreatorPermissionTemplates.put("conceptset:%s:delete:post", "Delete Concept Set with ID = %s");
this.sourcePermissionTemplates.put("cohortdefinition:*:report:%s:get", "Get Inclusion Rule Report for Source with SourceKey = %s");
this.sourcePermissionTemplates.put("cohortdefinition:*:generate:%s:get", "Generate Cohort on Source with SourceKey = %s");
}
@Override
public Map<String, String> getFilterChain() {
return new FilterChainBuilder()
.setOAuthFilters("ssl, cors, forceSessionCreation", "updateToken, sendTokenInUrl")
.setRestFilters("ssl, noSessionCreation, cors")
.setAuthcFilter("jwtAuthc")
.setAuthzFilter("authz")
// the order does metter - first match wins
// protected resources
//
// login/logout
.addRestPath("/user/login", "negotiateAuthc, updateToken, sendTokenInHeader")
.addRestPath("/user/refresh", "jwtAuthc, updateToken, sendTokenInHeader")
.addRestPath("/user/logout", "invalidateToken, logout")
.addOAuthPath("/user/oauth/google", "googleAuthc")
.addOAuthPath("/user/oauth/facebook", "facebookAuthc")
.addPath("/user/oauth/callback", "ssl, handleUnsuccessfullOAuth, oauthCallback")
// permissions
.addProtectedRestPath("/user/**")
.addProtectedRestPath("/role/**")
.addProtectedRestPath("/permission/**")
// concept set
.addRestPath("/conceptset", "skipFurtherFiltersIfNotPutOrPost, jwtAuthc, authz, createPermissionsOnCreateConceptSet") // only PUT and POST methods are protected
.addRestPath("/conceptset/*/items", "skipFurtherFiltersIfNotPut, jwtAuthc, authz") // only PUT method is protected
.addRestPath("/conceptset/*", "skipFurtherFiltersIfNotPutOrDelete, jwtAuthc, authz, deletePermissionsOnDeleteConceptSet") // only PUT and DELETE methods are protected
// cohort definition
.addProtectedRestPath("/cohortdefinition", "createPermissionsOnCreateCohortDefinition")
.addProtectedRestPath("/cohortdefinition/*/copy", "createPermissionsOnCreateCohortDefinition")
.addProtectedRestPath("/cohortdefinition/*", "deletePermissionsOnDeleteCohortDefinition")
.addProtectedRestPath("/cohortdefinition/*/info")
.addProtectedRestPath("/cohortdefinition/sql")
.addProtectedRestPath("/cohortdefinition/*/generate/*")
.addProtectedRestPath("/cohortdefinition/*/report/*")
.addProtectedRestPath("/*/cohortresults/*/breakdown")
.addProtectedRestPath("/job/execution")
// not protected resources - all the rest
.addRestPath("/**")
.build();
}
@Override
public Map<String, Filter> getFilters() {
Map<String, javax.servlet.Filter> filters = new HashMap<>();
filters.put("logout", new LogoutFilter());
filters.put("noSessionCreation", new NoSessionCreationFilter());
filters.put("forceSessionCreation", new ForceSessionCreationFilter());
filters.put("jwtAuthc", new JwtAuthFilter());
filters.put("negotiateAuthc", new NegotiateAuthenticationFilter());
filters.put("updateToken", new UpdateAccessTokenFilter(this.authorizer, this.defaultRoles, this.tokenExpirationIntervalInSeconds));
filters.put("invalidateToken", new InvalidateAccessTokenFilter());
filters.put("authz", new UrlBasedAuthorizingFilter());
filters.put("createPermissionsOnCreateCohortDefinition", this.getCreatePermissionsOnCreateCohortDefinitionFilter());
filters.put("createPermissionsOnCreateConceptSet", this.getCreatePermissionsOnCreateConceptSetFilter());
filters.put("deletePermissionsOnDeleteCohortDefinition", this.getDeletePermissionsOnDeleteCohortDefinitionFilter());
filters.put("deletePermissionsOnDeleteConceptSet", this.getDeletePermissionsOnDeleteConceptSetFilter());
filters.put("cors", new CorsFilter());
filters.put("skipFurtherFiltersIfNotPost", this.getSkipFurtherFiltersIfNotPostFilter());
filters.put("skipFurtherFiltersIfNotPut", this.getSkipFurtherFiltersIfNotPutFilter());
filters.put("skipFurtherFiltersIfNotPutOrPost", this.getskipFurtherFiltersIfNotPutOrPostFilter());
filters.put("skipFurtherFiltersIfNotPutOrDelete", this.getskipFurtherFiltersIfNotPutOrDeleteFilter());
filters.put("sendTokenInUrl", new SendTokenInUrlFilter(this.oauthUiCallback));
filters.put("sendTokenInHeader", new SendTokenInHeaderFilter());
filters.put("ssl", this.getSslFilter());
// OAuth
//
Google2Client googleClient = new Google2Client(this.googleApiKey, this.googleApiSecret);
googleClient.setScope(Google2Client.Google2Scope.EMAIL);
FacebookClient facebookClient = new FacebookClient(this.facebookApiKey, this.facebookApiSecret);
facebookClient.setScope("email");
facebookClient.setFields("email");
Config cfg =
new Config(
new Clients(
this.oauthApiCallback
, googleClient
, facebookClient
// ... put new clients here and then assign them to filters ...
)
);
// assign clients to filters
SecurityFilter googleOauthFilter = new SecurityFilter();
googleOauthFilter.setConfig(cfg);
googleOauthFilter.setClients("Google2Client");
filters.put("googleAuthc", googleOauthFilter);
SecurityFilter facebookOauthFilter = new SecurityFilter();
facebookOauthFilter.setConfig(cfg);
facebookOauthFilter.setClients("FacebookClient");
filters.put("facebookAuthc", facebookOauthFilter);
CallbackFilter callbackFilter = new CallbackFilter();
callbackFilter.setConfig(cfg);
filters.put("oauthCallback", callbackFilter);
filters.put("handleUnsuccessfullOAuth", new RedirectOnFailedOAuthFilter(this.oauthUiCallback));
return filters;
}
@Override
public Set<Realm> getRealms() {
Set<Realm> realms = new LinkedHashSet<>();
realms.add(new JwtAuthRealm(this.authorizer));
realms.add(new NegotiateAuthenticationRealm());
realms.add(new Pac4jRealm());
return realms;
}
@Override
public Authenticator getAuthenticator() {
ModularRealmAuthenticator authenticator = new ModularRealmAuthenticator();
authenticator.setAuthenticationStrategy(new NegotiateAuthenticationStrategy());
return authenticator;
}
private void addSourceRole(String sourceKey) throws Exception {
String roleName = String.format("Source user (%s)", sourceKey);
if (this.authorizer.roleExists(roleName)) {
return;
}
RoleEntity role = this.authorizer.addRole(roleName);
this.authorizer.addPermissionsFromTemplate(role, this.sourcePermissionTemplates, sourceKey);
}
@PostConstruct
private void initRolesForSources() {
try {
for (Source source : sourceRepository.findAll()) {
this.addSourceRole(source.getSourceKey());
}
}
catch (Exception e) {
log.error(e);
}
}
private Filter getCreatePermissionsOnCreateCohortDefinitionFilter() {
return new ProcessResponseContentFilter() {
@Override
protected boolean shouldProcess(ServletRequest request, ServletResponse response) {
HttpServletRequest httpRequest = WebUtils.toHttp(request);
String path = httpRequest.getPathInfo().replaceAll("/+$", "");
if (StringUtils.endsWithIgnoreCase(path, "copy")) {
return HttpMethod.GET.equalsIgnoreCase(WebUtils.toHttp(request).getMethod());
}
else {
return HttpMethod.POST.equalsIgnoreCase(WebUtils.toHttp(request).getMethod());
}
}
@Override
protected void doProcessResponseContent(String content) throws Exception {
String id = this.parseJsonField(content, "id");
RoleEntity currentUserPersonalRole = authorizer.getCurrentUserPersonalRole();
authorizer.addPermissionsFromTemplate(currentUserPersonalRole, cohortdefinitionCreatorPermissionTemplates, id);
}
};
}
private Filter getCreatePermissionsOnCreateConceptSetFilter() {
return new ProcessResponseContentFilter() {
@Override
protected boolean shouldProcess(ServletRequest request, ServletResponse response) {
return HttpMethod.POST.equalsIgnoreCase(WebUtils.toHttp(request).getMethod());
}
@Override
protected void doProcessResponseContent(String content) throws Exception {
String id = this.parseJsonField(content, "id");
RoleEntity currentUserPersonalRole = authorizer.getCurrentUserPersonalRole();
authorizer.addPermissionsFromTemplate(currentUserPersonalRole, conceptsetCreatorPermissionTemplates, id);
}
};
}
private Filter getDeletePermissionsOnDeleteCohortDefinitionFilter() {
return new AdviceFilter() {
@Override
protected void postHandle(ServletRequest request, ServletResponse response) {
HttpServletRequest httpRequest = WebUtils.toHttp(request);
if (!HttpMethod.DELETE.equalsIgnoreCase(httpRequest.getMethod())) {
return;
}
String id = httpRequest.getPathInfo()
.replaceAll("^/+", "")
.replaceAll("/+$", "")
.split("/")
[1];
authorizer.removePermissionsFromTemplate(cohortdefinitionCreatorPermissionTemplates, id);
}
};
}
private Filter getDeletePermissionsOnDeleteConceptSetFilter() {
return new AdviceFilter() {
@Override
protected void postHandle(ServletRequest request, ServletResponse response) {
HttpServletRequest httpRequest = WebUtils.toHttp(request);
if (!HttpMethod.DELETE.equalsIgnoreCase(httpRequest.getMethod())) {
return;
}
String id = httpRequest.getPathInfo()
.replaceAll("^/+", "")
.replaceAll("/+$", "")
.split("/")
[1];
authorizer.removePermissionsFromTemplate(conceptsetCreatorPermissionTemplates, id);
}
};
}
private Filter getSkipFurtherFiltersIfNotPostFilter() {
return new SkipFurtherFilteringFilter() {
@Override
protected boolean shouldSkip(ServletRequest request, ServletResponse response) {
return !HttpMethod.POST.equalsIgnoreCase(WebUtils.toHttp(request).getMethod());
}
};
}
private Filter getSkipFurtherFiltersIfNotPutFilter() {
return new SkipFurtherFilteringFilter() {
@Override
protected boolean shouldSkip(ServletRequest request, ServletResponse response) {
return !HttpMethod.PUT.equalsIgnoreCase(WebUtils.toHttp(request).getMethod());
}
};
}
private Filter getskipFurtherFiltersIfNotPutOrPostFilter() {
return new SkipFurtherFilteringFilter() {
@Override
protected boolean shouldSkip(ServletRequest request, ServletResponse response) {
String httpMethod = WebUtils.toHttp(request).getMethod();
return !(HttpMethod.PUT.equalsIgnoreCase(httpMethod) || HttpMethod.POST.equalsIgnoreCase(httpMethod));
}
};
}
private Filter getskipFurtherFiltersIfNotPutOrDeleteFilter() {
return new SkipFurtherFilteringFilter() {
@Override
protected boolean shouldSkip(ServletRequest request, ServletResponse response) {
String httpMethod = WebUtils.toHttp(request).getMethod();
return !(HttpMethod.PUT.equalsIgnoreCase(httpMethod) || HttpMethod.DELETE.equalsIgnoreCase(httpMethod));
}
};
}
private Filter getSslFilter() {
SslFilter sslFilter = new SslFilter();
sslFilter.setPort(sslPort);
sslFilter.setEnabled(sslEnabled);
return sslFilter;
}
@Override
public String getSubject() {
if (SecurityUtils.getSubject().isAuthenticated())
return authorizer.getSubjectName();
else
return "anonymous";
}
private class FilterChainBuilder {
private Map<String, String> filterChain = new LinkedHashMap<>();
private String restFilters;
private String authcFilter;
private String authzFilter;
private String filtersBeforeOAuth;
private String filtersAfterOAuth;
public FilterChainBuilder setRestFilters(String restFilters) {
this.restFilters = restFilters;
return this;
}
public FilterChainBuilder setOAuthFilters(String filtersBeforeOAuth, String filtersAfterOAuth) {
this.filtersBeforeOAuth = filtersBeforeOAuth;
this.filtersAfterOAuth = filtersAfterOAuth;
return this;
}
public FilterChainBuilder setAuthcFilter(String authcFilter) {
this.authcFilter = authcFilter;
return this;
}
public FilterChainBuilder setAuthzFilter(String authzFilter) {
this.authzFilter = authzFilter;
return this;
}
public FilterChainBuilder addRestPath(String path, String filters) {
return this.addPath(path, this.restFilters + ", " + filters);
}
public FilterChainBuilder addRestPath(String path) {
return this.addPath(path, this.restFilters);
}
public FilterChainBuilder addOAuthPath(String path, String oauthFilter) {
return this.addPath(path, filtersBeforeOAuth + ", " + oauthFilter + ", " + filtersAfterOAuth);
}
public FilterChainBuilder addProtectedRestPath(String path) {
return this.addRestPath(path, this.authcFilter + ", " + this.authzFilter);
}
public FilterChainBuilder addProtectedRestPath(String path, String filters) {
return this.addRestPath(path, authcFilter + ", " + authzFilter + ", " + filters);
}
public FilterChainBuilder addPath(String path, String filters) {
path = path.replaceAll("/+$", "");
this.filterChain.put(path, filters);
// If path ends with non wildcard character, need to add two paths -
// one without slash at the end and one with slash at the end, because
// both URLs like www.domain.com/myapp/mypath and www.domain.com/myapp/mypath/
// (note the slash at the end) are falling into the same method, but
// for filter chain these are different paths
if (!path.endsWith("*")) {
this.filterChain.put(path + "/", filters);
}
return this;
}
public Map<String, String> build() {
return filterChain;
}
}
}