package org.apereo.cas.impl.calcs;
import org.apereo.cas.authentication.Authentication;
import org.apereo.cas.authentication.adaptive.geo.GeoLocationRequest;
import org.apereo.cas.authentication.adaptive.geo.GeoLocationResponse;
import org.apereo.cas.authentication.adaptive.geo.GeoLocationService;
import org.apereo.cas.services.RegisteredService;
import org.apereo.cas.support.events.dao.CasEvent;
import org.apereo.cas.support.events.CasEventRepository;
import org.apereo.cas.web.support.WebUtils;
import org.apereo.inspektr.common.web.ClientInfoHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import javax.servlet.http.HttpServletRequest;
import java.math.BigDecimal;
import java.util.Collection;
/**
* This is {@link GeoLocationAuthenticationRequestRiskCalculator}.
*
* @author Misagh Moayyed
* @since 5.1.0
*/
public class GeoLocationAuthenticationRequestRiskCalculator extends BaseAuthenticationRequestRiskCalculator {
private static final Logger LOGGER = LoggerFactory.getLogger(GeoLocationAuthenticationRequestRiskCalculator.class);
/**
* Geolocation service.
*/
@Autowired
@Qualifier("geoLocationService")
protected GeoLocationService geoLocationService;
public GeoLocationAuthenticationRequestRiskCalculator(final CasEventRepository casEventRepository) {
super(casEventRepository);
}
@Override
protected BigDecimal calculateScore(final HttpServletRequest request, final Authentication authentication,
final RegisteredService service, final Collection<CasEvent> events) {
final GeoLocationRequest loc = WebUtils.getHttpServletRequestGeoLocation();
if (loc.isValid()) {
LOGGER.debug("Filtering authentication events for geolocation [{}]", loc);
final long count = events.stream().filter(e -> e.getGeoLocation().equals(loc)).count();
LOGGER.debug("Total authentication events found for [{}]: [{}]", loc, count);
if (count == events.size()) {
LOGGER.debug("Principal [{}] has always authenticated from [{}]", authentication.getPrincipal(), loc);
return LOWEST_RISK_SCORE;
}
return getFinalAveragedScore(count, events.size());
}
final String remoteAddr = ClientInfoHolder.getClientInfo().getClientIpAddress();
LOGGER.debug("Filtering authentication events for location based on ip [{}]", remoteAddr);
final GeoLocationResponse response = this.geoLocationService.locate(remoteAddr);
if (response != null) {
final long count = events.stream().filter(e -> e.getGeoLocation().equals(
new GeoLocationRequest(response.getLatitude(), response.getLongitude()))).count();
LOGGER.debug("Total authentication events found for location of [{}]: [{}]", remoteAddr, count);
if (count == events.size()) {
LOGGER.debug("Principal [{}] has always authenticated from [{}]", authentication.getPrincipal(), loc);
return LOWEST_RISK_SCORE;
}
return getFinalAveragedScore(count, events.size());
}
LOGGER.debug("Request does not contain enough geolocation data");
return HIGHEST_RISK_SCORE;
}
}