/*******************************************************************************
* Copyright 2015-2016 - CNRS (Centre National de Recherche Scientifique)
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*******************************************************************************/
package eu.project.ttc.test.func;
import static org.assertj.core.api.Assertions.assertThat;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import org.assertj.core.api.AbstractAssert;
import org.assertj.core.api.AbstractIterableAssert;
import org.assertj.core.util.Lists;
import com.google.common.base.Joiner;
import com.google.common.base.Objects;
import com.google.common.collect.Sets;
import com.google.common.primitives.Ints;
import eu.project.ttc.models.Term;
import eu.project.ttc.models.TermIndex;
import eu.project.ttc.models.TermVariation;
import eu.project.ttc.models.VariationType;
import eu.project.ttc.utils.TermIndexUtils;
public class TermIndexAssert extends AbstractAssert<TermIndexAssert, TermIndex> {
public TermIndexAssert(TermIndex actual) {
super(actual, TermIndexAssert.class);
}
public TermIndexAssert hasSize(int expected) {
if(actual.getTerms().size() != expected)
failWithMessage("Expected size was <%s>, but actual size is <%s>.",
expected, actual.getTerms().size());
return this;
}
public TermIndexAssert containsTerm(String expectedTerm, int frequency) {
for(Term t:actual.getTerms()) {
if(t.getGroupingKey().equals(expectedTerm)) {
if(t.getFrequency() != frequency)
failWithMessage("Expected frequency for term %s was <%s>, but actually is: <%s>.",
expectedTerm,
frequency,
t.getFrequency());
return this;
}
}
failWithMessage("No such term <%s> found in term index.", expectedTerm);
return this;
}
public TermIndexAssert containsTerm(String expectedTerm) {
for(Term t:actual.getTerms()) {
if(t.getGroupingKey().equals(expectedTerm))
return this;
}
failWithMessage("No such term <%s> found in term index.", expectedTerm);
return this;
}
public TermIndexAssert containsVariation(String baseGroupingKey, VariationType type, String variantGroupingKey) {
if(failToFindTerms(baseGroupingKey, variantGroupingKey))
return this;
List<TermVariation> potentialVariations = Lists.newArrayList();
Set<TermVariation> sameType = Sets.newHashSet();
for(TermVariation tv:getVariations()) {
if(tv.getBase().getGroupingKey().equals(baseGroupingKey)
&& tv.getVariant().getGroupingKey().equals(variantGroupingKey)) {
potentialVariations.add(tv);
if(tv.getVariationType() == type)
return this;
}
if(type == tv.getVariationType())
sameType.add(tv);
}
potentialVariations.addAll(Sets.newHashSet(actual.getTermByGroupingKey(baseGroupingKey).getVariations(type)));
potentialVariations.addAll(Sets.newHashSet(actual.getTermByGroupingKey(variantGroupingKey).getBases(type)));
potentialVariations.addAll(Sets.newHashSet(actual.getTermByGroupingKey(baseGroupingKey).getVariations()));
potentialVariations.addAll(Sets.newHashSet(actual.getTermByGroupingKey(variantGroupingKey).getBases()));
potentialVariations.addAll(sameType);
failWithMessage("No such variation <%s--%s--%s> found in term index. Closed variations: <%s>",
baseGroupingKey, type, variantGroupingKey,
Joiner.on(", ").join(potentialVariations.subList(0, Ints.min(10, potentialVariations.size())))
);
return this;
}
private boolean failToFindTerms(String... groupingKeys) {
boolean failed = false;
for(String gKey:groupingKeys) {
if(actual.getTermByGroupingKey(gKey) == null) {
failed = true;
failWithMessage("Could not find term <%s> in termIndex", gKey);
}
}
return failed;
}
public TermIndexAssert containsVariation(String baseGroupingKey, VariationType type, String variantGroupingKey, Object info) {
if(failToFindTerms(baseGroupingKey, variantGroupingKey))
return this;
List<TermVariation> potentialVariations = Lists.newArrayList();
Set<TermVariation> sameType = Sets.newHashSet();
for(TermVariation tv:getVariations()) {
if(tv.getBase().getGroupingKey().equals(baseGroupingKey)
&& tv.getVariant().getGroupingKey().equals(variantGroupingKey)) {
potentialVariations.add(tv);
if(tv.getVariationType() == type && Objects.equal(info, tv.getInfo()))
return this;
}
if(type == tv.getVariationType())
sameType.add(tv);
}
potentialVariations.addAll(Sets.newHashSet(actual.getTermByGroupingKey(baseGroupingKey).getVariations(type)));
potentialVariations.addAll(Sets.newHashSet(actual.getTermByGroupingKey(variantGroupingKey).getBases(type)));
potentialVariations.addAll(Sets.newHashSet(actual.getTermByGroupingKey(baseGroupingKey).getVariations()));
potentialVariations.addAll(Sets.newHashSet(actual.getTermByGroupingKey(variantGroupingKey).getBases()));
potentialVariations.addAll(sameType);
failWithMessage("No such variation <%s--%s[%s]--%s> found in term index. Closed variations: <%s>",
baseGroupingKey, type,
info,
variantGroupingKey,
Joiner.on(", ").join(potentialVariations)
);
return this;
}
private Collection<TermVariation> getVariations() {
Set<TermVariation> termVariations = Sets.newHashSet();
for(Term t:actual.getTerms()) {
for(TermVariation v:t.getVariations())
termVariations.add(v);
}
return termVariations;
}
public TermIndexAssert hasNVariationsOfType(int expected, VariationType type) {
int cnt = 0;
for(TermVariation tv:getVariations()) {
if(tv.getVariationType() == type)
cnt++;
}
if(cnt != expected)
failWithMessage("Expected <%s> variations of type <%s>. Got: <%s>", expected, type, cnt);
return this;
}
public AbstractIterableAssert<?, ? extends Iterable<? extends TermVariation>, TermVariation> asTermVariationsHavingObject(Object object) {
Set<TermVariation> variations = Sets.newHashSet();
for(TermVariation v:getVariations())
if(Objects.equal(v.getInfo(), object))
variations.add(v);
return assertThat(variations);
}
public AbstractIterableAssert<?, ? extends Iterable<? extends TermVariation>, TermVariation> asTermVariations(VariationType... variations) {
return assertThat(
TermIndexUtils.selectTermVariations(actual, variations));
}
public AbstractIterableAssert<?, ? extends Iterable<? extends Term>, Term> asCompoundList() {
return assertThat(
TermIndexUtils.selectCompounds(actual));
}
public AbstractIterableAssert<?, ? extends Iterable<? extends String>, String> asMatchingRules() {
Set<String> matchingRuleNames = Sets.newHashSet();
for(TermVariation tv:TermIndexUtils.selectTermVariations(actual, VariationType.SYNTACTICAL, VariationType.MORPHOLOGICAL))
matchingRuleNames.add((String)tv.getInfo());
return assertThat(matchingRuleNames);
}
}