/* ==================================================================
* QuerySecurityAspect.java - Dec 18, 2012 4:32:34 PM
*
* Copyright 2007-2012 SolarNetwork.net Dev Team
*
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License as
* published by the Free Software Foundation; either version 2 of
* the License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
* 02111-1307 USA
* ==================================================================
*/
package net.solarnetwork.central.query.aop;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.security.core.Authentication;
import org.springframework.util.AntPathMatcher;
import net.solarnetwork.central.datum.domain.AggregateGeneralNodeDatumFilter;
import net.solarnetwork.central.datum.domain.DatumFilter;
import net.solarnetwork.central.datum.domain.DatumFilterCommand;
import net.solarnetwork.central.datum.domain.GeneralNodeDatumFilter;
import net.solarnetwork.central.datum.domain.NodeDatumFilter;
import net.solarnetwork.central.domain.Filter;
import net.solarnetwork.central.domain.SortDescriptor;
import net.solarnetwork.central.query.biz.QueryBiz;
import net.solarnetwork.central.security.AuthorizationException;
import net.solarnetwork.central.security.SecurityPolicy;
import net.solarnetwork.central.security.SecurityPolicyEnforcer;
import net.solarnetwork.central.security.SecurityUtils;
import net.solarnetwork.central.user.dao.UserNodeDao;
import net.solarnetwork.central.user.support.AuthorizationSupport;
/**
* Security enforcing AOP aspect for {@link QueryBiz}.
*
* @author matt
* @version 1.5
*/
@Aspect
public class QuerySecurityAspect extends AuthorizationSupport {
public static final String FILTER_KEY_NODE_ID = "nodeId";
public static final String FILTER_KEY_NODE_IDS = "nodeIds";
private Set<String> nodeIdNotRequiredSet;
/**
* Constructor.
*
* @param userNodeDao
* the UserNodeDao
*/
public QuerySecurityAspect(UserNodeDao userNodeDao) {
super(userNodeDao);
AntPathMatcher antMatch = new AntPathMatcher();
antMatch.setCachePatterns(false);
antMatch.setCaseSensitive(true);
setPathMatcher(antMatch);
}
@Pointcut("bean(aop*) && execution(* net.solarnetwork.central.query.biz.*.getReportableInterval(..)) && args(nodeId,sourceId,..)")
public void nodeReportableInterval(Long nodeId, String sourceId) {
}
@Pointcut("bean(aop*) && execution(* net.solarnetwork.central.query.biz.*.getAvailableSources(..)) && args(nodeId,..)")
public void nodeReportableSources(Long nodeId) {
}
@Pointcut("bean(aop*) && execution(* net.solarnetwork.central.query.biz.*.getMostRecentWeatherConditions(..)) && args(nodeId,..)")
public void nodeMostRecentWeatherConditions(Long nodeId) {
}
@Pointcut("bean(aop*) && execution(* net.solarnetwork.central.query.biz.*.findFiltered*(..)) && args(filter,..)")
public void nodeDatumFilter(Filter filter) {
}
@Around(value = "nodeDatumFilter(filter)")
public Object userNodeFilterAccessCheck(ProceedingJoinPoint pjp, Filter filter) throws Throwable {
final boolean isQueryBiz = (pjp.getTarget() instanceof QueryBiz);
final SecurityPolicy policy = getActiveSecurityPolicy();
if ( policy != null && policy.getSourceIds() != null && !policy.getSourceIds().isEmpty()
&& filter instanceof GeneralNodeDatumFilter
&& ((GeneralNodeDatumFilter) filter).getSourceId() == null ) {
// no source IDs provided, but policy restricts source IDs.
// restrict the filter to the available source IDs if using a DatumFilterCommand,
// and let call to userNodeAccessCheck later on filter out restricted values
if ( isQueryBiz && filter instanceof DatumFilterCommand ) {
QueryBiz target = (QueryBiz) pjp.getTarget();
DatumFilterCommand f = (DatumFilterCommand) filter;
Set<String> availableSources = target.getAvailableSources(f.getNodeId(),
f.getStartDate(), f.getEndDate());
if ( availableSources != null && !availableSources.isEmpty() ) {
f.setSourceIds(availableSources.toArray(new String[availableSources.size()]));
}
}
}
Filter f = userNodeAccessCheck(filter);
if ( f == filter ) {
return pjp.proceed();
}
// if an aggregate was injected (enforced) on the filter, then the join point method
// might need to change to an aggregate one, e.g. from findFilteredGeneralNodeDatum
// to findFilteredAggregateGeneralNodeDatum. This _could_ break the calling code if
// it is expecting a specific result type, but in many cases it is simply returning
// the result as JSON to some HTTP client and the difference does not matter.
if ( isQueryBiz && f instanceof AggregateGeneralNodeDatumFilter
&& ((AggregateGeneralNodeDatumFilter) f).getAggregation() != null
&& pjp.getSignature().getName().equals("findFilteredGeneralNodeDatum") ) {
// redirect this to findFilteredAggregateGeneralNodeDatum
QueryBiz target = (QueryBiz) pjp.getTarget();
Object[] args = pjp.getArgs();
@SuppressWarnings("unchecked")
List<SortDescriptor> sorts = (List<SortDescriptor>) args[1];
return target.findFilteredAggregateGeneralNodeDatum((AggregateGeneralNodeDatumFilter) f,
sorts, (Integer) args[2], (Integer) args[3]);
}
Object[] args = pjp.getArgs();
args[0] = f;
return pjp.proceed(args);
}
/**
* Enforce node ID and source ID policy restrictions when requesting the
* available sources of a node.
*
* First the node ID is verified. Then, for all returned source ID values,
* if the active policy has no source ID restrictions return all values,
* otherwise remove any value not included in the policy.
*
* @param pjp
* The join point.
* @param nodeId
* The node ID.
* @return The set of String source IDs.
* @throws Throwable
*/
@Around("nodeReportableSources(nodeId)")
public Object reportableSourcesAccessCheck(ProceedingJoinPoint pjp, Long nodeId) throws Throwable {
// verify node ID
requireNodeReadAccess(nodeId);
// verify source IDs in result
@SuppressWarnings("unchecked")
Set<String> result = (Set<String>) pjp.proceed();
if ( result == null || result.isEmpty() ) {
return result;
}
SecurityPolicy policy = getActiveSecurityPolicy();
if ( policy == null ) {
return result;
}
Set<String> allowedSourceIds = policy.getSourceIds();
if ( allowedSourceIds == null || allowedSourceIds.isEmpty() ) {
return result;
}
Authentication authentication = SecurityUtils.getCurrentAuthentication();
Object principal = (authentication != null ? authentication.getPrincipal() : null);
SecurityPolicyEnforcer enforcer = new SecurityPolicyEnforcer(policy, principal, null,
getPathMatcher());
try {
String[] resultSourceIds = enforcer
.verifySourceIds(result.toArray(new String[result.size()]));
result = new LinkedHashSet<String>(Arrays.asList(resultSourceIds));
} catch ( AuthorizationException e ) {
// ignore, and just map to empty set
result = Collections.emptySet();
}
return result;
}
/**
* Enforce node ID and source ID policy restrictions when requesting a
* reportable interval.
*
* If the active policy has source ID restrictions, then if no
* {@code sourceId} is provided fill in the first available value from the
* policy. Otherwise, if {@code sourceId} is provided, check that value is
* allowed by the policy.
*
* @param pjp
* The join point.
* @param nodeId
* The node ID.
* @param sourceId
* The source ID, or {@code null}.
* @return The reportable interval.
* @throws Throwable
* If any error occurs.
*/
@Around("nodeReportableInterval(nodeId, sourceId)")
public Object reportableIntervalAccessCheck(ProceedingJoinPoint pjp, Long nodeId, String sourceId)
throws Throwable {
// verify node ID
requireNodeReadAccess(nodeId);
// now verify source ID
SecurityPolicy policy = getActiveSecurityPolicy();
if ( policy == null ) {
return pjp.proceed();
}
Set<String> allowedSourceIds = policy.getSourceIds();
if ( allowedSourceIds != null && !allowedSourceIds.isEmpty() ) {
Authentication authentication = SecurityUtils.getCurrentAuthentication();
Object principal = (authentication != null ? authentication.getPrincipal() : null);
if ( sourceId == null ) {
// force the first allowed source ID
sourceId = allowedSourceIds.iterator().next();
log.info("Access RESTRICTED to source {} for {}", sourceId, principal);
Object[] args = pjp.getArgs();
args[1] = sourceId;
return pjp.proceed(args);
} else if ( !allowedSourceIds.contains(sourceId) ) {
log.warn("Access DENIED to source {} for {}", sourceId, principal);
throw new AuthorizationException(AuthorizationException.Reason.ACCESS_DENIED, sourceId);
}
}
return pjp.proceed();
}
/**
* Allow the current user (or current node) access to node data.
*
* @param nodeId
* the ID of the node to verify
*/
@Before("nodeMostRecentWeatherConditions(nodeId)")
public void userNodeAccessCheck(Long nodeId) {
if ( nodeId == null ) {
return;
}
requireNodeReadAccess(nodeId);
}
/**
* Enforce security policies on a {@link Filter}.
*
* @param filter
* The filter to verify.
* @return A possibly modified filter based on security policies.
* @throws AuthorizationException
* if any authorization error occurs
*/
public <T extends Filter> T userNodeAccessCheck(T filter) {
Long[] nodeIds = null;
boolean nodeIdRequired = true;
if ( filter instanceof NodeDatumFilter ) {
NodeDatumFilter cmd = (NodeDatumFilter) filter;
nodeIdRequired = isNodeIdRequired(cmd);
if ( nodeIdRequired ) {
nodeIds = cmd.getNodeIds();
}
} else {
nodeIdRequired = false;
Map<String, ?> f = filter.getFilter();
if ( f.containsKey(FILTER_KEY_NODE_IDS) ) {
nodeIds = getLongArrayParameter(f, FILTER_KEY_NODE_IDS);
} else if ( f.containsKey(FILTER_KEY_NODE_ID) ) {
nodeIds = getLongArrayParameter(f, FILTER_KEY_NODE_ID);
}
}
if ( !nodeIdRequired ) {
return filter;
}
if ( nodeIds == null || nodeIds.length < 1 ) {
log.warn("Access DENIED; no node ID provided");
throw new AuthorizationException(AuthorizationException.Reason.ACCESS_DENIED, null);
}
for ( Long nodeId : nodeIds ) {
userNodeAccessCheck(nodeId);
}
return policyEnforcerCheck(filter);
}
/**
* Check if a node ID is required of a filter instance. This will return
* <em>true</em> unless the {@link #getNodeIdNotRequiredSet()} set contains
* the value returned by {@link DatumFilter#getType()}.
*
* @param filter
* the filter
* @return <em>true</em> if a node ID is required for the given filter
*/
private boolean isNodeIdRequired(DatumFilter filter) {
final String type = (filter == null || filter.getType() == null ? null
: filter.getType().toLowerCase());
return (nodeIdNotRequiredSet == null || !nodeIdNotRequiredSet.contains(type));
}
private Long[] getLongArrayParameter(final Map<String, ?> map, final String key) {
Long[] result = null;
if ( map.containsKey(key) ) {
Object o = map.get(key);
if ( o instanceof Long[] ) {
result = (Long[]) o;
} else if ( o instanceof Long ) {
result = new Long[] { (Long) o };
}
}
return result;
}
public Set<String> getNodeIdNotRequiredSet() {
return nodeIdNotRequiredSet;
}
public void setNodeIdNotRequiredSet(Set<String> nodeIdNotRequiredSet) {
this.nodeIdNotRequiredSet = nodeIdNotRequiredSet;
}
}