/*
* Copyright (c) 2002-2017 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neo4j.driver.internal.util;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.hamcrest.TypeSafeMatcher;
import java.util.Objects;
import org.neo4j.driver.internal.DirectConnectionProvider;
import org.neo4j.driver.internal.InternalDriver;
import org.neo4j.driver.internal.SessionFactory;
import org.neo4j.driver.internal.SessionFactoryImpl;
import org.neo4j.driver.internal.cluster.LoadBalancer;
import org.neo4j.driver.internal.cluster.RoundRobinAddressSet;
import org.neo4j.driver.internal.cluster.RoutingTable;
import org.neo4j.driver.internal.net.BoltServerAddress;
import org.neo4j.driver.internal.spi.ConnectionProvider;
import org.neo4j.driver.v1.Driver;
public final class Matchers
{
private Matchers()
{
}
public static Matcher<RoutingTable> containsRouter( final BoltServerAddress address )
{
return new TypeSafeMatcher<RoutingTable>()
{
@Override
protected boolean matchesSafely( RoutingTable routingTable )
{
for ( int i = 0; i < routingTable.routerSize(); i++ )
{
if ( routingTable.nextRouter().equals( address ) )
{
return true;
}
}
return false;
}
@Override
public void describeTo( Description description )
{
description.appendText( "routing table that contains router " ).appendValue( address );
}
};
}
public static Matcher<RoutingTable> containsReader( final BoltServerAddress address )
{
return new TypeSafeMatcher<RoutingTable>()
{
@Override
protected boolean matchesSafely( RoutingTable routingTable )
{
return contains( routingTable.readers(), address );
}
@Override
public void describeTo( Description description )
{
description.appendText( "routing table that contains reader " ).appendValue( address );
}
};
}
public static Matcher<RoutingTable> containsWriter( final BoltServerAddress address )
{
return new TypeSafeMatcher<RoutingTable>()
{
@Override
protected boolean matchesSafely( RoutingTable routingTable )
{
return contains( routingTable.writers(), address );
}
@Override
public void describeTo( Description description )
{
description.appendText( "routing table that contains writer " ).appendValue( address );
}
};
}
public static Matcher<Driver> directDriver()
{
return new TypeSafeMatcher<Driver>()
{
@Override
protected boolean matchesSafely( Driver driver )
{
return hasConnectionProvider( driver, DirectConnectionProvider.class );
}
@Override
public void describeTo( Description description )
{
description.appendText( "direct 'bolt://' driver " );
}
};
}
public static Matcher<Driver> directDriverWithAddress( final BoltServerAddress address )
{
return new TypeSafeMatcher<Driver>()
{
@Override
protected boolean matchesSafely( Driver driver )
{
DirectConnectionProvider provider = extractConnectionProvider( driver, DirectConnectionProvider.class );
return provider != null && Objects.equals( provider.getAddress(), address );
}
@Override
public void describeTo( Description description )
{
description.appendText( "direct driver with address bolt://" ).appendValue( address );
}
};
}
public static Matcher<Driver> clusterDriver()
{
return new TypeSafeMatcher<Driver>()
{
@Override
protected boolean matchesSafely( Driver driver )
{
return hasConnectionProvider( driver, LoadBalancer.class );
}
@Override
public void describeTo( Description description )
{
description.appendText( "cluster 'bolt+routing://' driver " );
}
};
}
private static boolean contains( RoundRobinAddressSet set, BoltServerAddress address )
{
for ( int i = 0; i < set.size(); i++ )
{
if ( set.next().equals( address ) )
{
return true;
}
}
return false;
}
private static boolean hasConnectionProvider( Driver driver, Class<? extends ConnectionProvider> providerClass )
{
return extractConnectionProvider( driver, providerClass ) != null;
}
private static <T extends ConnectionProvider> T extractConnectionProvider( Driver driver, Class<T> providerClass )
{
if ( driver instanceof InternalDriver )
{
SessionFactory sessionFactory = ((InternalDriver) driver).getSessionFactory();
if ( sessionFactory instanceof SessionFactoryImpl )
{
ConnectionProvider provider = ((SessionFactoryImpl) sessionFactory).getConnectionProvider();
if ( providerClass.isInstance( provider ) )
{
return providerClass.cast( provider );
}
}
}
return null;
}
}