package org.infinispan.query.remote.impl; import java.io.IOException; import java.util.Map; import java.util.Set; import java.util.TreeSet; import org.infinispan.commands.AbstractVisitor; import org.infinispan.commands.CommandsFactory; import org.infinispan.commands.FlagAffectedCommand; import org.infinispan.commands.VisitableCommand; import org.infinispan.commands.tx.PrepareCommand; import org.infinispan.commands.write.ClearCommand; import org.infinispan.commands.write.PutKeyValueCommand; import org.infinispan.commands.write.PutMapCommand; import org.infinispan.commands.write.RemoveCommand; import org.infinispan.commands.write.ReplaceCommand; import org.infinispan.commands.write.WriteCommand; import org.infinispan.commons.CacheException; import org.infinispan.commons.util.EnumUtil; import org.infinispan.context.InvocationContext; import org.infinispan.context.impl.FlagBitSets; import org.infinispan.context.impl.TxInvocationContext; import org.infinispan.factories.annotations.Inject; import org.infinispan.interceptors.AsyncInterceptorChain; import org.infinispan.interceptors.BaseCustomAsyncInterceptor; import org.infinispan.metadata.EmbeddedMetadata; import org.infinispan.metadata.Metadata; import org.infinispan.protostream.DescriptorParserException; import org.infinispan.protostream.FileDescriptorSource; import org.infinispan.protostream.SerializationContext; import org.infinispan.protostream.descriptors.FileDescriptor; import org.infinispan.query.remote.ProtobufMetadataManager; import org.infinispan.query.remote.client.ProtobufMetadataManagerConstants; /** * Intercepts updates to the protobuf schema file caches and updates the SerializationContext accordingly. * * @author anistor@redhat.com * @since 7.0 */ final class ProtobufMetadataManagerInterceptor extends BaseCustomAsyncInterceptor implements ProtobufMetadataManagerConstants { private static final Metadata DEFAULT_METADATA = new EmbeddedMetadata.Builder().build(); private CommandsFactory commandsFactory; private AsyncInterceptorChain invoker; private SerializationContext serializationContext; /** * A no-op callback. */ private static final FileDescriptorSource.ProgressCallback EMPTY_CALLBACK = new FileDescriptorSource.ProgressCallback() { }; private final class ProgressCallback implements FileDescriptorSource.ProgressCallback { private final InvocationContext ctx; private final long flagsBitSet; private final Set<String> errorFiles = new TreeSet<>(); private ProgressCallback(InvocationContext ctx, long flagsBitSet) { this.ctx = ctx; this.flagsBitSet = flagsBitSet; } Set<String> getErrorFiles() { return errorFiles; } @Override public void handleError(String fileName, DescriptorParserException exception) { // handle first error per file, ignore the rest if any if (errorFiles.add(fileName)) { VisitableCommand cmd = commandsFactory.buildPutKeyValueCommand(fileName + ERRORS_KEY_SUFFIX, exception.getMessage(), DEFAULT_METADATA, flagsBitSet); invoker.invoke(ctx, cmd); } } @Override public void handleSuccess(String fileName) { VisitableCommand cmd = commandsFactory.buildRemoveCommand(fileName + ERRORS_KEY_SUFFIX, null, flagsBitSet); invoker.invoke(ctx, cmd); } } /** * Visitor used for handling the list of modifications of a PrepareCommand. */ private final AbstractVisitor serializationContextUpdaterVisitor = new AbstractVisitor() { @Override public Object visitPutKeyValueCommand(InvocationContext ctx, PutKeyValueCommand command) throws Throwable { final String key = (String) command.getKey(); if (shouldIntercept(key)) { FileDescriptorSource source = new FileDescriptorSource() .withProgressCallback(EMPTY_CALLBACK) .addProtoFile(key, (String) command.getValue()); try { serializationContext.registerProtoFiles(source); } catch (IOException | DescriptorParserException e) { throw new CacheException("Failed to parse proto file : " + key, e); } } return null; } @Override public Object visitPutMapCommand(InvocationContext ctx, PutMapCommand command) throws Throwable { final Map<Object, Object> map = command.getMap(); FileDescriptorSource source = new FileDescriptorSource() .withProgressCallback(EMPTY_CALLBACK); for (Object key : map.keySet()) { if (shouldIntercept(key)) { source.addProtoFile((String) key, (String) map.get(key)); } } try { serializationContext.registerProtoFiles(source); } catch (IOException | DescriptorParserException e) { throw new CacheException(e); } return null; } @Override public Object visitReplaceCommand(InvocationContext ctx, ReplaceCommand command) throws Throwable { final String key = (String) command.getKey(); if (shouldIntercept(key)) { FileDescriptorSource source = new FileDescriptorSource() .withProgressCallback(EMPTY_CALLBACK) .addProtoFile(key, (String) command.getNewValue()); try { serializationContext.registerProtoFiles(source); } catch (IOException | DescriptorParserException e) { throw new CacheException("Failed to parse proto file : " + key, e); } } return null; } @Override public Object visitRemoveCommand(InvocationContext ctx, RemoveCommand command) throws Throwable { final String key = (String) command.getKey(); if (shouldIntercept(key)) { if (serializationContext.getFileDescriptors().containsKey(key)) { serializationContext.unregisterProtoFile(key); } } return null; } @Override public Object visitClearCommand(InvocationContext ctx, ClearCommand command) throws Throwable { for (String fileName : serializationContext.getFileDescriptors().keySet()) { serializationContext.unregisterProtoFile(fileName); } return null; } }; @Inject public void init(CommandsFactory commandsFactory, AsyncInterceptorChain invoker, ProtobufMetadataManager protobufMetadataManager) { this.commandsFactory = commandsFactory; this.invoker = invoker; this.serializationContext = ((ProtobufMetadataManagerImpl) protobufMetadataManager).getSerializationContext(); } @Override public Object visitPrepareCommand(TxInvocationContext ctx, PrepareCommand command) throws Throwable { return invokeNextThenAccept(ctx, command, (rCtx, rCommand, rv) -> { if (!rCtx.isOriginLocal()) { // apply updates to the serialization context for (WriteCommand wc : ((PrepareCommand) rCommand).getModifications()) { wc.acceptVisitor(rCtx, serializationContextUpdaterVisitor); } } }); } @Override public Object visitPutKeyValueCommand(final InvocationContext ctx, PutKeyValueCommand command) throws Throwable { final Object key = command.getKey(); final Object value = command.getValue(); if (ctx.isOriginLocal()) { if (!(key instanceof String)) { throw new CacheException("The key must be a string"); } if (!(value instanceof String)) { throw new CacheException("The value must be a string"); } if (shouldIntercept(key)) { if (!command.hasAnyFlag(FlagBitSets.PUT_FOR_STATE_TRANSFER | FlagBitSets.SKIP_LOCKING)) { if (!((String) key).endsWith(PROTO_KEY_SUFFIX)) { throw new CacheException("The key must be a string ending with \".proto\" : " + key); } // lock .errors key VisitableCommand cmd = commandsFactory.buildLockControlCommand(ERRORS_KEY_SUFFIX, command.getFlagsBitSet(), null); invoker.invoke(ctx, cmd); } } else { return invokeNext(ctx, command); } } return invokeNextThenAccept(ctx, command, (rCtx, rCommand, rv) -> { PutKeyValueCommand putKeyValueCommand = (PutKeyValueCommand) rCommand; long flagsBitSet = copyFlags(putKeyValueCommand); if (putKeyValueCommand.isSuccessful()) { FileDescriptorSource source = new FileDescriptorSource().addProtoFile((String) key, (String) value); ProgressCallback progressCallback = null; if (rCtx.isOriginLocal() && !putKeyValueCommand.hasAnyFlag(FlagBitSets.PUT_FOR_STATE_TRANSFER)) { progressCallback = new ProgressCallback(rCtx, flagsBitSet); source.withProgressCallback(progressCallback); } else { source.withProgressCallback(EMPTY_CALLBACK); } try { serializationContext.registerProtoFiles(source); } catch (IOException | DescriptorParserException e) { throw new CacheException("Failed to parse proto file : " + key, e); } if (progressCallback != null) { updateGlobalErrors(rCtx, progressCallback.getErrorFiles(), flagsBitSet); } } }); } /** * For preload, we need to copy the CACHE_MODE_LOCAL flag from the put command. * But we also need to remove the SKIP_CACHE_STORE flag, so that existing .errors keys are updated. */ private long copyFlags(FlagAffectedCommand command) { return EnumUtil.diffBitSets(command.getFlagsBitSet(), FlagBitSets.SKIP_CACHE_STORE); } @Override public Object visitPutMapCommand(final InvocationContext ctx, PutMapCommand command) throws Throwable { final Map<Object, Object> map = command.getMap(); FileDescriptorSource source = new FileDescriptorSource(); for (Object key : map.keySet()) { final Object value = map.get(key); if (!(key instanceof String)) { throw new CacheException("The key must be a string"); } if (!(value instanceof String)) { throw new CacheException("The value must be a string"); } if (shouldIntercept(key)) { if (!((String) key).endsWith(PROTO_KEY_SUFFIX)) { throw new CacheException("The key must be a string ending with \".proto\" : " + key); } source.addProtoFile((String) key, (String) value); } } // lock .errors key VisitableCommand cmd = commandsFactory.buildLockControlCommand(ERRORS_KEY_SUFFIX, command.getFlagsBitSet(), null); invoker.invoke(ctx, cmd); return invokeNextThenAccept(ctx, command, (rCtx, rCommand, rv) -> { long flagsBitSet = copyFlags(((PutMapCommand) rCommand)); ProgressCallback progressCallback = null; if (rCtx.isOriginLocal()) { progressCallback = new ProgressCallback(rCtx, flagsBitSet); source.withProgressCallback(progressCallback); } else { source.withProgressCallback(EMPTY_CALLBACK); } try { serializationContext.registerProtoFiles(source); } catch (IOException | DescriptorParserException e) { throw new CacheException(e); } if (progressCallback != null) { updateGlobalErrors(rCtx, progressCallback.getErrorFiles(), flagsBitSet); } }); } @Override public Object visitRemoveCommand(InvocationContext ctx, RemoveCommand command) throws Throwable { if (ctx.isOriginLocal()) { if (!(command.getKey() instanceof String)) { throw new CacheException("The key must be a string"); } String key = (String) command.getKey(); if (shouldIntercept(key)) { // lock .errors key long flagsBitSet = copyFlags(command); VisitableCommand cmd = commandsFactory.buildLockControlCommand(ERRORS_KEY_SUFFIX, flagsBitSet, null); invoker.invoke(ctx, cmd); cmd = commandsFactory.buildRemoveCommand(key + ERRORS_KEY_SUFFIX, null, flagsBitSet); invoker.invoke(ctx, cmd); if (serializationContext.getFileDescriptors().containsKey(key)) { serializationContext.unregisterProtoFile(key); } // put error key for all unresolved files and remove error key for all resolved files StringBuilder sb = new StringBuilder(); for (FileDescriptor fd : serializationContext.getFileDescriptors().values()) { if (fd.isResolved()) { cmd = commandsFactory.buildRemoveCommand(fd.getName() + ERRORS_KEY_SUFFIX, null, flagsBitSet); invoker.invoke(ctx, cmd); } else { if (sb.length() > 0) { sb.append('\n'); } sb.append(fd.getName()); PutKeyValueCommand put = commandsFactory.buildPutKeyValueCommand(fd.getName() + ERRORS_KEY_SUFFIX, "One of the imported files is missing or has errors", DEFAULT_METADATA, flagsBitSet); put.setPutIfAbsent(true); invoker.invoke(ctx, put); } } if (sb.length() > 0) { cmd = commandsFactory.buildPutKeyValueCommand(ERRORS_KEY_SUFFIX, sb.toString(), DEFAULT_METADATA, flagsBitSet); } else { cmd = commandsFactory.buildRemoveCommand(ERRORS_KEY_SUFFIX, null, flagsBitSet); } invoker.invoke(ctx, cmd); } } return invokeNext(ctx, command); } @Override public Object visitReplaceCommand(final InvocationContext ctx, ReplaceCommand command) throws Throwable { final Object key = command.getKey(); final Object value = command.getNewValue(); if (!ctx.isOriginLocal()) { return invokeNext(ctx, command); } if (!(key instanceof String)) { throw new CacheException("The key must be a string"); } if (!(value instanceof String)) { throw new CacheException("The value must be a string"); } if (!shouldIntercept(key)) { return invokeNext(ctx, command); } if (!((String) key).endsWith(PROTO_KEY_SUFFIX)) { throw new CacheException("The key must be a string ending with \".proto\" : " + key); } // lock .errors key VisitableCommand cmd = commandsFactory.buildLockControlCommand(ERRORS_KEY_SUFFIX, command.getFlagsBitSet(), null); invoker.invoke(ctx, cmd); return invokeNextThenAccept(ctx, command, (rCtx, rCommand, rv) -> { if (((WriteCommand) rCommand).isSuccessful()) { FileDescriptorSource source = new FileDescriptorSource().addProtoFile((String) key, (String) value); long flagsBitSet = copyFlags(((WriteCommand) rCommand)); ProgressCallback progressCallback = null; if (rCtx.isOriginLocal()) { progressCallback = new ProgressCallback(rCtx, flagsBitSet); source.withProgressCallback(progressCallback); } else { source.withProgressCallback(EMPTY_CALLBACK); } try { serializationContext.registerProtoFiles(source); } catch (IOException | DescriptorParserException e) { throw new CacheException("Failed to parse proto file : " + key, e); } if (progressCallback != null) { updateGlobalErrors(rCtx, progressCallback.getErrorFiles(), flagsBitSet); } } }); } @Override public Object visitClearCommand(InvocationContext ctx, ClearCommand command) throws Throwable { for (String fileName : serializationContext.getFileDescriptors().keySet()) { serializationContext.unregisterProtoFile(fileName); } return invokeNext(ctx, command); } private boolean shouldIntercept(Object key) { return !((String) key).endsWith(ERRORS_KEY_SUFFIX); } private void updateGlobalErrors(InvocationContext ctx, Set<String> errorFiles, long flagsBitSet) { // remove or update .errors accordingly VisitableCommand cmd; if (errorFiles.isEmpty()) { cmd = commandsFactory.buildRemoveCommand(ERRORS_KEY_SUFFIX, null, flagsBitSet); } else { StringBuilder sb = new StringBuilder(); for (String fileName : errorFiles) { if (sb.length() > 0) { sb.append('\n'); } sb.append(fileName); } cmd = commandsFactory.buildPutKeyValueCommand(ERRORS_KEY_SUFFIX, sb.toString(), DEFAULT_METADATA, flagsBitSet); } invoker.invoke(ctx, cmd); } }