package core.framework.impl.mongo;
import com.mongodb.ReadPreference;
import com.mongodb.bulk.BulkWriteResult;
import com.mongodb.client.AggregateIterable;
import com.mongodb.client.FindIterable;
import com.mongodb.client.MapReduceIterable;
import com.mongodb.client.MongoCursor;
import com.mongodb.client.MongoIterable;
import com.mongodb.client.model.BulkWriteOptions;
import com.mongodb.client.model.CountOptions;
import com.mongodb.client.model.DeleteOneModel;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.InsertManyOptions;
import com.mongodb.client.model.ReplaceOneModel;
import com.mongodb.client.model.UpdateOptions;
import com.mongodb.client.result.DeleteResult;
import com.mongodb.client.result.UpdateResult;
import core.framework.api.log.ActionLogContext;
import core.framework.api.log.Markers;
import core.framework.api.mongo.Aggregate;
import core.framework.api.mongo.Collection;
import core.framework.api.mongo.Count;
import core.framework.api.mongo.FindOne;
import core.framework.api.mongo.Get;
import core.framework.api.mongo.MapReduce;
import core.framework.api.mongo.MongoCollection;
import core.framework.api.mongo.Query;
import core.framework.api.util.Exceptions;
import core.framework.api.util.Lists;
import core.framework.api.util.StopWatch;
import core.framework.api.util.Strings;
import org.bson.BsonDocument;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.conversions.Bson;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
/**
* @author neo
*/
class MongoCollectionImpl<T> implements MongoCollection<T> {
private final Logger logger = LoggerFactory.getLogger(MongoCollectionImpl.class);
private final MongoImpl mongo;
private final Class<T> entityClass;
private final String collectionName;
private final EntityValidator<T> validator;
private com.mongodb.client.MongoCollection<T> collection;
MongoCollectionImpl(MongoImpl mongo, Class<T> entityClass) {
this.mongo = mongo;
this.entityClass = entityClass;
validator = new EntityValidator<>(entityClass);
collectionName = entityClass.getDeclaredAnnotation(Collection.class).name();
}
@Override
public long count(Count count) {
StopWatch watch = new StopWatch();
Bson filter = count.filter == null ? new BsonDocument() : count.filter;
try {
return collection(count.readPreference).count(filter, new CountOptions().maxTime(mongo.timeoutInMs, TimeUnit.MILLISECONDS));
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("count, collection={}, filter={}, readPreference={}, elapsedTime={}",
collectionName,
new BsonParam(filter, mongo.registry),
count.readPreference == null ? null : count.readPreference.getName(),
elapsedTime);
checkSlowOperation(elapsedTime);
}
}
@Override
public void insert(T entity) {
StopWatch watch = new StopWatch();
validator.validate(entity);
try {
collection().insertOne(entity);
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("insert, collection={}, elapsedTime={}", collectionName, elapsedTime);
checkSlowOperation(elapsedTime);
}
}
@Override
public void bulkInsert(List<T> entities) {
if (entities == null || entities.isEmpty()) throw Exceptions.error("entities must not be empty");
StopWatch watch = new StopWatch();
for (T entity : entities) {
validator.validate(entity);
}
try {
collection().insertMany(entities, new InsertManyOptions().ordered(false));
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("bulkInsert, collection={}, size={}, elapsedTime={}", collectionName, entities.size(), elapsedTime);
checkSlowOperation(elapsedTime);
}
}
@Override
public Optional<T> get(Get get) {
if (get.id == null) throw new Error("get.id must not be null");
StopWatch watch = new StopWatch();
try {
T result = collection(get.readPreference).find(Filters.eq("_id", get.id)).first();
return Optional.ofNullable(result);
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("get, collection={}, id={}, readPreference={}, elapsedTime={}",
collectionName,
get.id,
get.readPreference == null ? null : get.readPreference.getName(),
elapsedTime);
checkSlowOperation(elapsedTime);
}
}
@Override
public Optional<T> findOne(FindOne findOne) {
StopWatch watch = new StopWatch();
Bson filter = findOne.filter == null ? new BsonDocument() : findOne.filter;
try {
List<T> results = new ArrayList<>(2);
FindIterable<T> query = collection()
.find(filter)
.limit(2)
.maxTime(mongo.timeoutInMs, TimeUnit.MILLISECONDS);
fetch(query, results);
if (results.isEmpty()) return Optional.empty();
if (results.size() > 1) throw Exceptions.error("more than one row returned, size={}", results.size());
return Optional.of(results.get(0));
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("findOne, collection={}, filter={}, readPreference={}, elapsedTime={}",
collectionName,
new BsonParam(filter, mongo.registry),
findOne.readPreference == null ? null : findOne.readPreference.getName(),
elapsedTime);
checkSlowOperation(elapsedTime);
}
}
@Override
public List<T> find(Query query) {
StopWatch watch = new StopWatch();
try {
List<T> results = query.limit == null ? Lists.newArrayList() : new ArrayList<>(query.limit);
FindIterable<T> mongoQuery = mongoQuery(query).maxTime(mongo.timeoutInMs, TimeUnit.MILLISECONDS);
fetch(mongoQuery, results);
checkTooManyRowsReturned(results.size());
return results;
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("find, collection={}, filter={}, projection={}, sort={}, skip={}, limit={}, readPreference={}, elapsedTime={}",
collectionName,
new BsonParam(query.filter, mongo.registry),
new BsonParam(query.projection, mongo.registry),
new BsonParam(query.sort, mongo.registry),
query.skip,
query.limit,
query.readPreference == null ? null : query.readPreference.getName(),
elapsedTime);
checkSlowOperation(elapsedTime);
}
}
@Override
public void forEach(Query query, Consumer<T> consumer) {
StopWatch watch = new StopWatch();
Integer total = null;
try {
FindIterable<T> mongoQuery = mongoQuery(query);
total = apply(mongoQuery, consumer);
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("forEach, collection={}, filter={}, projection={}, sort={}, skip={}, limit={}, readPreference={}, total={}, elapsedTime={}",
collectionName,
new BsonParam(query.filter, mongo.registry),
new BsonParam(query.projection, mongo.registry),
new BsonParam(query.sort, mongo.registry),
query.skip,
query.limit,
query.readPreference == null ? null : query.readPreference.getName(),
total,
elapsedTime);
}
}
private FindIterable<T> mongoQuery(Query query) {
FindIterable<T> mongoQuery = collection(query.readPreference).find(query.filter == null ? new BsonDocument() : query.filter);
if (query.projection != null) mongoQuery.projection(query.projection);
if (query.sort != null) mongoQuery.sort(query.sort);
if (query.skip != null) mongoQuery.skip(query.skip);
if (query.limit != null) mongoQuery.limit(query.limit);
return mongoQuery;
}
private int apply(MongoIterable<T> mongoQuery, Consumer<T> consumer) {
int total = 0;
try (MongoCursor<T> cursor = mongoQuery.iterator()) {
while (cursor.hasNext()) {
T result = cursor.next();
total++;
consumer.accept(result);
}
}
return total;
}
@Override
public <V> List<V> aggregate(Aggregate<V> aggregate) {
if (aggregate.pipeline == null || aggregate.pipeline.isEmpty()) throw new Error("aggregate.pipeline must not be empty");
if (aggregate.resultClass == null) throw new Error("aggregate.resultClass must not be null");
StopWatch watch = new StopWatch();
try {
List<V> results = Lists.newArrayList();
AggregateIterable<V> query = collection(aggregate.readPreference)
.aggregate(aggregate.pipeline, aggregate.resultClass)
.maxTime(mongo.timeoutInMs, TimeUnit.MILLISECONDS);
fetch(query, results);
checkTooManyRowsReturned(results.size());
return results;
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("aggregate, collection={}, pipeline={}, readPreference={}, elapsedTime={}",
collectionName,
aggregate.pipeline.stream().map(stage -> new BsonParam(stage, mongo.registry)).toArray(),
aggregate.readPreference == null ? null : aggregate.readPreference.getName(),
elapsedTime);
}
}
@Override
public <V> List<V> mapReduce(MapReduce<V> mapReduce) {
if (Strings.isEmpty(mapReduce.mapFunction)) throw new Error("mapReduce.mapFunction must not be empty");
if (Strings.isEmpty(mapReduce.reduceFunction)) throw new Error("mapReduce.reduceFunction must not be empty");
if (mapReduce.resultClass == null) throw new Error("mapReduce.resultClass must not be null");
StopWatch watch = new StopWatch();
try {
List<V> results = Lists.newArrayList();
MapReduceIterable<V> query = collection(mapReduce.readPreference)
.mapReduce(mapReduce.mapFunction, mapReduce.reduceFunction, mapReduce.resultClass)
.maxTime(mongo.timeoutInMs, TimeUnit.MILLISECONDS);
if (mapReduce.filter != null) query.filter(mapReduce.filter);
fetch(query, results);
checkTooManyRowsReturned(results.size());
return results;
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("mapReduce, collection={}, map={}, reduce={}, filter={}, readPreference={}, elapsedTime={}",
collectionName,
mapReduce.mapFunction,
mapReduce.reduceFunction,
new BsonParam(mapReduce.filter, mongo.registry),
mapReduce.readPreference == null ? null : mapReduce.readPreference.getName(),
elapsedTime);
}
}
private <V> void fetch(MongoIterable<V> iterable, List<V> results) {
try (MongoCursor<V> cursor = iterable.iterator()) {
while (cursor.hasNext()) {
results.add(cursor.next());
}
}
}
@Override
public void replace(T entity) {
StopWatch watch = new StopWatch();
Object id = null;
validator.validate(entity);
try {
id = mongo.codecs.id(entity);
if (id == null) throw Exceptions.error("entity must have id, entityClass={}", entityClass.getCanonicalName());
collection().replaceOne(Filters.eq("_id", id), entity, new UpdateOptions().upsert(true));
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("replace, collection={}, id={}, elapsedTime={}", collectionName, id, elapsedTime);
checkSlowOperation(elapsedTime);
}
}
@Override
public void bulkReplace(List<T> entities) {
StopWatch watch = new StopWatch();
if (entities == null || entities.isEmpty()) throw Exceptions.error("entities must not be empty");
for (T entity : entities) {
validator.validate(entity);
}
try {
List<ReplaceOneModel<T>> models = new ArrayList<>(entities.size());
for (T entity : entities) {
Object id = mongo.codecs.id(entity);
if (id == null) throw Exceptions.error("entity must have id, entityClass={}", entityClass.getCanonicalName());
models.add(new ReplaceOneModel<>(Filters.eq("_id", id), entity, new UpdateOptions().upsert(true)));
}
collection().bulkWrite(models, new BulkWriteOptions().ordered(false));
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("bulkReplace, collection={}, size={}, elapsedTime={}", collectionName, entities.size(), elapsedTime);
checkSlowOperation(elapsedTime);
}
}
@Override
public long update(Bson filter, Bson update) {
StopWatch watch = new StopWatch();
try {
UpdateResult result = collection().updateMany(filter, update);
return result.getModifiedCount();
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("update, collection={}, filter={}, update={}, elapsedTime={}", collectionName, new BsonParam(filter, mongo.registry), update, elapsedTime);
checkSlowOperation(elapsedTime);
}
}
@Override
public long delete(Object id) {
StopWatch watch = new StopWatch();
try {
DeleteResult result = collection().deleteOne(Filters.eq("_id", id));
return result.getDeletedCount();
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("delete, collection={}, id={}, elapsedTime={}", collectionName, id, elapsedTime);
checkSlowOperation(elapsedTime);
}
}
@Override
public long delete(Bson filter) {
StopWatch watch = new StopWatch();
try {
DeleteResult result = collection().deleteMany(filter == null ? new BsonDocument() : filter);
return result.getDeletedCount();
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("delete, collection={}, filter={}, elapsedTime={}", collectionName, new BsonParam(filter, mongo.registry), elapsedTime);
checkSlowOperation(elapsedTime);
}
}
@Override
public long bulkDelete(List<?> ids) {
StopWatch watch = new StopWatch();
try {
List<DeleteOneModel<T>> models = new ArrayList<>(ids.size());
for (Object id : ids) {
models.add(new DeleteOneModel<>(Filters.eq("_id", id)));
}
BulkWriteResult result = collection().bulkWrite(models, new BulkWriteOptions().ordered(false));
return result.getDeletedCount();
} finally {
long elapsedTime = watch.elapsedTime();
ActionLogContext.track("mongoDB", elapsedTime);
logger.debug("bulkDelete, collection={}, ids={}, elapsedTime={}", collectionName, ids, elapsedTime);
checkSlowOperation(elapsedTime);
}
}
private void checkSlowOperation(long elapsedTime) {
if (elapsedTime > mongo.slowOperationThresholdInNanos) {
logger.warn(Markers.errorCode("SLOW_MONGODB"), "slow mongoDB query, elapsedTime={}", elapsedTime);
}
}
private void checkTooManyRowsReturned(int size) {
if (size > mongo.tooManyRowsReturnedThreshold) {
logger.warn(Markers.errorCode("TOO_MANY_ROWS_RETURNED"), "too many rows returned, returnedRows={}", size);
}
}
private com.mongodb.client.MongoCollection<T> collection(ReadPreference readPreference) {
if (readPreference != null)
return collection().withReadPreference(readPreference);
return collection();
}
private com.mongodb.client.MongoCollection<T> collection() {
if (collection == null) {
collection = mongo.mongoCollection(entityClass);
}
return collection;
}
static class BsonParam {
final Bson bson;
final CodecRegistry registry;
BsonParam(Bson bson, CodecRegistry registry) {
this.bson = bson;
this.registry = registry;
}
@Override
public String toString() {
if (bson == null) return "null";
return bson.toBsonDocument(null, registry).toJson();
}
}
}