/*
* Copyright 2015 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.drools.compiler.integrationtests;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import org.drools.core.marshalling.impl.ProtobufMarshaller;
import org.drools.core.util.DroolsStreamUtils;
import org.kie.api.KieBase;
import org.kie.api.marshalling.ObjectMarshallingStrategy;
import org.kie.api.runtime.EnvironmentName;
import org.kie.api.runtime.KieSession;
import org.kie.api.time.SessionClock;
import org.kie.internal.marshalling.MarshallerFactory;
import org.kie.internal.runtime.StatefulKnowledgeSession;
/**
* Marshalling helper class to perform serialize/de-serialize a given object
*/
public class SerializationHelper {
public static <T> T serializeObject(T obj) throws IOException,
ClassNotFoundException {
return serializeObject( obj,
null );
}
@SuppressWarnings("unchecked")
public static <T> T serializeObject(T obj,
ClassLoader classLoader) throws IOException,
ClassNotFoundException {
return (T) DroolsStreamUtils.streamIn( DroolsStreamUtils.streamOut( obj ),
classLoader );
}
public static StatefulKnowledgeSession getSerialisedStatefulKnowledgeSession(KieSession ksession,
boolean dispose) throws Exception {
return getSerialisedStatefulKnowledgeSession( ksession,
dispose,
true );
}
public static StatefulKnowledgeSession getSerialisedStatefulKnowledgeSession(KieSession ksession,
boolean dispose,
boolean testRoundTrip ) throws Exception {
return getSerialisedStatefulKnowledgeSession( ksession,ksession.getKieBase(), dispose, testRoundTrip );
}
public static StatefulKnowledgeSession getSerialisedStatefulKnowledgeSession(KieSession ksession,
KieBase kbase,
boolean dispose ) throws Exception {
return getSerialisedStatefulKnowledgeSession( ksession, kbase, dispose, true );
}
public static StatefulKnowledgeSession getSerialisedStatefulKnowledgeSession(KieSession ksession,
KieBase kbase,
boolean dispose,
boolean testRoundTrip ) throws Exception {
ProtobufMarshaller marshaller = (ProtobufMarshaller) MarshallerFactory.newMarshaller( kbase,
(ObjectMarshallingStrategy[])ksession.getEnvironment().get(EnvironmentName.OBJECT_MARSHALLING_STRATEGIES) );
long time = ksession.<SessionClock>getSessionClock().getCurrentTime();
// make sure globas are in the environment of the session
ksession.getEnvironment().set( EnvironmentName.GLOBALS, ksession.getGlobals() );
// Serialize object
final byte [] b1;
{
ByteArrayOutputStream bos = new ByteArrayOutputStream();
marshaller.marshall( bos,
ksession,
time );
b1 = bos.toByteArray();
bos.close();
}
// Deserialize object
StatefulKnowledgeSession ksession2;
{
ByteArrayInputStream bais = new ByteArrayInputStream( b1 );
ksession2 = marshaller.unmarshall( bais,
ksession.getSessionConfiguration(),
ksession.getEnvironment());
bais.close();
}
if( testRoundTrip ) {
// for now, we can ensure the IDs will match because queries are creating untraceable fact handles at the moment
// int previous_id = ((StatefulKnowledgeSessionImpl)ksession).session.getFactHandleFactory().getId();
// long previous_recency = ((StatefulKnowledgeSessionImpl)ksession).session.getFactHandleFactory().getRecency();
// int current_id = ((StatefulKnowledgeSessionImpl)ksession2).session.getFactHandleFactory().getId();
// long current_recency = ((StatefulKnowledgeSessionImpl)ksession2).session.getFactHandleFactory().getRecency();
// ((StatefulKnowledgeSessionImpl)ksession2).session.getFactHandleFactory().clear( previous_id, previous_recency );
// Reserialize and check that byte arrays are the same
final byte[] b2;
{
ByteArrayOutputStream bos = new ByteArrayOutputStream();
marshaller.marshall( bos,
ksession2,
time );
b2 = bos.toByteArray();
bos.close();
}
// bytes should be the same.
if ( !areByteArraysEqual( b1,
b2 ) ) {
// throw new IllegalArgumentException( "byte streams for serialisation test are not equal" );
}
// ((StatefulKnowledgeSessionImpl) ksession2).session.getFactHandleFactory().clear( current_id, current_recency );
// ((StatefulKnowledgeSessionImpl) ksession2).session.setGlobalResolver( ((StatefulKnowledgeSessionImpl) ksession).session.getGlobalResolver() );
}
if ( dispose ) {
ksession.dispose();
}
return ksession2;
}
private static boolean areByteArraysEqual(byte[] b1, byte[] b2) {
if ( b1.length != b2.length ) {
System.out.println( "Different length: b1=" + b1.length + " b2=" + b2.length );
return false;
}
// System.out.println( "b1" );
// for ( int i = 0, length = b1.length; i < length; i++ ) {
// if ( i == 81 ) {
// System.out.print( "!" );
// }
// System.out.print( b1[i] );
// if ( i == 83 ) {
// System.out.print( "!" );
// }
// }
//
// System.out.println( "\nb2" );
// for ( int i = 0, length = b2.length; i < length; i++ ) {
// if ( i == 81 ) {
// System.out.print( "!" );
// }
// System.out.print( b2[i] );
// if ( i == 83 ) {
// System.out.print( "!" );
// }
// }
boolean result = true;
for ( int i = 0, length = b1.length; i < length; i++ ) {
if ( b1[i] != b2[i] ) {
System.out.println( "Difference at " + i + ": [" + b1[i] + "] != [" + b2[i] + "]" );
result = false;
}
}
return result;
}
}