/**
*
* Copyright (c) 2006-2017, Speedment, Inc. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); You may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/
package com.speedment.runtime.core.manager.sql;
import com.speedment.runtime.core.component.sql.SqlStreamOptimizerInfo;
import com.speedment.runtime.core.db.AsynchronousQueryResult;
import com.speedment.runtime.core.db.DbmsType;
import com.speedment.runtime.core.internal.component.sql.SqlStreamOptimizerComponentImpl;
import com.speedment.runtime.core.internal.component.sql.override.SqlStreamTerminatorComponentImpl;
import com.speedment.runtime.core.internal.db.AsynchronousQueryResultImpl;
import com.speedment.runtime.core.internal.manager.sql.SqlStreamTerminator;
import com.speedment.runtime.core.internal.stream.builder.action.reference.FilterAction;
import com.speedment.runtime.core.internal.stream.builder.action.reference.MapAction;
import com.speedment.runtime.core.internal.stream.builder.pipeline.PipelineImpl;
import com.speedment.runtime.core.internal.stream.builder.pipeline.ReferencePipeline;
import com.speedment.runtime.core.stream.action.Action;
import com.speedment.runtime.core.stream.parallel.ParallelStrategy;
import com.speedment.runtime.test_support.MockDbmsType;
import com.speedment.runtime.test_support.MockEntity;
import com.speedment.runtime.test_support.MockEntityUtil;
import java.util.ArrayList;
import static java.util.Collections.singletonList;
import java.util.List;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.BaseStream;
import java.util.stream.Stream;
import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertNull;
import org.junit.Test;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class SqlStreamTerminatorTest {
private static final long SQL_COUNT_RESULT = 100L;
private static final String SELECT_SQL = "SELECT * FROM table";
private static final String SELECT_COUNT_SQL = "SELECT COUNT(*) FROM table";
private static final String PREDICATE_COUNT_SQL_FRAGMENT = "(name = ?)";
private static final String COUNT_WHERE_SQL = String.join(" WHERE ", SELECT_SQL, PREDICATE_COUNT_SQL_FRAGMENT);
private String lastCountingSql;
private List<Object> lastCountingValues;
@Test
public void testCountGeneralFilter() {
lastCountingSql = null;
final Action<Stream<MockEntity>, Stream<MockEntity>> filterAction = new FilterAction<>(e -> e.getId() % 10 == 3);
assertEquals(10, countStreamOf(filterAction));
assertNull(lastCountingSql);
}
@Test
public void testCountSizePreservingFilter() {
final Action<Stream<MockEntity>, Stream<Integer>> mapAction = new MapAction<>(MockEntity::getId);
assertEquals(SQL_COUNT_RESULT, countStreamOf(mapAction));
assertEquals(SELECT_COUNT_SQL, lastCountingSql);
}
@Test
@SuppressWarnings("unchecked")
public void testCountFieldPredicateFilter() {
final Predicate<MockEntity> predicate = MockEntity.NAME.equal("ABBA");
final Action<Stream<MockEntity>, Stream<MockEntity>> filterAction = new FilterAction<>(predicate);
assertEquals(SQL_COUNT_RESULT, countStreamOf(filterAction));
assertEquals(makeCountSql(COUNT_WHERE_SQL), lastCountingSql);
assertEquals(singletonList("ABBA"), lastCountingValues);
}
@Test
@SuppressWarnings("unchecked")
public void testCountFieldPredicateFilterPolluted() {
final Predicate<MockEntity> predicate = MockEntity.NAME.equal("ABBA").or(me -> me.getName().equals("Olle"));
final Action<Stream<MockEntity>, Stream<MockEntity>> filterAction = new FilterAction<>(predicate);
assertEquals(0, countStreamOf(filterAction));
assertNull(lastCountingSql); // Make sure counter was not called
assertNull(lastCountingValues);
}
private String makeCountSql(String sql) {
return "SELECT COUNT(*) FROM (" + sql + ") AS A";
}
private long countStreamOf(Action<?, ?> action) {
@SuppressWarnings("unchecked")
final AsynchronousQueryResult<MockEntity> asynchronousQueryResult = new AsynchronousQueryResultImpl<>(
SELECT_SQL,
new ArrayList<>(),
rs -> new MockEntity(1),
() -> null, // getConnection()
ParallelStrategy.computeIntensityDefault(),
(ps) -> {
},
(rs) -> {
}
);
final SqlStreamOptimizerInfo<MockEntity> info = SqlStreamOptimizerInfo.of(
createDbmsType(),
SELECT_SQL,
SELECT_COUNT_SQL,
(sql, l) -> {
lastCountingSql = sql;
lastCountingValues = l;
return SQL_COUNT_RESULT;
},
f -> f.identifier().getColumnName(),
f -> Object.class
);
SqlStreamTerminator<MockEntity> terminator = new SqlStreamTerminator<>(
info,
asynchronousQueryResult,
new SqlStreamOptimizerComponentImpl(),
new SqlStreamTerminatorComponentImpl(),
true
);
return terminator.count(createPipeline(action));
}
private ReferencePipeline<MockEntity> createPipeline(Action<?, ?> action) {
@SuppressWarnings("unchecked")
final Supplier<Stream<MockEntity>> supplier = mock(Supplier.class);
final Stream<MockEntity> stream = MockEntityUtil.stream((int) SQL_COUNT_RESULT);
when(supplier.get()).thenReturn(stream);
@SuppressWarnings("unchecked")
final ReferencePipeline<MockEntity> pipeline = new PipelineImpl<>((Supplier<BaseStream<?, ?>>) (Object) supplier);
pipeline.add(action);
return pipeline;
}
private DbmsType createDbmsType() {
return new MockDbmsType();
// final DbmsType dbmsType = mock(DbmsType.class);
// final FieldPredicateView fpv = mock(FieldPredicateView.class);
// final SqlPredicateFragmentImpl predicateFragment = new SqlPredicateFragmentImpl();
// predicateFragment.setSql(PREDICATE_COUNT_SQL_FRAGMENT);
// when(fpv.transform(any(), any(), any())).thenReturn(predicateFragment);
// when(dbmsType.getFieldPredicateView()).thenReturn(fpv);
// return dbmsType;
}
}