diff --git a/src/buffer_utils.cc b/src/buffer_utils.cc index 0162822c..d8daa944 100644 --- a/src/buffer_utils.cc +++ b/src/buffer_utils.cc @@ -265,4 +265,15 @@ Vector undo_group_as_strings(const Buffer::UndoGroup& undo_group) return res; } +String generate_buffer_name(StringView pattern) +{ + auto& buffer_manager = BufferManager::instance(); + for (int i = 0; true; ++i) + { + String name = format(pattern, i); + if (buffer_manager.get_buffer_ifp(name) == nullptr) + return name; + } +} + } diff --git a/src/buffer_utils.hh b/src/buffer_utils.hh index 8a9c6d24..5af08e0e 100644 --- a/src/buffer_utils.hh +++ b/src/buffer_utils.hh @@ -87,6 +87,8 @@ void write_to_debug_buffer(StringView str); Vector history_as_strings(const Vector& history); Vector undo_group_as_strings(const Buffer::UndoGroup& undo_group); +String generate_buffer_name(StringView pattern); + } #endif // buffer_utils_hh_INCLUDED diff --git a/src/commands.cc b/src/commands.cc index accb3e37..d6feb501 100644 --- a/src/commands.cc +++ b/src/commands.cc @@ -356,16 +356,8 @@ void edit(const ParametersParser& parser, Context& context, const ShellContext&) (parser.get_switch("debug") ? Buffer::Flags::Debug : Buffer::Flags::None); auto& buffer_manager = BufferManager::instance(); - auto generate_scratch_name = [&] { - for (int i = 0; true; ++i) - { - String name = format("*scratch-{}*", i); - if (buffer_manager.get_buffer_ifp(name) == nullptr) - return name; - } - }; const auto& name = parser.positional_count() > 0 ? - parser[0] : (scratch ? generate_scratch_name() : context.buffer().name()); + parser[0] : (scratch ? generate_buffer_name("*scratch-{}*") : context.buffer().name()); Buffer* buffer = buffer_manager.get_buffer_ifp(name); if (scratch) diff --git a/src/main.cc b/src/main.cc index 4ec8effa..2ee7bc1d 100644 --- a/src/main.cc +++ b/src/main.cc @@ -661,9 +661,21 @@ int run_client(StringView session, StringView name, StringView client_init, { try { + Optional stdin_fd; + if (not isatty(0)) + { + // move stdin to another fd, and restore tty as stdin + stdin_fd = dup(0); + int tty = open("/dev/tty", O_RDONLY); + dup2(tty, 0); + close(tty); + } + EventManager event_manager; RemoteClient client{session, name, make_ui(ui_type), getpid(), get_env_vars(), - client_init, std::move(init_coord)}; + client_init, std::move(init_coord), stdin_fd}; + stdin_fd.map(close); + if (suspend) raise(SIGTSTP); while (not client.exit_status() and client.is_ui_ok()) diff --git a/src/optional.hh b/src/optional.hh index e3e0c7e9..8531042e 100644 --- a/src/optional.hh +++ b/src/optional.hh @@ -60,11 +60,12 @@ public: bool operator!=(const Optional& other) const { return !(*this == other); } template - void emplace(Args&&... args) + T& emplace(Args&&... args) { destruct_ifn(); new (&m_value) T{std::forward(args)...}; m_valid = true; + return m_value; } T& operator*() diff --git a/src/remote.cc b/src/remote.cc index 8e475e72..dea4d853 100644 --- a/src/remote.cc +++ b/src/remote.cc @@ -261,25 +261,56 @@ public: return Reader::read(*this); } + Optional ancillary_fd() + { + auto res = m_ancillary_fd; + m_ancillary_fd.reset(); + return res; + } + + ~MsgReader() + { + m_ancillary_fd.map(close); + } + void reset() { m_stream.resize(0); m_write_pos = 0; m_read_pos = header_size; + m_ancillary_fd.map(close); } private: void read_from_socket(int sock, size_t size) { kak_assert(m_write_pos + size <= m_stream.size()); - int res = ::read(sock, m_stream.data() + m_write_pos, size); + iovec io{m_stream.data() + m_write_pos, size}; + alignas(cmsghdr) char fdbuf[CMSG_SPACE(sizeof(int))]; + + msghdr msg{}; + msg.msg_iov = &io; + msg.msg_iovlen = 1; + msg.msg_control = fdbuf; + msg.msg_controllen = sizeof(fdbuf); + + int res = recvmsg(sock, &msg, MSG_CMSG_CLOEXEC); if (res <= 0) throw disconnected{format("socket read failed: {}", strerror(errno))}; + m_write_pos += res; + + if (cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg && cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS && cmsg->cmsg_len == CMSG_LEN(sizeof(int))) + { + m_ancillary_fd.map(close); + memcpy(&m_ancillary_fd.emplace(), CMSG_DATA(cmsg), sizeof(int)); + } } static constexpr uint32_t header_size = sizeof(MessageType) + sizeof(uint32_t); Vector m_stream; + Optional m_ancillary_fd; uint32_t m_write_pos = 0; uint32_t m_read_pos = header_size; }; @@ -398,14 +429,32 @@ private: RemoteBuffer m_send_buffer; }; -static bool send_data(int fd, RemoteBuffer& buffer) +static bool send_data(int fd, RemoteBuffer& buffer, Optional ancillary_fd = {}) { while (not buffer.empty() and fd_writable(fd)) { - int res = ::write(fd, buffer.data(), buffer.size()); - if (res <= 0) - throw disconnected{format("socket write failed: {}", strerror(errno))}; - buffer.erase(buffer.begin(), buffer.begin() + res); + iovec io{buffer.data(), buffer.size()}; + alignas(cmsghdr) char fdbuf[CMSG_SPACE(sizeof(int))]; + + msghdr msg{}; + msg.msg_iov = &io; + msg.msg_iovlen = 1; + if (ancillary_fd) + { + msg.msg_control = fdbuf; + msg.msg_controllen = sizeof(fdbuf); + + cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); + cmsg->cmsg_len = CMSG_LEN(sizeof(int)); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + memcpy(CMSG_DATA(cmsg), &*ancillary_fd, sizeof(int)); + } + + int res = sendmsg(fd, &msg, 0); + if (res <= 0) + throw disconnected{format("socket write failed: {}", strerror(errno))}; + buffer.erase(buffer.begin(), buffer.begin() + res); } return buffer.empty(); } @@ -592,7 +641,7 @@ bool check_session(StringView session) RemoteClient::RemoteClient(StringView session, StringView name, std::unique_ptr&& ui, int pid, const EnvVarMap& env_vars, StringView init_command, - Optional init_coord) + Optional init_coord, Optional stdin_fd) : m_ui(std::move(ui)) { int sock = connect_to(session); @@ -601,6 +650,7 @@ RemoteClient::RemoteClient(StringView session, StringView name, std::unique_ptr< MsgWriter msg{m_send_buffer, MessageType::Connect}; msg.write(pid, name, init_command, init_coord, m_ui->dimensions(), env_vars); } + send_data(sock, m_send_buffer, stdin_fd); m_ui->set_on_key([this](Key key){ MsgWriter msg(m_send_buffer, MessageType::Key); @@ -750,6 +800,10 @@ private: auto init_coord = m_reader.read>(); auto dimensions = m_reader.read(); auto env_vars = m_reader.read>(); + + if (auto stdin_fd = m_reader.ancillary_fd()) + create_fifo_buffer(generate_buffer_name("*stdin-{}*"), *stdin_fd, Buffer::Flags::None); + auto* ui = new RemoteUI{sock, dimensions}; ClientManager::instance().create_client( std::unique_ptr(ui), pid, std::move(name), diff --git a/src/remote.hh b/src/remote.hh index e465d087..89c2f438 100644 --- a/src/remote.hh +++ b/src/remote.hh @@ -32,7 +32,7 @@ class RemoteClient public: RemoteClient(StringView session, StringView name, std::unique_ptr&& ui, int pid, const EnvVarMap& env_vars, StringView init_command, - Optional init_coord); + Optional init_coord, Optional stdin_fd); bool is_ui_ok() const; const Optional& exit_status() const { return m_exit_status; }