diff --git a/src/client_manager.cc b/src/client_manager.cc index ad9a5009..6d4ef641 100644 --- a/src/client_manager.cc +++ b/src/client_manager.cc @@ -174,7 +174,8 @@ void ClientManager::clear_client_trash() bool ClientManager::validate_client_name(StringView name) const { - return const_cast(this)->get_client_ifp(name) == nullptr; + return all_of(name, is_identifier) and + const_cast(this)->get_client_ifp(name) == nullptr; } Client* ClientManager::get_client_ifp(StringView name) diff --git a/src/commands.cc b/src/commands.cc index eff21355..889a5307 100644 --- a/src/commands.cc +++ b/src/commands.cc @@ -846,7 +846,7 @@ void define_command(const ParametersParser& parser, Context& context, const Shel const String& cmd_name = parser[0]; auto& cm = CommandManager::instance(); - if (contains_that(cmd_name, is_blank)) + if (not all_of(cmd_name, is_identifier)) throw runtime_error(format("invalid command name: '{}'", cmd_name)); if (cm.command_defined(cmd_name) and not parser.get_switch("allow-override")) diff --git a/src/keymap_manager.cc b/src/keymap_manager.cc index fd83757c..844f150f 100644 --- a/src/keymap_manager.cc +++ b/src/keymap_manager.cc @@ -61,7 +61,7 @@ void KeymapManager::add_user_mode(String user_mode_name) if (contains(user_modes(), user_mode_name)) throw runtime_error(format("user mode '{}' already defined", user_mode_name)); - if (contains_that(user_mode_name, is_blank)) + if (not all_of(user_mode_name, is_identifier)) throw runtime_error(format("invalid mode name: '{}'", user_mode_name)); user_modes().push_back(std::move(user_mode_name)); diff --git a/src/option_manager.hh b/src/option_manager.hh index 674cf6da..84c1be98 100644 --- a/src/option_manager.hh +++ b/src/option_manager.hh @@ -214,13 +214,11 @@ public: const T& value, OptionFlags flags = OptionFlags::None) { - auto is_not_identifier = [](char c) { - return (c < 'a' or c > 'z') and - (c < 'A' or c > 'Z') and - (c < '0' or c > '9') and c != '_'; + auto is_option_identifier = [](char c) { + return is_basic_alpha(c) or is_basic_digit(c) or c == '_'; }; - if (contains_that(name, is_not_identifier)) + if (not all_of(name, is_option_identifier)) throw runtime_error{format("name '{}' contains char out of [a-zA-Z0-9_]", name)}; auto& opts = m_global_manager.m_options; diff --git a/src/ranges.hh b/src/ranges.hh index 58f14db4..a762dd2f 100644 --- a/src/ranges.hh +++ b/src/ranges.hh @@ -338,6 +338,20 @@ bool contains_that(Range&& range, T op) return find_if(range, op) != end(range); } +template +bool all_of(Range&& range, T op) +{ + using std::begin; using std::end; + return std::all_of(begin(range), end(range), op); +} + +template +bool any_of(Range&& range, T op) +{ + using std::begin; using std::end; + return std::any_of(begin(range), end(range), op); +} + template void unordered_erase(Range&& vec, U&& value) { diff --git a/src/remote.cc b/src/remote.cc index 521663f6..f3dd7d06 100644 --- a/src/remote.cc +++ b/src/remote.cc @@ -777,8 +777,8 @@ private: Server::Server(String session_name) : m_session{std::move(session_name)} { - if (contains(m_session, '/')) - throw runtime_error{"Cannot create sessions with '/' in their name"}; + if (not all_of(m_session, is_identifier)) + throw runtime_error{format("Invalid session name '{}'", session_name)}; int listen_sock = socket(AF_UNIX, SOCK_STREAM, 0); fcntl(listen_sock, F_SETFD, FD_CLOEXEC); @@ -816,8 +816,8 @@ Server::Server(String session_name) bool Server::rename_session(StringView name) { - if (contains(name, '/')) - throw runtime_error{"Cannot create sessions with '/' in their name"}; + if (not all_of(name, is_identifier)) + throw runtime_error{format("Invalid session name '{}'", name)}; String old_socket_file = format("{}/kakoune/{}/{}", tmpdir(), get_user_name(geteuid()), m_session); diff --git a/src/unicode.hh b/src/unicode.hh index 01fcfb23..1dcc836b 100644 --- a/src/unicode.hh +++ b/src/unicode.hh @@ -53,6 +53,17 @@ inline bool is_basic_alpha(Codepoint c) noexcept return (c >= 'a' and c <= 'z') or (c >= 'A' and c <= 'Z'); } +inline bool is_basic_digit(Codepoint c) noexcept +{ + return c >= '0' and c <= '9'; +} + +inline bool is_identifier(Codepoint c) noexcept +{ + return is_basic_alpha(c) or is_basic_digit(c) or + c == '_' or c == '-'; +} + inline ColumnCount codepoint_width(Codepoint c) noexcept { if (c == '\n')