/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cep.nfa;
import org.apache.commons.io.output.ByteArrayOutputStream;
import org.apache.flink.cep.Event;
import org.apache.flink.cep.nfa.compiler.NFACompiler;
import org.apache.flink.cep.pattern.Pattern;
import org.apache.flink.cep.pattern.conditions.BooleanConditions;
import org.apache.flink.cep.pattern.conditions.IterativeCondition;
import org.apache.flink.cep.pattern.conditions.SimpleCondition;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.TestLogger;
import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.junit.Assert.assertEquals;
public class NFATest extends TestLogger {
@Test
public void testSimpleNFA() {
NFA<Event> nfa = new NFA<>(Event.createTypeSerializer(), 0, false);
List<StreamRecord<Event>> streamEvents = new ArrayList<>();
streamEvents.add(new StreamRecord<>(new Event(1, "start", 1.0), 1L));
streamEvents.add(new StreamRecord<>(new Event(2, "bar", 2.0), 2L));
streamEvents.add(new StreamRecord<>(new Event(3, "start", 3.0), 3L));
streamEvents.add(new StreamRecord<>(new Event(4, "end", 4.0), 4L));
State<Event> startState = new State<>("start", State.StateType.Start);
State<Event> endState = new State<>("end", State.StateType.Normal);
State<Event> endingState = new State<>("", State.StateType.Final);
startState.addTake(
endState,
new SimpleCondition<Event>() {
private static final long serialVersionUID = -4869589195918650396L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("start");
}
});
endState.addTake(
endingState,
new SimpleCondition<Event>() {
private static final long serialVersionUID = 2979804163709590673L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("end");
}
});
endState.addIgnore(BooleanConditions.<Event>trueFunction());
nfa.addState(startState);
nfa.addState(endState);
nfa.addState(endingState);
Set<Map<String, List<Event>>> expectedPatterns = new HashSet<>();
Map<String, List<Event>> firstPattern = new HashMap<>();
firstPattern.put("start", Collections.singletonList(new Event(1, "start", 1.0)));
firstPattern.put("end", Collections.singletonList(new Event(4, "end", 4.0)));
Map<String, List<Event>> secondPattern = new HashMap<>();
secondPattern.put("start", Collections.singletonList(new Event(3, "start", 3.0)));
secondPattern.put("end", Collections.singletonList(new Event(4, "end", 4.0)));
expectedPatterns.add(firstPattern);
expectedPatterns.add(secondPattern);
Collection<Map<String, List<Event>>> actualPatterns = runNFA(nfa, streamEvents);
assertEquals(expectedPatterns, actualPatterns);
}
@Test
public void testTimeoutWindowPruning() {
NFA<Event> nfa = createStartEndNFA(2);
List<StreamRecord<Event>> streamEvents = new ArrayList<>();
streamEvents.add(new StreamRecord<>(new Event(1, "start", 1.0), 1L));
streamEvents.add(new StreamRecord<>(new Event(2, "bar", 2.0), 2L));
streamEvents.add(new StreamRecord<>(new Event(3, "start", 3.0), 3L));
streamEvents.add(new StreamRecord<>(new Event(4, "end", 4.0), 4L));
Set<Map<String, List<Event>>> expectedPatterns = new HashSet<>();
Map<String, List<Event>> secondPattern = new HashMap<>();
secondPattern.put("start", Collections.singletonList(new Event(3, "start", 3.0)));
secondPattern.put("end", Collections.singletonList(new Event(4, "end", 4.0)));
expectedPatterns.add(secondPattern);
Collection<Map<String, List<Event>>> actualPatterns = runNFA(nfa, streamEvents);
assertEquals(expectedPatterns, actualPatterns);
}
/**
* Tests that elements whose timestamp difference is exactly the window length are not matched.
* The reaon is that the right window side (later elements) is exclusive.
*/
@Test
public void testWindowBorders() {
NFA<Event> nfa = createStartEndNFA(2);
List<StreamRecord<Event>> streamEvents = new ArrayList<>();
streamEvents.add(new StreamRecord<>(new Event(1, "start", 1.0), 1L));
streamEvents.add(new StreamRecord<>(new Event(2, "end", 2.0), 3L));
Set<Map<String, List<Event>>> expectedPatterns = Collections.emptySet();
Collection<Map<String, List<Event>>> actualPatterns = runNFA(nfa, streamEvents);
assertEquals(expectedPatterns, actualPatterns);
}
/**
* Tests that pruning shared buffer elements and computations state use the same window border
* semantics (left side inclusive and right side exclusive)
*/
@Test
public void testTimeoutWindowPruningWindowBorders() {
NFA<Event> nfa = createStartEndNFA(2);
List<StreamRecord<Event>> streamEvents = new ArrayList<>();
streamEvents.add(new StreamRecord<>(new Event(1, "start", 1.0), 1L));
streamEvents.add(new StreamRecord<>(new Event(2, "start", 2.0), 2L));
streamEvents.add(new StreamRecord<>(new Event(3, "foobar", 3.0), 3L));
streamEvents.add(new StreamRecord<>(new Event(4, "end", 4.0), 3L));
Set<Map<String, List<Event>>> expectedPatterns = new HashSet<>();
Map<String, List<Event>> secondPattern = new HashMap<>();
secondPattern.put("start", Collections.singletonList(new Event(2, "start", 2.0)));
secondPattern.put("end", Collections.singletonList(new Event(4, "end", 4.0)));
expectedPatterns.add(secondPattern);
Collection<Map<String, List<Event>>> actualPatterns = runNFA(nfa, streamEvents);
assertEquals(expectedPatterns, actualPatterns);
}
public <T> Collection<Map<String, List<T>>> runNFA(NFA<T> nfa, List<StreamRecord<T>> inputs) {
Set<Map<String, List<T>>> actualPatterns = new HashSet<>();
for (StreamRecord<T> streamEvent : inputs) {
Collection<Map<String, List<T>>> matchedPatterns = nfa.process(
streamEvent.getValue(),
streamEvent.getTimestamp()).f0;
actualPatterns.addAll(matchedPatterns);
}
return actualPatterns;
}
@Test
public void testNFASerialization() throws IOException, ClassNotFoundException {
Pattern<Event, ?> pattern1 = Pattern.<Event>begin("start").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 1858562682635302605L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("a");
}
}).followedByAny("middle").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 8061969839441121955L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("b");
}
}).oneOrMore().optional().allowCombinations().followedByAny("end").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 8061969839441121955L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("d");
}
});
Pattern<Event, ?> pattern2 = Pattern.<Event>begin("start").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 1858562682635302605L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("a");
}
}).notFollowedBy("not").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = -6085237016591726715L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("c");
}
}).followedByAny("middle").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 8061969839441121955L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("b");
}
}).oneOrMore().optional().allowCombinations().followedByAny("end").where(new IterativeCondition<Event>() {
private static final long serialVersionUID = 8061969839441121955L;
@Override
public boolean filter(Event value, Context<Event> ctx) throws Exception {
double sum = 0.0;
for (Event e : ctx.getEventsForPattern("middle")) {
sum += e.getPrice();
}
return sum > 5.0;
}
});
Pattern<Event, ?> pattern3 = Pattern.<Event>begin("start")
.notFollowedBy("not").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = -6085237016591726715L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("c");
}
}).followedByAny("middle").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 8061969839441121955L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("b");
}
}).oneOrMore().allowCombinations().followedByAny("end").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 8061969839441121955L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("d");
}
});
List<Pattern<Event, ?>> patterns = new ArrayList<>();
patterns.add(pattern1);
patterns.add(pattern2);
patterns.add(pattern3);
for (Pattern<Event, ?> p: patterns) {
NFACompiler.NFAFactory<Event> nfaFactory = NFACompiler.compileFactory(p, Event.createTypeSerializer(), false);
NFA<Event> nfa = nfaFactory.createNFA();
Event a = new Event(40, "a", 1.0);
Event b = new Event(41, "b", 2.0);
Event c = new Event(42, "c", 3.0);
Event b1 = new Event(41, "b", 3.0);
Event b2 = new Event(41, "b", 4.0);
Event b3 = new Event(41, "b", 5.0);
Event d = new Event(43, "d", 4.0);
nfa.process(a, 1);
nfa.process(b, 2);
nfa.process(c, 3);
nfa.process(b1, 4);
nfa.process(b2, 5);
nfa.process(b3, 6);
nfa.process(d, 7);
nfa.process(a, 8);
NFA.NFASerializer<Event> serializer = new NFA.NFASerializer<>(Event.createTypeSerializer());
//serialize
ByteArrayOutputStream baos = new ByteArrayOutputStream();
serializer.serialize(nfa, new DataOutputViewStreamWrapper(baos));
baos.close();
// copy
NFA.NFASerializer<Event> copySerializer = new NFA.NFASerializer<>(Event.createTypeSerializer());
ByteArrayInputStream in = new ByteArrayInputStream(baos.toByteArray());
ByteArrayOutputStream out = new ByteArrayOutputStream();
copySerializer.copy(new DataInputViewStreamWrapper(in), new DataOutputViewStreamWrapper(out));
in.close();
out.close();
// deserialize
ByteArrayInputStream bais = new ByteArrayInputStream(out.toByteArray());
NFA.NFASerializer<Event> deserializer = new NFA.NFASerializer<>(Event.createTypeSerializer());
NFA<Event> copy = deserializer.deserialize(new DataInputViewStreamWrapper(bais));
bais.close();
assertEquals(nfa, copy);
}
}
private NFA<Event> createStartEndNFA(long windowLength) {
NFA<Event> nfa = new NFA<>(Event.createTypeSerializer(), windowLength, false);
State<Event> startState = new State<>("start", State.StateType.Start);
State<Event> endState = new State<>("end", State.StateType.Normal);
State<Event> endingState = new State<>("", State.StateType.Final);
startState.addTake(
endState,
new SimpleCondition<Event>() {
private static final long serialVersionUID = -4869589195918650396L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("start");
}
});
endState.addTake(
endingState,
new SimpleCondition<Event>() {
private static final long serialVersionUID = 2979804163709590673L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("end");
}
});
endState.addIgnore(BooleanConditions.<Event>trueFunction());
nfa.addState(startState);
nfa.addState(endState);
nfa.addState(endingState);
return nfa;
}
}