/* * Copyright (c) [2011-2017] "Pivotal Software, Inc." / "Neo Technology" / "Graph Aware Ltd." * * This product is licensed to you under the Apache License, Version 2.0 (the "License"). * You may not use this product except in compliance with the License. * * This product may include a number of subcomponents with * separate copyright notices and license terms. Your use of the source * code for these subcomponents is subject to the terms and * conditions of the subcomponent's license, as noted in the LICENSE file. * */ package org.springframework.data.neo4j.repository.support; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Optional; import org.neo4j.ogm.cypher.query.Pagination; import org.neo4j.ogm.session.Session; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.neo4j.repository.Neo4jRepository; import org.springframework.data.neo4j.util.PagingAndSortingUtils; import org.springframework.stereotype.Repository; import org.springframework.transaction.annotation.Transactional; import org.springframework.util.Assert; /** * Default implementation of the {@link org.springframework.data.repository.CrudRepository} interface. This will offer * you a more sophisticated interface than the plain {@link Session} . * * @param <T> the type of the entity to handle * * @author Vince Bickers * @author Luanne Misquitta * @author Mark Angrish * @author Mark Paluch * @author Jens Schauder */ @Repository @Transactional(readOnly = true) public class SimpleNeo4jRepository<T, ID extends Serializable> implements Neo4jRepository<T, ID> { private static final int DEFAULT_QUERY_DEPTH = 1; private static final String ID_MUST_NOT_BE_NULL = "The given id must not be null!"; private final Class<T> clazz; private final Session session; /** * Creates a new {@link SimpleNeo4jRepository} to manage objects of the given domain type. * * @param domainClass must not be {@literal null}. * @param session must not be {@literal null}. */ public SimpleNeo4jRepository(Class<T> domainClass, Session session) { Assert.notNull(domainClass, "Domain class must not be null!"); Assert.notNull(session, "Session must not be null!"); this.clazz = domainClass; this.session = session; } @Transactional @Override public <S extends T> S save(S entity) { session.save(entity); return entity; } @Transactional @Override public <S extends T> Iterable<S> saveAll(Iterable<S> entities) { for (S entity : entities) { session.save(entity); } return entities; } @Override public Optional<T> findById(ID id) { Assert.notNull(id, ID_MUST_NOT_BE_NULL); return Optional.ofNullable(session.load(clazz, id)); } @Override public boolean existsById(ID id) { return findById(id).isPresent(); } @Override public long count() { return session.countEntitiesOfType(clazz); } @Transactional @Override public void deleteById(ID id) { findById(id).ifPresent(session::delete); } @Transactional @Override public void delete(T t) { session.delete(t); } @Transactional @Override public void deleteAll(Iterable<? extends T> ts) { for (T t : ts) { session.delete(t); } } @Transactional @Override public void deleteAll() { session.deleteAll(clazz); } @Transactional @Override public <S extends T> S save(S s, int depth) { session.save(s, depth); return s; } @Transactional @Override public <S extends T> Iterable<S> save(Iterable<S> ses, int depth) { session.save(ses, depth); return ses; } @Override public Optional<T> findById(ID id, int depth) { return Optional.ofNullable(session.load(clazz, id, depth)); } // findAll and variants @Override public Iterable<T> findAll() { return findAll(DEFAULT_QUERY_DEPTH); } @Override public Iterable<T> findAll(int depth) { return session.loadAll(clazz, depth); } @Override public Iterable<T> findAllById(Iterable<ID> longs) { return findAllById(longs, DEFAULT_QUERY_DEPTH); } @Override public Iterable<T> findAllById(Iterable<ID> ids, int depth) { return session.loadAll(clazz, (Collection<ID>) ids, depth); } @Override public Iterable<T> findAll(Sort sort) { return findAll(sort, DEFAULT_QUERY_DEPTH); } @Override public Iterable<T> findAll(Sort sort, int depth) { return session.loadAll(clazz, PagingAndSortingUtils.convert(sort), depth); } @Override public Iterable<T> findAllById(Iterable<ID> ids, Sort sort) { return findAllById(ids, sort, DEFAULT_QUERY_DEPTH); } @Override public Iterable<T> findAllById(Iterable<ID> ids, Sort sort, int depth) { return session.loadAll(clazz, (Collection<ID>) ids, PagingAndSortingUtils.convert(sort), depth); } @Override public Page<T> findAll(Pageable pageable) { return findAll(pageable, DEFAULT_QUERY_DEPTH); } @Override public Page<T> findAll(Pageable pageable, int depth) { Collection<T> data = session.loadAll(clazz, PagingAndSortingUtils.convert(pageable.getSort()) , new Pagination(pageable.getPageNumber(), pageable.getPageSize()), depth); return updatePage(pageable, new ArrayList<>(data)); } /* * This is a cheap trick to estimate the total number of objects without actually knowing the real value. * Essentially, if the result size is the same as the page size, we assume more data can be fetched, so * we set the expected total to the current total retrieved so far + the current page size. As soon as the * result size is less than the page size, we know there are no more, so we set the total to the number * retrieved so far. This will ensure that page.next() returns false. */ private Page<T> updatePage(Pageable pageable, List<T> results) { int pageSize = pageable.getPageSize(); long pageOffset = pageable.getOffset(); long total = pageOffset + results.size() + (results.size() == pageSize ? pageSize : 0); return new PageImpl<T>(results, pageable, total); } }