Use proper buffering when reading remote messages

Messages now have their size in a header, along with their type
and are only executed once fully received. We dont block anymore
while trying to read a full message.
This commit is contained in:
Maxime Coste 2016-08-31 19:33:02 +01:00
parent 044a6ce860
commit 563497ade7
3 changed files with 229 additions and 167 deletions

View File

@ -32,6 +32,7 @@ enum class MemoryDomain
WordDB, WordDB,
Selections, Selections,
History, History,
Remote,
Count Count
}; };
@ -59,6 +60,7 @@ inline const char* domain_name(MemoryDomain domain)
case MemoryDomain::Client: return "Client"; case MemoryDomain::Client: return "Client";
case MemoryDomain::Selections: return "Selections"; case MemoryDomain::Selections: return "Selections";
case MemoryDomain::History: return "History"; case MemoryDomain::History: return "History";
case MemoryDomain::Remote: return "Remote";
case MemoryDomain::Count: break; case MemoryDomain::Count: break;
} }
kak_assert(false); kak_assert(false);

View File

@ -21,8 +21,9 @@
namespace Kakoune namespace Kakoune
{ {
enum class MessageType enum class MessageType : char
{ {
Unknown,
Connect, Connect,
Command, Command,
MenuShow, MenuShow,
@ -39,18 +40,18 @@ enum class MessageType
struct socket_error{}; struct socket_error{};
class Message class MsgWriter
{ {
public: public:
Message(int sock, MessageType type) : m_socket(sock) MsgWriter(int sock, MessageType type) : m_socket(sock)
{ {
write(type); write(type);
write((uint32_t)0); // message size, to be patched on write
} }
~Message() noexcept(false) ~MsgWriter() noexcept(false)
{ {
if (m_stream.size() == 0) *reinterpret_cast<uint32_t*>(m_stream.data()+1) = (uint32_t)m_stream.size();
return;
int res = ::write(m_socket, m_stream.data(), m_stream.size()); int res = ::write(m_socket, m_stream.data(), m_stream.size());
if (res == 0) if (res == 0)
throw peer_disconnected{}; throw peer_disconnected{};
@ -131,112 +132,163 @@ public:
} }
private: private:
Vector<char> m_stream; Vector<char, MemoryDomain::Remote> m_stream;
int m_socket; int m_socket;
}; };
void read(int socket, char* buffer, size_t size) class MsgReader
{ {
while (size) public:
void read_available(int sock)
{ {
int res = ::read(socket, buffer, size); if (m_write_pos < header_size)
{
m_stream.resize(header_size);
read_from_socket(sock, header_size - m_write_pos);
if (m_write_pos == header_size)
m_stream.resize(size());
}
else
read_from_socket(sock, size() - m_write_pos);
}
bool ready() const
{
return m_write_pos >= header_size and m_write_pos == size();
}
uint32_t size() const
{
kak_assert(m_write_pos >= header_size);
return *reinterpret_cast<const uint32_t*>(m_stream.data()+1);
}
MessageType type() const
{
kak_assert(m_write_pos >= header_size);
return *reinterpret_cast<const MessageType*>(m_stream.data());
}
void read(char* buffer, size_t size)
{
if (m_stream.size() - m_read_pos < size)
throw peer_disconnected{};
memcpy(buffer, m_stream.data() + m_read_pos, size);
m_read_pos += size;
}
template<typename T>
T read()
{
union U
{
T object;
alignas(T) char data[sizeof(T)];
U() {}
~U() { object.~T(); }
} u;
read(u.data, sizeof(T));
return u.object;
}
template<typename T>
Vector<T> read_vector()
{
uint32_t size = read<uint32_t>();
Vector<T> res;
res.reserve(size);
while (size--)
res.push_back(read<T>());
return res;
}
template<typename Val, MemoryDomain domain>
IdMap<Val, domain> read_idmap()
{
uint32_t size = read<uint32_t>();
IdMap<Val, domain> res;
res.reserve(size);
while (size--)
{
auto key = read<String>();
auto val = read<Val>();
res.append({std::move(key), std::move(val)});
}
return res;
}
void reset()
{
m_stream.resize(0);
m_write_pos = 0;
m_read_pos = header_size;
}
private:
void read_from_socket(int sock, size_t size)
{
int res = ::read(sock, m_stream.data() + m_write_pos, size);
if (res == 0) if (res == 0)
throw peer_disconnected{}; throw peer_disconnected{};
if (res < 0) if (res < 0)
throw socket_error{}; throw socket_error{};
m_write_pos += res;
buffer += res;
size -= res;
} }
}
template<typename T>
T read(int socket)
{
union U
{
T object;
alignas(T) char data[sizeof(T)];
U() {}
~U() { object.~T(); }
} u;
read(socket, u.data, sizeof(T));
return u.object;
}
static constexpr uint32_t header_size = sizeof(MessageType) + sizeof(uint32_t);
Vector<char, MemoryDomain::Remote> m_stream;
uint32_t m_write_pos = 0;
uint32_t m_read_pos = header_size;
};
template<> template<>
String read<String>(int socket) String MsgReader::read<String>()
{ {
ByteCount length = read<ByteCount>(socket); ByteCount length = read<ByteCount>();
String res; String res;
if (length > 0) if (length > 0)
{ {
res.force_size((int)length); res.force_size((int)length);
read(socket, &res[0_byte], (int)length); read(&res[0_byte], (int)length);
} }
return res; return res;
} }
template<typename T>
Vector<T> read_vector(int socket)
{
uint32_t size = read<uint32_t>(socket);
Vector<T> res;
res.reserve(size);
while (size--)
res.push_back(read<T>(socket));
return res;
}
template<> template<>
Color read<Color>(int socket) Color MsgReader::read<Color>()
{ {
Color res; Color res;
res.color = read<Color::NamedColor>(socket); res.color = read<Color::NamedColor>();
if (res.color == Color::RGB) if (res.color == Color::RGB)
{ {
res.r = read<unsigned char>(socket); res.r = read<unsigned char>();
res.g = read<unsigned char>(socket); res.g = read<unsigned char>();
res.b = read<unsigned char>(socket); res.b = read<unsigned char>();
} }
return res; return res;
} }
template<> template<>
DisplayAtom read<DisplayAtom>(int socket) DisplayAtom MsgReader::read<DisplayAtom>()
{ {
DisplayAtom atom(read<String>(socket)); DisplayAtom atom(read<String>());
atom.face = read<Face>(socket); atom.face = read<Face>();
return atom; return atom;
} }
template<>
DisplayLine read<DisplayLine>(int socket)
{
return DisplayLine(read_vector<DisplayAtom>(socket));
}
template<> template<>
DisplayBuffer read<DisplayBuffer>(int socket) DisplayLine MsgReader::read<DisplayLine>()
{
return DisplayLine(read_vector<DisplayAtom>());
}
template<>
DisplayBuffer MsgReader::read<DisplayBuffer>()
{ {
DisplayBuffer db; DisplayBuffer db;
db.lines() = read_vector<DisplayLine>(socket); db.lines() = read_vector<DisplayLine>();
return db; return db;
} }
template<typename Val, MemoryDomain domain>
IdMap<Val, domain> read_idmap(int socket)
{
uint32_t size = read<uint32_t>(socket);
IdMap<Val, domain> res;
res.reserve(size);
while (size--)
{
auto key = read<String>(socket);
auto val = read<Val>(socket);
res.append({std::move(key), std::move(val)});
}
return res;
}
class RemoteUI : public UserInterface class RemoteUI : public UserInterface
{ {
@ -275,14 +327,19 @@ public:
private: private:
FDWatcher m_socket_watcher; FDWatcher m_socket_watcher;
MsgReader m_reader;
CharCoord m_dimensions; CharCoord m_dimensions;
InputCallback m_input_callback; InputCallback m_input_callback;
}; };
RemoteUI::RemoteUI(int socket, CharCoord dimensions) RemoteUI::RemoteUI(int socket, CharCoord dimensions)
: m_socket_watcher(socket, [this](FDWatcher&, EventMode mode) { : m_socket_watcher(socket, [this](FDWatcher& watcher, EventMode mode) {
if (m_input_callback) const int sock = watcher.fd();
while (fd_readable(sock) and not m_reader.ready())
m_reader.read_available(sock);
if (m_reader.ready() and m_input_callback)
m_input_callback(mode); m_input_callback(mode);
}), }),
m_dimensions(dimensions) m_dimensions(dimensions)
@ -300,7 +357,7 @@ void RemoteUI::menu_show(ConstArrayView<DisplayLine> choices,
CharCoord anchor, Face fg, Face bg, CharCoord anchor, Face fg, Face bg,
MenuStyle style) MenuStyle style)
{ {
Message msg{m_socket_watcher.fd(), MessageType::MenuShow}; MsgWriter msg{m_socket_watcher.fd(), MessageType::MenuShow};
msg.write(choices); msg.write(choices);
msg.write(anchor); msg.write(anchor);
msg.write(fg); msg.write(fg);
@ -310,20 +367,20 @@ void RemoteUI::menu_show(ConstArrayView<DisplayLine> choices,
void RemoteUI::menu_select(int selected) void RemoteUI::menu_select(int selected)
{ {
Message msg{m_socket_watcher.fd(), MessageType::MenuSelect}; MsgWriter msg{m_socket_watcher.fd(), MessageType::MenuSelect};
msg.write(selected); msg.write(selected);
} }
void RemoteUI::menu_hide() void RemoteUI::menu_hide()
{ {
Message msg{m_socket_watcher.fd(), MessageType::MenuHide}; MsgWriter msg{m_socket_watcher.fd(), MessageType::MenuHide};
} }
void RemoteUI::info_show(StringView title, StringView content, void RemoteUI::info_show(StringView title, StringView content,
CharCoord anchor, Face face, CharCoord anchor, Face face,
InfoStyle style) InfoStyle style)
{ {
Message msg{m_socket_watcher.fd(), MessageType::InfoShow}; MsgWriter msg{m_socket_watcher.fd(), MessageType::InfoShow};
msg.write(title); msg.write(title);
msg.write(content); msg.write(content);
msg.write(anchor); msg.write(anchor);
@ -333,14 +390,14 @@ void RemoteUI::info_show(StringView title, StringView content,
void RemoteUI::info_hide() void RemoteUI::info_hide()
{ {
Message msg{m_socket_watcher.fd(), MessageType::InfoHide}; MsgWriter msg{m_socket_watcher.fd(), MessageType::InfoHide};
} }
void RemoteUI::draw(const DisplayBuffer& display_buffer, void RemoteUI::draw(const DisplayBuffer& display_buffer,
const Face& default_face, const Face& default_face,
const Face& padding_face) const Face& padding_face)
{ {
Message msg{m_socket_watcher.fd(), MessageType::Draw}; MsgWriter msg{m_socket_watcher.fd(), MessageType::Draw};
msg.write(display_buffer); msg.write(display_buffer);
msg.write(default_face); msg.write(default_face);
msg.write(padding_face); msg.write(padding_face);
@ -350,7 +407,7 @@ void RemoteUI::draw_status(const DisplayLine& status_line,
const DisplayLine& mode_line, const DisplayLine& mode_line,
const Face& default_face) const Face& default_face)
{ {
Message msg{m_socket_watcher.fd(), MessageType::DrawStatus}; MsgWriter msg{m_socket_watcher.fd(), MessageType::DrawStatus};
msg.write(status_line); msg.write(status_line);
msg.write(mode_line); msg.write(mode_line);
msg.write(default_face); msg.write(default_face);
@ -358,31 +415,31 @@ void RemoteUI::draw_status(const DisplayLine& status_line,
void RemoteUI::refresh(bool force) void RemoteUI::refresh(bool force)
{ {
Message msg{m_socket_watcher.fd(), MessageType::Refresh}; MsgWriter msg{m_socket_watcher.fd(), MessageType::Refresh};
msg.write(force); msg.write(force);
} }
void RemoteUI::set_ui_options(const Options& options) void RemoteUI::set_ui_options(const Options& options)
{ {
Message msg{m_socket_watcher.fd(), MessageType::SetOptions}; MsgWriter msg{m_socket_watcher.fd(), MessageType::SetOptions};
msg.write(options); msg.write(options);
} }
bool RemoteUI::is_key_available() bool RemoteUI::is_key_available()
{ {
return fd_readable(m_socket_watcher.fd()); return m_reader.ready();
} }
Key RemoteUI::get_key() Key RemoteUI::get_key()
{ {
kak_assert(m_reader.ready());
try try
{ {
const int sock = m_socket_watcher.fd(); if (m_reader.type() != MessageType::Key)
const auto msg = read<MessageType>(sock);
if (msg != MessageType::Key)
throw client_removed{ false }; throw client_removed{ false };
Key key = read<Key>(sock); Key key = m_reader.read<Key>();
m_reader.reset();
if (key.modifiers == Key::Modifiers::Resize) if (key.modifiers == Key::Modifiers::Resize)
m_dimensions = key.coord(); m_dimensions = key.coord();
return key; return key;
@ -444,7 +501,7 @@ RemoteClient::RemoteClient(StringView session, std::unique_ptr<UserInterface>&&
int sock = connect_to(session); int sock = connect_to(session);
{ {
Message msg{sock, MessageType::Connect}; MsgWriter msg{sock, MessageType::Connect};
msg.write(init_command); msg.write(init_command);
msg.write(m_ui->dimensions()); msg.write(m_ui->dimensions());
msg.write(env_vars); msg.write(env_vars);
@ -452,82 +509,78 @@ RemoteClient::RemoteClient(StringView session, std::unique_ptr<UserInterface>&&
m_ui->set_input_callback([this](EventMode){ write_next_key(); }); m_ui->set_input_callback([this](EventMode){ write_next_key(); });
m_socket_watcher.reset(new FDWatcher{sock, [this](FDWatcher&, EventMode){ process_available_messages(); }}); MsgReader reader;
} m_socket_watcher.reset(new FDWatcher{sock, [this, reader](FDWatcher& watcher, EventMode) mutable {
const int sock = watcher.fd();
while (fd_readable(sock) and not reader.ready())
reader.read_available(sock);
void RemoteClient::process_available_messages() if (not reader.ready())
{ return;
int socket = m_socket_watcher->fd();
do {
process_next_message();
} while (fd_readable(socket));
}
void RemoteClient::process_next_message() auto clear_reader = on_scope_end([&reader] { reader.reset(); });
{ switch (reader.type())
int socket = m_socket_watcher->fd(); {
const auto msg = read<MessageType>(socket); case MessageType::MenuShow:
switch (msg) {
{ auto choices = reader.read_vector<DisplayLine>();
case MessageType::MenuShow: auto anchor = reader.read<CharCoord>();
{ auto fg = reader.read<Face>();
auto choices = read_vector<DisplayLine>(socket); auto bg = reader.read<Face>();
auto anchor = read<CharCoord>(socket); auto style = reader.read<MenuStyle>();
auto fg = read<Face>(socket); m_ui->menu_show(choices, anchor, fg, bg, style);
auto bg = read<Face>(socket); break;
auto style = read<MenuStyle>(socket); }
m_ui->menu_show(choices, anchor, fg, bg, style); case MessageType::MenuSelect:
break; m_ui->menu_select(reader.read<int>());
} break;
case MessageType::MenuSelect: case MessageType::MenuHide:
m_ui->menu_select(read<int>(socket)); m_ui->menu_hide();
break; break;
case MessageType::MenuHide: case MessageType::InfoShow:
m_ui->menu_hide(); {
break; auto title = reader.read<String>();
case MessageType::InfoShow: auto content = reader.read<String>();
{ auto anchor = reader.read<CharCoord>();
auto title = read<String>(socket); auto face = reader.read<Face>();
auto content = read<String>(socket); auto style = reader.read<InfoStyle>();
auto anchor = read<CharCoord>(socket); m_ui->info_show(title, content, anchor, face, style);
auto face = read<Face>(socket); break;
auto style = read<InfoStyle>(socket); }
m_ui->info_show(title, content, anchor, face, style); case MessageType::InfoHide:
break; m_ui->info_hide();
} break;
case MessageType::InfoHide: case MessageType::Draw:
m_ui->info_hide(); {
break; auto display_buffer = reader.read<DisplayBuffer>();
case MessageType::Draw: auto default_face = reader.read<Face>();
{ auto padding_face = reader.read<Face>();
auto display_buffer = read<DisplayBuffer>(socket); m_ui->draw(display_buffer, default_face, padding_face);
auto default_face = read<Face>(socket); break;
auto padding_face = read<Face>(socket); }
m_ui->draw(display_buffer, default_face, padding_face); case MessageType::DrawStatus:
break; {
} auto status_line = reader.read<DisplayLine>();
case MessageType::DrawStatus: auto mode_line = reader.read<DisplayLine>();
{ auto default_face = reader.read<Face>();
auto status_line = read<DisplayLine>(socket); m_ui->draw_status(status_line, mode_line, default_face);
auto mode_line = read<DisplayLine>(socket); break;
auto default_face = read<Face>(socket); }
m_ui->draw_status(status_line, mode_line, default_face); case MessageType::Refresh:
break; m_ui->refresh(reader.read<bool>());
} break;
case MessageType::Refresh: case MessageType::SetOptions:
m_ui->refresh(read<bool>(socket)); m_ui->set_ui_options(reader.read_idmap<String, MemoryDomain::Options>());
break; break;
case MessageType::SetOptions: default:
m_ui->set_ui_options(read_idmap<String, MemoryDomain::Options>(socket)); kak_assert(false);
break; }
default: }});
kak_assert(false);
}
} }
void RemoteClient::write_next_key() void RemoteClient::write_next_key()
{ {
Message msg(m_socket_watcher->fd(), MessageType::Key); MsgWriter msg(m_socket_watcher->fd(), MessageType::Key);
// do that before checking dimensions as get_key may // do that before checking dimensions as get_key may
// handle a resize event. // handle a resize event.
msg.write(m_ui->get_key()); msg.write(m_ui->get_key());
@ -537,7 +590,7 @@ void send_command(StringView session, StringView command)
{ {
int sock = connect_to(session); int sock = connect_to(session);
auto close_sock = on_scope_end([sock]{ close(sock); }); auto close_sock = on_scope_end([sock]{ close(sock); });
Message msg{sock, MessageType::Command}; MsgWriter msg{sock, MessageType::Command};
msg.write(command); msg.write(command);
} }
@ -563,14 +616,22 @@ private:
void handle_available_input() void handle_available_input()
{ {
const int sock = m_socket_watcher.fd(); const int sock = m_socket_watcher.fd();
const auto msg = read<MessageType>(sock); do
switch (msg) {
m_reader.read_available(sock);
}
while (fd_readable(sock) and not m_reader.ready());
if (not m_reader.ready())
return;
switch (m_reader.type())
{ {
case MessageType::Connect: case MessageType::Connect:
{ {
auto init_command = read<String>(sock); auto init_command = m_reader.read<String>();
auto dimensions = read<CharCoord>(sock); auto dimensions = m_reader.read<CharCoord>();
auto env_vars = read_idmap<String, MemoryDomain::EnvVars>(sock); auto env_vars = m_reader.read_idmap<String, MemoryDomain::EnvVars>();
std::unique_ptr<UserInterface> ui{new RemoteUI{sock, dimensions}}; std::unique_ptr<UserInterface> ui{new RemoteUI{sock, dimensions}};
ClientManager::instance().create_client(std::move(ui), ClientManager::instance().create_client(std::move(ui),
std::move(env_vars), std::move(env_vars),
@ -580,7 +641,7 @@ private:
} }
case MessageType::Command: case MessageType::Command:
{ {
auto command = read<String>(sock); auto command = m_reader.read<String>();
if (not command.empty()) try if (not command.empty()) try
{ {
Context context{Context::EmptyContextFlag{}}; Context context{Context::EmptyContextFlag{}};
@ -605,6 +666,7 @@ private:
} }
FDWatcher m_socket_watcher; FDWatcher m_socket_watcher;
MsgReader m_reader;
}; };
Server::Server(String session_name) Server::Server(String session_name)

View File

@ -32,8 +32,6 @@ public:
const EnvVarMap& env_vars, StringView init_command); const EnvVarMap& env_vars, StringView init_command);
private: private:
void process_available_messages();
void process_next_message();
void write_next_key(); void write_next_key();
std::unique_ptr<UserInterface> m_ui; std::unique_ptr<UserInterface> m_ui;