package com.monkeyk.os.oauth.shiro;
import com.monkeyk.os.domain.oauth.AccessToken;
import com.monkeyk.os.domain.oauth.ClientDetails;
import com.monkeyk.os.infrastructure.shiro.MkkJdbcRealm;
import com.monkeyk.os.service.OAuthRSService;
import org.apache.shiro.authc.*;
import org.apache.shiro.authz.AuthorizationInfo;
import org.apache.shiro.authz.SimpleAuthorizationInfo;
import org.apache.shiro.subject.PrincipalCollection;
import org.apache.shiro.util.JdbcUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Set;
/**
* 2015/9/29
* <p/>
* Ext. {@link org.apache.shiro.realm.jdbc.JdbcRealm}
*
* @author Shengzhao Li
* @see org.apache.shiro.realm.jdbc.JdbcRealm
*/
public class OAuth2JdbcRealm extends MkkJdbcRealm {
private static final Logger LOG = LoggerFactory.getLogger(OAuth2JdbcRealm.class);
private OAuthRSService rsService;
public OAuth2JdbcRealm() {
super();
setAuthenticationTokenClass(OAuth2Token.class);
}
private void validateToken(String token, AccessToken accessToken) throws OAuth2AuthenticationException {
if (accessToken == null) {
LOG.debug("Invalid access_token: {}, because it is null", token);
throw new OAuth2AuthenticationException("Invalid access_token: " + token);
}
if (accessToken.tokenExpired()) {
LOG.debug("Invalid access_token: {}, because it is expired", token);
throw new OAuth2AuthenticationException("Invalid access_token: " + token);
}
}
private void validateClientDetails(String token, AccessToken accessToken, ClientDetails clientDetails) throws OAuth2AuthenticationException {
if (clientDetails == null || clientDetails.archived()) {
LOG.debug("Invalid ClientDetails: {} by client_id: {}, it is null or archived", clientDetails, accessToken.clientId());
throw new OAuth2AuthenticationException("Invalid client by token: " + token);
}
}
@Override
public AuthenticationInfo doGetAuthenticationInfo(AuthenticationToken token) throws AuthenticationException {
OAuth2Token upToken = (OAuth2Token) token;
final String accessToken = (String) upToken.getCredentials();
if (StringUtils.isEmpty(accessToken)) {
throw new OAuth2AuthenticationException("Invalid access_token: " + accessToken);
}
//Validate access token
AccessToken aToken = rsService.loadAccessTokenByTokenId(accessToken);
validateToken(accessToken, aToken);
//Validate client details by resource-id
final ClientDetails clientDetails = rsService.loadClientDetails(aToken.clientId(), upToken.getResourceId());
validateClientDetails(accessToken, aToken, clientDetails);
String username = aToken.username();
// Null username is invalid
if (username == null) {
throw new AccountException("Null usernames are not allowed by this realm.");
}
return new SimpleAuthenticationInfo(username, accessToken, getName());
}
@Override
public AuthorizationInfo doGetAuthorizationInfo(PrincipalCollection principals) {
//null usernames are invalid
if (principals == null) {
throw new OAuth2AuthenticationException("PrincipalCollection method argument cannot be null.");
}
String username = (String) getAvailablePrincipal(principals);
Connection conn = null;
Set<String> roleNames = null;
Set<String> permissions = null;
try {
conn = dataSource.getConnection();
// Retrieve roles and permissions from database
roleNames = getRoleNamesForUser(conn, username);
if (permissionsLookupEnabled) {
permissions = getPermissions(conn, username, roleNames);
}
} catch (SQLException e) {
final String message = "There was a SQL error while authorizing user [" + username + "]";
if (LOG.isErrorEnabled()) {
LOG.error(message, e);
}
// Rethrow any SQL errors as an authorization exception
throw new OAuth2AuthenticationException(message, e);
} finally {
JdbcUtils.closeConnection(conn);
}
SimpleAuthorizationInfo info = new SimpleAuthorizationInfo(roleNames);
info.setStringPermissions(permissions);
return info;
}
public void setRsService(OAuthRSService rsService) {
this.rsService = rsService;
}
@Override
public void afterPropertiesSet() throws Exception {
super.afterPropertiesSet();
Assert.notNull(this.rsService);
}
}