/******************************************************************************* * Cloud Foundry * Copyright (c) [2009-2016] Pivotal Software, Inc. All Rights Reserved. * * This product is licensed to you under the Apache License, Version 2.0 (the "License"). * You may not use this product except in compliance with the License. * * This product includes a number of subcomponents with * separate copyright notices and license terms. Your use of these * subcomponents is subject to the terms and conditions of the * subcomponent's license, as noted in the LICENSE file. *******************************************************************************/ package org.cloudfoundry.identity.uaa.config; import static java.lang.annotation.RetentionPolicy.RUNTIME; import static org.junit.Assert.assertEquals; import java.lang.annotation.Documented; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.Target; import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; import java.util.Map; import javax.validation.Constraint; import javax.validation.ConstraintValidator; import javax.validation.ConstraintValidatorContext; import javax.validation.Payload; import javax.validation.constraints.NotNull; import org.cloudfoundry.identity.uaa.config.YamlBindingTests.OAuthConfiguration.OAuthConfigurationValidator; import org.cloudfoundry.identity.uaa.impl.config.YamlPropertiesFactoryBean; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.springframework.beans.BeanWrapperImpl; import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.PropertyValue; import org.springframework.context.support.StaticMessageSource; import org.springframework.core.io.ByteArrayResource; import org.springframework.util.StringUtils; import org.springframework.validation.BindingResult; import org.springframework.validation.DataBinder; import org.springframework.validation.FieldError; import org.springframework.validation.beanvalidation.LocalValidatorFactoryBean; /** * @author Dave Syer * */ public class YamlBindingTests { @Rule public ExpectedException expected = ExpectedException.none(); @Test public void testBindString() { VanillaTarget target = new VanillaTarget(); bind(target, "foo: bar"); assertEquals("bar", target.getFoo()); } @Test public void testBindNumber() { VanillaTarget target = new VanillaTarget(); bind(target, "foo: bar\nvalue: 123"); assertEquals(123, target.getValue()); } @Test public void testSimpleValidation() { ValidatedTarget target = new ValidatedTarget(); BindingResult result = bind(target, ""); assertEquals(1, result.getErrorCount()); } @Test public void testRequiredFieldsValidation() { TargetWithValidatedMap target = new TargetWithValidatedMap(); BindingResult result = bind(target, "info:\n foo: bar"); assertEquals(2, result.getErrorCount()); for (FieldError error : result.getFieldErrors()) { System.err.println(new StaticMessageSource().getMessage(error, Locale.getDefault())); } } @Test public void testBindNested() { TargetWithNestedObject target = new TargetWithNestedObject(); bind(target, "nested:\n foo: bar\n value: 123"); assertEquals(123, target.getNested().getValue()); } @Test public void testBindNestedMap() { TargetWithNestedMap target = new TargetWithNestedMap(); bind(target, "nested:\n foo: bar\n value: 123"); assertEquals(123, target.getNested().get("value")); } @Test public void testBindNestedMapBracketReferenced() { TargetWithNestedMap target = new TargetWithNestedMap(); bind(target, "nested[foo]: bar\nnested[value]: 123"); assertEquals(123, target.getNested().get("value")); } @SuppressWarnings("unchecked") @Test public void testBindDoubleNestedMap() { TargetWithNestedMap target = new TargetWithNestedMap(); bind(target, "nested:\n foo: bar\n oof:\n spam: bucket\n value: 123"); assertEquals(123, ((Map<String, Object>) target.getNested().get("oof")).get("value")); } @Test public void testBindErrorTypeMismatch() { VanillaTarget target = new VanillaTarget(); BindingResult result = bind(target, "foo: bar\nvalue: foo"); assertEquals(1, result.getErrorCount()); } @Test public void testBindErrorNotWritable() { expected.expectMessage("property 'spam'"); expected.expectMessage("not writable"); VanillaTarget target = new VanillaTarget(); BindingResult result = bind(target, "spam: bar\nvalue: 123"); assertEquals(1, result.getErrorCount()); } private BindingResult bind(Object target, String values) { YamlPropertiesFactoryBean factory = new YamlPropertiesFactoryBean(); factory.setResources(new ByteArrayResource[] { new ByteArrayResource(values.getBytes()) }); Map<Object, Object> map = factory.getObject(); DataBinder binder = new DataBinder(target) { @Override protected void doBind(MutablePropertyValues mpvs) { modifyProperties(mpvs, getTarget()); super.doBind(mpvs); } private void modifyProperties(MutablePropertyValues mpvs, Object target) { List<PropertyValue> list = mpvs.getPropertyValueList(); BeanWrapperImpl bw = new BeanWrapperImpl(target); for (int i = 0; i < list.size(); i++) { PropertyValue pv = list.get(i); String name = pv.getName(); StringBuilder builder = new StringBuilder(); for (String key : StringUtils.delimitedListToStringArray(name, ".")) { if (builder.length() != 0) { builder.append("."); } builder.append(key); String base = builder.toString(); Class<?> type = bw.getPropertyType(base); if (type != null && Map.class.isAssignableFrom(type)) { String suffix = name.substring(base.length()); Map<String, Object> nested = new LinkedHashMap<String, Object>(); if (bw.getPropertyValue(base) != null) { @SuppressWarnings("unchecked") Map<String, Object> existing = (Map<String, Object>) bw.getPropertyValue(base); nested = existing; } else { bw.setPropertyValue(base, nested); } Map<String, Object> value = nested; String[] tree = StringUtils.delimitedListToStringArray(suffix, "."); for (int j = 1; j < tree.length - 1; j++) { String subtree = tree[j]; value.put(subtree, nested); value = nested; } String refName = base + suffix.replaceAll("\\.([a-zA-Z0-9]*)", "[$1]"); mpvs.setPropertyValueAt(new PropertyValue(refName, pv.getValue()), i); break; } } } } }; binder.setIgnoreUnknownFields(false); LocalValidatorFactoryBean validatorFactoryBean = new LocalValidatorFactoryBean(); validatorFactoryBean.afterPropertiesSet(); binder.setValidator(validatorFactoryBean); binder.bind(new MutablePropertyValues(map)); binder.validate(); return binder.getBindingResult(); } @Documented @Target({ ElementType.TYPE }) @Retention(RUNTIME) @Constraint(validatedBy = OAuthConfigurationValidator.class) public @interface ValidOAuthConfiguration { } @ValidOAuthConfiguration public static class OAuthConfiguration { private Client client; private Map<String, OAuthClient> clients; public Client getClient() { return client; } public void setClient(Client client) { this.client = client; } public Map<String, OAuthClient> getClients() { return clients; } public void setClients(Map<String, OAuthClient> clients) { this.clients = clients; } public static class Client { private List<String> autoapprove; public List<String> getAutoapprove() { return autoapprove; } public void setAutoapprove(List<String> autoapprove) { this.autoapprove = autoapprove; } } public static class OAuthClient { private String id; public String getId() { return id; } public void setId(String id) { this.id = id; } } public static class OAuthConfigurationValidator implements ConstraintValidator<ValidOAuthConfiguration, OAuthConfiguration> { @Override public void initialize(ValidOAuthConfiguration constraintAnnotation) { } @Override public boolean isValid(OAuthConfiguration value, ConstraintValidatorContext context) { boolean valid = true; if (value.client != null && value.client.autoapprove != null) { if (value.clients != null) { context.buildConstraintViolationWithTemplate( "Please use oauth.clients to specifiy autoapprove not client.autoapprove") .addConstraintViolation(); valid = false; } } return valid; } } } @Documented @Target({ ElementType.FIELD }) @Retention(RUNTIME) @Constraint(validatedBy = RequiredKeysValidator.class) public @interface RequiredKeys { String[] value(); String message() default "Required keys are not provided for field"; Class<?>[] groups() default {}; Class<? extends Payload>[] payload() default {}; } public static class RequiredKeysValidator implements ConstraintValidator<RequiredKeys, Map<String, Object>> { private String[] requiredKeys; @Override public void initialize(RequiredKeys constraintAnnotation) { requiredKeys = constraintAnnotation.value(); } @Override public boolean isValid(Map<String, Object> value, ConstraintValidatorContext context) { boolean valid = true; for (String key : requiredKeys) { if (!value.containsKey(key)) { context.buildConstraintViolationWithTemplate("Missing key ''" + key + "''") .addConstraintViolation(); valid = false; } } return valid; } } public static class TargetWithValidatedMap { @RequiredKeys({ "foo", "baz" }) private Map<String, Object> info; public Map<String, Object> getInfo() { return info; } public void setInfo(Map<String, Object> nested) { this.info = nested; } } public static class TargetWithNestedMap { private Map<String, Object> nested; public Map<String, Object> getNested() { return nested; } public void setNested(Map<String, Object> nested) { this.nested = nested; } } public static class TargetWithNestedObject { private VanillaTarget nested; public VanillaTarget getNested() { return nested; } public void setNested(VanillaTarget nested) { this.nested = nested; } } public static class VanillaTarget { private String foo; private int value; public int getValue() { return value; } public void setValue(int value) { this.value = value; } public String getFoo() { return foo; } public void setFoo(String foo) { this.foo = foo; } } public static class ValidatedTarget { @NotNull private String foo; public String getFoo() { return foo; } public void setFoo(String foo) { this.foo = foo; } } }