/* * Copyright (c) 2015 Spotify AB. * * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ package com.spotify.heroic.shell; import com.spotify.heroic.shell.protocol.Acknowledge; import com.spotify.heroic.shell.protocol.CommandDefinition; import com.spotify.heroic.shell.protocol.CommandDone; import com.spotify.heroic.shell.protocol.CommandOutput; import com.spotify.heroic.shell.protocol.CommandOutputFlush; import com.spotify.heroic.shell.protocol.CommandsRequest; import com.spotify.heroic.shell.protocol.CommandsResponse; import com.spotify.heroic.shell.protocol.EvaluateRequest; import com.spotify.heroic.shell.protocol.FileClose; import com.spotify.heroic.shell.protocol.FileFlush; import com.spotify.heroic.shell.protocol.FileNewInputStream; import com.spotify.heroic.shell.protocol.FileNewOutputStream; import com.spotify.heroic.shell.protocol.FileOpened; import com.spotify.heroic.shell.protocol.FileRead; import com.spotify.heroic.shell.protocol.FileReadResult; import com.spotify.heroic.shell.protocol.FileWrite; import com.spotify.heroic.shell.protocol.Message; import com.spotify.heroic.shell.protocol.SimpleMessageVisitor; import eu.toolchain.async.AsyncFramework; import eu.toolchain.async.AsyncFuture; import eu.toolchain.serializer.SerializerFramework; import lombok.extern.slf4j.Slf4j; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.InetSocketAddress; import java.net.Socket; import java.nio.file.Paths; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static java.util.Optional.empty; import static java.util.Optional.of; @Slf4j public class RemoteCoreInterface implements CoreInterface { public static final int DEFAULT_PORT = 9190; public static final int MAX_READ = 4096; final InetSocketAddress address; final AsyncFramework async; final SerializerFramework serializer; public RemoteCoreInterface( InetSocketAddress address, AsyncFramework async, SerializerFramework serializer ) throws IOException { this.address = address; this.async = async; this.serializer = serializer; } @Override public AsyncFuture<Void> evaluate(final List<String> command, final ShellIO io) throws Exception { return async.call(() -> { final AtomicBoolean running = new AtomicBoolean(true); final AtomicInteger fileCounter = new AtomicInteger(); final Map<Integer, InputStream> reading = new HashMap<>(); final Map<Integer, OutputStream> writing = new HashMap<>(); final Map<Integer, Callable<Void>> closers = new HashMap<>(); try (final ShellConnection c = connect()) { c.send(new EvaluateRequest(command)); final Message.Visitor<Optional<Message>> visitor = setupVisitor(io, running, fileCounter, reading, writing, closers); while (true) { final Message in = c.receive(); final Optional<Message> out = in.visit(visitor); if (!running.get()) { break; } if (out.isPresent()) { final Message o = out.get(); try { c.send(o); } catch (Exception e) { throw new Exception("Failed to send response: " + o, e); } } } } return null; }); } private SimpleMessageVisitor<Optional<Message>> setupVisitor( final ShellIO io, final AtomicBoolean running, final AtomicInteger fileCounter, final Map<Integer, InputStream> reading, final Map<Integer, OutputStream> writing, final Map<Integer, Callable<Void>> closers ) { return new SimpleMessageVisitor<Optional<Message>>() { public Optional<Message> visitCommandDone(CommandDone m) { running.set(false); return empty(); } @Override public Optional<Message> visitCommandOutput(CommandOutput m) { io.out().write(m.getData()); return empty(); } @Override public Optional<Message> visitCommandOutputFlush(CommandOutputFlush m) { io.out().flush(); return empty(); } @Override public Optional<Message> visitFileNewInputStream(FileNewInputStream m) throws Exception { final InputStream in = io.newInputStream(Paths.get(m.getPath()), m.getOptionsAsArray()); final int h = fileCounter.incrementAndGet(); reading.put(h, in); closers.put(h, () -> { in.close(); reading.remove(h); return null; }); return of(new FileOpened(h)); } @Override public Optional<Message> visitFileNewOutputStream(FileNewOutputStream m) throws Exception { final OutputStream out = io.newOutputStream(Paths.get(m.getPath()), m.getOptionsAsArray()); final int h = fileCounter.incrementAndGet(); writing.put(h, out); closers.put(h, () -> { out.close(); writing.remove(h); return null; }); return of(new FileOpened(h)); } @Override public Optional<Message> visitFileFlush(FileFlush m) throws Exception { writer(m.getHandle()).flush(); return of(new Acknowledge()); } @Override public Optional<Message> visitFileClose(FileClose m) throws Exception { closer(m.getHandle()).call(); return of(new Acknowledge()); } @Override public Optional<Message> visitFileWrite(FileWrite m) throws Exception { final byte[] data = m.getData(); writer(m.getHandle()).write(data, 0, data.length); return empty(); } @Override public Optional<Message> visitFileRead(FileRead m) throws Exception { final byte[] buffer = new byte[Math.min(MAX_READ, m.getLength())]; int read = reader(m.getHandle()).read(buffer); if (read == 0) { return of(new FileReadResult(new byte[0])); } return of(new FileReadResult(Arrays.copyOf(buffer, read))); } @Override protected Optional<Message> visitUnknown(Message message) { throw new IllegalArgumentException("Unhandled message: " + message); } private InputStream reader(int handle) throws Exception { final InputStream r = reading.get(handle); if (r == null) { throw new Exception("No such handle: " + handle); } return r; } private OutputStream writer(int handle) throws Exception { final OutputStream w = writing.get(handle); if (w == null) { throw new Exception("No such handle: " + handle); } return w; } private Callable<Void> closer(int handle) throws Exception { final Callable<Void> closer = closers.get(handle); if (closer == null) { throw new Exception("No such handle: " + handle); } return closer; } }; } @Override public List<CommandDefinition> commands() throws Exception { try (final ShellConnection c = connect()) { return c.request(new CommandsRequest(), CommandsResponse.class).getCommands(); } } @Override public void shutdown() throws Exception { } public static RemoteCoreInterface fromConnectString( String connect, AsyncFramework async, SerializerFramework serializer ) throws IOException { final String host; final int port; final int index; if ((index = connect.indexOf(':')) > 0) { host = connect.substring(0, index); port = Integer.parseInt(connect.substring(index + 1)); } else { host = connect; port = DEFAULT_PORT; } return new RemoteCoreInterface(new InetSocketAddress(host, port), async, serializer); } private ShellConnection connect() throws IOException { final Socket socket = new Socket(); socket.connect(address); return new ShellConnection(serializer, socket); } }