/*
* Seldon -- open source prediction engine
* =======================================
*
* Copyright 2011-2015 Seldon Technologies Ltd and Rummble Ltd (http://www.seldon.io/)
*
* ********************************************************************************************
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* ********************************************************************************************
*/
package io.seldon.api.service;
import io.seldon.api.APIException;
import io.seldon.api.Constants;
import io.seldon.api.jdo.Consumer;
import io.seldon.api.jdo.ConsumerPeer;
import io.seldon.api.jdo.Token;
import io.seldon.api.jdo.TokenPeer;
import io.seldon.api.resource.ScopedConsumerBean;
import io.seldon.api.resource.TokenBean;
import io.seldon.api.state.ClientConfigHandler;
import io.seldon.api.state.NewClientListener;
import io.seldon.db.jdo.DbConfigHandler;
import io.seldon.db.jdo.DbConfigListener;
import io.seldon.db.jdo.JDOFactory;
import io.seldon.memcache.MemCacheKeys;
import io.seldon.memcache.MemCachePeer;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.annotation.PostConstruct;
import javax.servlet.http.HttpServletRequest;
import org.apache.commons.lang.StringUtils;
import org.apache.log4j.Logger;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
/**
* @author claudio
*/
@Service
public class AuthorizationServer implements DbConfigListener {
private final static Pattern pattern = Pattern.compile("/S+");
private final static Logger logger = Logger.getLogger(AuthorizationServer.class);
private final static Map<String, ScopedConsumerBean> consumerCache = new ConcurrentHashMap<>();
public static final int CONSUMER_REFRESH_INTERVAL = 300000;
private JDOFactory jdoFactory;
private ConsumerPeer consumerPeer;
private TokenPeer tokenPeer;
@Autowired
public AuthorizationServer(JDOFactory jdoFactory, ConsumerPeer consumerPeer, TokenPeer tokenPeer,
DbConfigHandler dbConfigHandler){
this.jdoFactory = jdoFactory;
this.consumerPeer = consumerPeer;
this.tokenPeer = tokenPeer;
dbConfigHandler.addDbConfigListener(this);
}
//METHODS
public ScopedConsumerBean getConsumer(HttpServletRequest request) throws APIException {
if (request == null) {
throw new APIException(APIException.NOT_VALID_CONNECTION);
}
String consumerKey = request.getParameter(Constants.CONSUMER_KEY);
return retrieveConsumerBean(consumerKey);
}
private ScopedConsumerBean retrieveConsumerBean(String consumerKey) {
return retrieveConsumerBean(consumerKey, true);
}
private ScopedConsumerBean retrieveConsumerBean(String consumerKey, boolean allowCached) {
if (StringUtils.isBlank(consumerKey)) {
throw new APIException(APIException.NOT_SPECIFIED_CONS_KEY);
}
final ScopedConsumerBean cachedConsumerBean = consumerCache.get(consumerKey);
if (cachedConsumerBean != null && allowCached) {
return cachedConsumerBean;
}
final Consumer consumer = consumerPeer.findConsumer(consumerKey);
if (consumer == null || StringUtils.isNotBlank(consumer.getSecret())) {
// As a precaution: This method is only valid for consumers set up without a secret.
throw new APIException(APIException.NOT_AUTHORIZED_CONS);
}
if (!consumer.isActive()) {
throw new APIException(APIException.NOT_AUTHORIZED_CONS);
}
final ScopedConsumerBean consumerBean = new ScopedConsumerBean(consumer.getShort_name(), consumer.getScope(), consumer.getUrl());
consumerCache.put(consumerKey, consumerBean);
return consumerBean;
}
@Scheduled(fixedDelay = CONSUMER_REFRESH_INTERVAL)
public void refreshConsumerCache() {
try {
logger.info("Refreshing consumer cache (" + consumerCache.size() + " entries).");
for (String consumerKey : consumerCache.keySet()) {
try {
final ScopedConsumerBean consumerBean = retrieveConsumerBean(consumerKey, false);
consumerCache.put(consumerKey, consumerBean);
} catch (Exception e) {
logger.warn("Problem refreshing consumer cache entry for: " + consumerKey, e);
consumerCache.remove(consumerKey);
}
}
} finally {
jdoFactory.cleanupPM();
}
}
public Token getToken(HttpServletRequest req) throws APIException {
return getToken(req,true);
}
/**
* @return token
* Create and return an access token for a specific consumer.
* It generates a new token even if the consumer has already a valid token active
*/
public Token getToken(HttpServletRequest req,boolean makeTransient) throws APIException {
//init
String consumerKey = null;
String consumerSecret = null;
Token token = null;
//if request is null
if(req == null) {
throw new APIException(APIException.NOT_VALID_CONNECTION);
}
//retrieve from the request the consumer key and the consumer secret
consumerKey = req.getParameter(Constants.CONSUMER_KEY);
consumerSecret = req.getParameter(Constants.CONSUMER_SECRET);
//check if the consumerId is set
if(consumerKey == null || consumerKey.trim().equals("")) {
throw new APIException(APIException.NOT_SPECIFIED_CONS_KEY);
}
//check if the consumerSecret is set
if(consumerSecret == null || consumerSecret.trim().equals("")) {
throw new APIException(APIException.NOT_SPECIFIED_CONS_SECRET);
}
//check if the consumer credentials are valid
Consumer consumer = isConsumerValid(consumerKey,consumerSecret);
//check if the consumer is secure and request is TLS
if(consumer.isSecure() && !req.isSecure()) {
throw new APIException(APIException.NOT_SSL_CONN);
}
token = new Token(consumer);
//make the token persistent
tokenPeer.saveToken(token);
if (makeTransient)
{
//RAS-34 (ensure token is transient so JDO doesn't try to refresh against the read-replica the fields
//Maybe a better solution?
jdoFactory.getPersistenceManager(Constants.API_DB).makeTransient(token);
}
return token;
}
/**
* @param consumerId
* @param consumerSecret
* @return boolean
* Check if the pair consumerId/consumerSecret is valid
*/
public Consumer isConsumerValid(String consumerId, String consumerSecret) throws APIException {
Consumer consumer = null;
if(consumerId == null || consumerId.trim().equals("") ||consumerSecret == null || consumerSecret.trim().equals("")) {
throw new APIException(APIException.NOT_AUTHORIZED_CONS);
}
//if consumer key does not exists
consumer = consumerPeer.findConsumer(consumerId);
if(consumer == null) {
throw new APIException(APIException.NOT_VALID_KEY_CONS);
}
//if consumer secret is not valid
if(!consumer.getSecret().trim().equals(consumerSecret.trim())) {
throw new APIException(APIException.NOT_VALID_SECRET_CONS);
}
//if consumer is not active
if(!consumer.isActive()) {
throw new APIException(APIException.NOT_AUTHORIZED_CONS);
}
return consumer;
}
public TokenBean getTokenBeanFromKey(String tokenKey)
{
TokenBean res = (TokenBean) MemCachePeer.get(MemCacheKeys.getTokenBeanKey(tokenKey));
if(res==null) {
Token t = tokenPeer.findToken(tokenKey);
//if token not existing
if(t == null) {
throw new APIException(APIException.NOT_VALID_TOKEN_KEY);
}
//if token expired or no longer valid
if(tokenPeer.isExpired(t)) {
throw new APIException(APIException.NOT_VALID_TOKEN_EXPIRED);
}
res = new TokenBean(t);
}
return res;
}
public TokenBean isTokenValid(HttpServletRequest req) throws APIException {
//init
String tokenKey = null;
boolean safe = true;
//if request is null
if(req == null) {
throw new APIException(APIException.NOT_VALID_CONNECTION);
}
//retrieve from the request the consumer key and the consumer secret
//from heder
String authorization = req.getHeader(Constants.AUTHORIZATION);
if(authorization!=null && !authorization.trim().equals("")) {
Matcher matcher = pattern.matcher(authorization);
if(matcher.find()) {
tokenKey = matcher.group(1);
}
}
//try to retrieve the token from the parameters
if(tokenKey == null || tokenKey.trim().equals("")) {
tokenKey = req.getParameter(Constants.OAUTH_TOKEN);
//token not sent in a safe way
safe = false;
}
//TODO
//check presence of token in the body
//if tokenKey empty
if(tokenKey == null || tokenKey.trim().equals("")) {
throw new APIException(APIException.NOT_SPECIFIED_TOKEN);
}
//check if token is in memcached
//if is in memcached replicate the expired and security controls?
TokenBean res = (TokenBean) MemCachePeer.get(MemCacheKeys.getTokenBeanKey(tokenKey));
if(res==null) {
Token t = tokenPeer.findToken(tokenKey);
//if token not existing
if(t == null) {
throw new APIException(APIException.NOT_VALID_TOKEN_KEY);
}
if(t.getConsumer().isSecure() && !safe) {
throw new APIException(APIException.NOT_SECURE_TOKEN);
}
//if token expired or no longer valid
if(tokenPeer.isExpired(t)) {
throw new APIException(APIException.NOT_VALID_TOKEN_EXPIRED);
}
res = new TokenBean(t);
}
return res;
}
public void clientDeleted(String client) {
// ignore
}
@Override
public void dbConfigInitialised(String client) {
logger.info("Reloading consumer table due to new client " + client);
refreshConsumerCache();
}
}