/* * Copyright 2005 Red Hat, Inc. and/or its affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.drools.core.reteoo; import org.drools.core.common.BaseNode; import org.drools.core.common.InternalWorkingMemory; import org.drools.core.common.Memory; import org.drools.core.common.MemoryFactory; import org.drools.core.impl.InternalKnowledgeBase; import org.kie.api.runtime.KieSession; public class ReteMemoryChecker { public static void checkNodeMemories(KieSession session) { InternalKnowledgeBase kbase = (InternalKnowledgeBase)session.getKieBase(); for (EntryPointNode entryPointNode : kbase.getRete().getEntryPointNodes().values()) { checkNodeMemory( (InternalWorkingMemory) session, entryPointNode ); } } private static void checkNodeMemory(InternalWorkingMemory wm, BaseNode node) { if (node instanceof MemoryFactory) { Memory memory = wm.getNodeMemory( (MemoryFactory) node ); if ( NodeTypeEnums.ObjectTypeNode == node.getType() ) { if ( !( memory instanceof ObjectTypeNode.ObjectTypeNodeMemory ) ) { throw new RuntimeException( "Invalid memory type. Node: " + node + " has memory " + memory ); } } else if ( NodeTypeEnums.LeftInputAdapterNode == node.getType() ) { if ( !( memory instanceof LeftInputAdapterNode.LiaNodeMemory ) ) { throw new RuntimeException( "Invalid memory type. Node: " + node + " has memory " + memory ); } } else if ( NodeTypeEnums.isBetaNode( node ) ) { if ( NodeTypeEnums.AccumulateNode == node.getType() ) { if ( !( memory instanceof AccumulateNode.AccumulateMemory ) ) { throw new RuntimeException( "Invalid memory type. Node: " + node + " has memory " + memory ); } } else if ( !( memory instanceof BetaMemory ) ) { throw new RuntimeException( "Invalid memory type. Node: " + node + " has memory " + memory ); } } else if ( NodeTypeEnums.FromNode == node.getType() ) { if ( !( memory instanceof FromNode.FromMemory ) ) { throw new RuntimeException( "Invalid memory type. Node: " + node + " has memory " + memory ); } } else if ( NodeTypeEnums.WindowNode == node.getType() ) { if ( !( memory instanceof WindowNode.WindowMemory ) ) { throw new RuntimeException( "Invalid memory type. Node: " + node + " has memory " + memory ); } } else if ( NodeTypeEnums.isTerminalNode( node ) ) { if ( !( memory instanceof PathMemory ) ) { throw new RuntimeException( "Invalid memory type. Node: " + node + " has memory " + memory ); } checkPathMemory((PathMemory)memory); } } Sink[] sinks = node.getSinks(); if (sinks != null) { for (Sink sink : sinks) { if (sink instanceof BaseNode) { checkNodeMemory( wm, (BaseNode) sink ); } } } } private static void checkPathMemory( PathMemory pathMemory ) { SegmentMemory[] smems = pathMemory.getSegmentMemories(); if ( !NodeTypeEnums.isLeftTupleSource( smems[0].getRootNode() ) ) { throw new RuntimeException( "The root node for path " + pathMemory + " has to be a LeftTupleSource but is a " + smems[0].getRootNode() ); } if ( !NodeTypeEnums.isTerminalNode( smems[smems.length-1].getTipNode() ) ) { throw new RuntimeException( "The tip node for path " + pathMemory + " has to be a TerminalNode but is a " + smems[smems.length-1].getTipNode() ); } for (int i = 0; i < smems.length; i++) { if (smems[i] == null) { throw new RuntimeException( "Missing segment in position " + i + " for " + pathMemory ); } if (i != smems[i].getPos()) { throw new RuntimeException( "Segment " + smems[i] + " is expected to be in position " + i + " but it is in position " + smems[i].getPos() ); } } } }