diff --git a/src/commands.cc b/src/commands.cc index 4260c9ab..fd6b7552 100644 --- a/src/commands.cc +++ b/src/commands.cc @@ -12,6 +12,7 @@ #include "event_manager.hh" #include "face_registry.hh" #include "file.hh" +#include "hash_map.hh" #include "highlighter.hh" #include "highlighters.hh" #include "option_manager.hh" @@ -1126,7 +1127,7 @@ const CommandDesc debug_cmd = { make_completer( [](const Context& context, CompletionFlags flags, const String& prefix, ByteCount cursor_pos) -> Completions { - auto c = {"info", "buffers", "options", "memory", "shared-strings"}; + auto c = {"info", "buffers", "options", "memory", "shared-strings", "profile-hash-maps"}; return { 0_byte, cursor_pos, complete(prefix, cursor_pos, c) }; }), [](const ParametersParser& parser, Context& context, const ShellContext&) @@ -1167,6 +1168,10 @@ const CommandDesc debug_cmd = { { StringRegistry::instance().debug_stats(); } + else if (parser[0] == "profile-hash-maps") + { + profile_hash_maps(); + } else throw runtime_error(format("unknown debug command '{}'", parser[0])); } diff --git a/src/hash.hh b/src/hash.hh index 896478f3..ddce1cff 100644 --- a/src/hash.hh +++ b/src/hash.hh @@ -58,6 +58,13 @@ struct Hash } }; +// Traits specifying if two types have compatible hashing, that is, +// if lhs == rhs => hash_value(lhs) == hash_value(rhs) +template +struct HashCompatible : std::false_type {}; + +template struct HashCompatible : std::true_type {}; + } #endif // hash_hh_INCLUDED diff --git a/src/hash_map.cc b/src/hash_map.cc new file mode 100644 index 00000000..bc2d6ce2 --- /dev/null +++ b/src/hash_map.cc @@ -0,0 +1,156 @@ +#include "hash_map.hh" + +#include "clock.hh" +#include "string.hh" +#include "buffer_utils.hh" +#include "unit_tests.hh" + +#include + +namespace Kakoune +{ + +UnitTest test_hash_map{[] { + // Basic usage + { + HashMap map; + map.insert({10, 1}); + map.insert({20, 2}); + kak_assert(map.find_index(0) == -1); + kak_assert(map.find_index(10) == 0); + kak_assert(map.find_index(20) == 1); + kak_assert(map[10] == 1); + kak_assert(map[20] == 2); + kak_assert(map[30] == 0); + map[30] = 3; + kak_assert(map.find_index(30) == 2); + map.remove(20); + kak_assert(map.find_index(30) == 1); + kak_assert(map.size() == 2); + } + + // Multiple entries with the same key + { + HashMap map; + map.insert({10, 1}); + map.insert({10, 2}); + kak_assert(map.find_index(10) == 0); + map.remove(10); + kak_assert(map.find_index(10) == 0); + map.remove(10); + kak_assert(map.find_index(10) == -1); + map.insert({20, 1}); + map.insert({20, 2}); + map.remove_all(20); + kak_assert(map.find_index(20) == -1); + } + + // Check hash compatible support + { + HashMap map; + map.insert({"test", 10}); + kak_assert(map["test"_sv] == 10); + map.remove("test"_sv); + } + + // make sure we get what we expect from the hash map + { + std::random_device dev; + std::default_random_engine re{dev()}; + std::uniform_int_distribution dist; + + HashMap map; + Vector> ref; + + for (int i = 0; i < 100; ++i) + { + auto key = dist(re), value = dist(re); + ref.push_back({key, value}); + map.insert({key, value}); + + std::random_shuffle(ref.begin(), ref.end()); + for (auto& elem : ref) + { + auto it = map.find(elem.first); + kak_assert(it != map.end() and it->value == elem.second); + } + } + } +}}; + +struct HashStats +{ + size_t max_dist; + float mean_dist; + float fill_rate; +}; + +template +HashStats HashIndex::compute_stats() const +{ + size_t count = 0; + size_t max_dist = 0; + size_t sum_dist = 0; + for (size_t slot = 0; slot < m_entries.size(); ++slot) + { + auto& entry = m_entries[slot]; + if (entry.index == -1) + continue; + ++count; + auto dist = slot - compute_slot(entry.hash); + max_dist = std::max(max_dist, dist); + sum_dist += dist; + } + + return { max_dist, (float)sum_dist / count, (float)count / m_entries.size() }; +} + +template +HashStats HashMap::compute_stats() const +{ + return m_index.compute_stats(); +} + +template +void do_profile(size_t count, StringView type) +{ + std::random_device dev; + std::default_random_engine re{dev()}; + std::uniform_int_distribution dist{0, count}; + + Vector vec; + for (size_t i = 0; i < count; ++i) + vec.push_back(i); + std::random_shuffle(vec.begin(), vec.end()); + + Map map; + auto start = Clock::now(); + + for (auto v : vec) + map.insert({v, dist(re)}); + auto after_insert = Clock::now(); + + for (size_t i = 0; i < count; ++i) + ++map[dist(re)]; + auto after_read = Clock::now(); + + for (size_t i = 0; i < count; ++i) + map.erase(dist(re)); + auto after_remove = Clock::now(); + + write_to_debug_buffer(format("{} ({}) -- inserts: {}ms, reads: {}ms, remove: {}ms", type, count, + std::chrono::duration_cast(after_insert - start).count(), + std::chrono::duration_cast(after_read - after_insert).count(), + std::chrono::duration_cast(after_remove - after_read).count())); +} + +void profile_hash_maps() +{ + for (auto i : { 1000, 10000, 100000, 1000000, 10000000 }) + { + do_profile>(i, "UnorderedMap"); + do_profile>(i, " HashMap "); + } +} + +} diff --git a/src/hash_map.hh b/src/hash_map.hh new file mode 100644 index 00000000..6e4e1899 --- /dev/null +++ b/src/hash_map.hh @@ -0,0 +1,299 @@ +#ifndef hash_map_hh_INCLUDED +#define hash_map_hh_INCLUDED + +#include "hash.hh" +#include "memory.hh" +#include "vector.hh" + +namespace Kakoune +{ + +class String; + +struct HashStats; + +template +struct HashIndex +{ + struct Entry + { + size_t hash; + int index; + }; + + void grow() + { + Vector old_entries = std::move(m_entries); + constexpr size_t init_size = 4; + m_entries.resize(old_entries.empty() ? init_size : old_entries.size() * 2, {0,-1}); + for (auto& entry : old_entries) + { + if (entry.index >= 0) + add(entry.hash, entry.index); + } + } + + void add(size_t hash, int index) + { + ++m_count; + if ((float)m_count / m_entries.size() > m_max_fill_rate) + grow(); + + Entry entry{hash, index}; + while (true) + { + auto target_slot = compute_slot(entry.hash); + for (auto slot = target_slot; slot < m_entries.size(); ++slot) + { + if (m_entries[slot].index == -1) + { + m_entries[slot] = entry; + return; + } + + // Robin hood hashing + auto candidate_slot = compute_slot(m_entries[slot].hash); + if (target_slot < candidate_slot) + { + std::swap(m_entries[slot], entry); + target_slot = candidate_slot; + } + } + // no free entries found, grow, try again + grow(); + } + } + + void remove(size_t hash, int index) + { + --m_count; + for (auto slot = compute_slot(hash); slot < m_entries.size(); ++slot) + { + kak_assert(m_entries[slot].index >= 0); + if (m_entries[slot].index == index) + { + m_entries[slot].index = -1; + // Recompact following entries + for (auto next = slot+1; next < m_entries.size(); ++next) + { + if (m_entries[next].index == -1 or + compute_slot(m_entries[next].hash) == next) + break; + kak_assert(compute_slot(m_entries[next].hash) < next); + std::swap(m_entries[next-1], m_entries[next]); + } + break; + } + } + } + + void ordered_fix_entries(int index) + { + // Fix entries index + for (auto& entry : m_entries) + { + if (entry.index >= index) + --entry.index; + } + } + + void unordered_fix_entries(size_t hash, int old_index, int new_index) + { + for (auto slot = compute_slot(hash); slot < m_entries.size(); ++slot) + { + if (m_entries[slot].index == old_index) + { + m_entries[slot].index = new_index; + return; + } + } + kak_assert(false); // entry not found ?! + } + + const Entry& operator[](size_t index) const { return m_entries[index]; } + size_t size() const { return m_entries.size(); } + size_t compute_slot(size_t hash) const + { + // We assume entries.size() is power of 2 + return m_entries.empty() ? 0 : hash & (m_entries.size()-1); + } + + void clear() { m_entries.clear(); } + + HashStats compute_stats() const; + +private: + size_t m_count = 0; + float m_max_fill_rate = 0.5f; + Vector m_entries; +}; + +template +struct HashMap +{ + struct Item + { + Key key; + Value value; + }; + + HashMap() = default; + + HashMap(std::initializer_list val) : m_items{val} + { + for (int i = 0; i < m_items.size(); ++i) + m_index.add(hash_value(m_items[i].key), i); + } + + Value& insert(Item item) + { + m_index.add(hash_value(item.key), (int)m_items.size()); + m_items.push_back(std::move(item)); + return m_items.back().value; + } + + template + using EnableIfHashCompatible = typename std::enable_if< + HashCompatible::type>::value + >::type; + + // For IdMap inteface compatibility, to remove + using Element = Item; + Value& append(Item item) { return insert(std::move(item)); } + static const String& get_id(const Element& e) { return e.key; } + + template> + int find_index(const KeyType& key, size_t hash) const + { + for (auto slot = m_index.compute_slot(hash); slot < m_index.size(); ++slot) + { + auto& entry = m_index[slot]; + if (entry.index == -1) + return -1; + if (entry.hash == hash and m_items[entry.index].key == key) + return entry.index; + } + return -1; + } + + template> + int find_index(const KeyType& key) const { return find_index(key, hash_value(key)); } + + template> + bool contains(const KeyType& key) const { return find_index(key) >= 0; } + + template> + Value& operator[](KeyType&& key) + { + const auto hash = hash_value(key); + auto index = find_index(key, hash); + if (index >= 0) + return m_items[index].value; + + m_index.add(hash, (int)m_items.size()); + m_items.push_back({Key{std::forward(key)}, {}}); + return m_items.back().value; + } + + template> + void remove(const KeyType& key) + { + const auto hash = hash_value(key); + int index = find_index(key, hash); + if (index >= 0) + { + m_items.erase(m_items.begin() + index); + m_index.remove(hash, index); + m_index.ordered_fix_entries(index); + } + } + + template> + void unordered_remove(const KeyType& key) + { + const auto hash = hash_value(key); + int index = find_index(key, hash); + if (index >= 0) + { + std::swap(m_items[index], m_items.back()); + m_items.pop_back(); + m_index.remove(hash, index); + if (index != m_items.size()) + m_index.unordered_fix_entries(hash_value(m_items[index].key), m_items.size(), index); + } + } + + void erase(const Key& key) { unordered_remove(key); } + + template> + void remove_all(const KeyType& key) + { + const auto hash = hash_value(key); + for (int index = find_index(key, hash); index >= 0; + index = find_index(key, hash)) + { + m_items.erase(m_items.begin() + index); + m_index.remove(hash, index); + m_index.ordered_fix_entries(index); + } + } + + using iterator = typename Vector::iterator; + iterator begin() { return m_items.begin(); } + iterator end() { return m_items.end(); } + + using const_iterator = typename Vector::const_iterator; + const_iterator begin() const { return m_items.begin(); } + const_iterator end() const { return m_items.end(); } + + template> + iterator find(const KeyType& key) + { + auto index = find_index(key); + return index >= 0 ? begin() + index : end(); + } + + template> + const_iterator find(const KeyType& key) const + { + return const_cast(this)->find(key); + } + + void clear() { m_items.clear(); m_index.clear(); } + + size_t size() const { return m_items.size(); } + bool empty() const { return m_items.empty(); } + void reserve(size_t size) + { + m_items.reserve(size); + // TODO: Reserve in the index as well + } + + // Equality is taking the order of insertion into account + template + bool operator==(const HashMap& other) const + { + return size() == other.size() and + std::equal(begin(), end(), other.begin(), + [](const Item& lhs, const Item& rhs) { + return lhs.key == rhs.key and lhs.value == rhs.value; + }); + } + + template + bool operator!=(const HashMap& other) const + { + return not (*this == other); + } + + HashStats compute_stats() const; +private: + Vector m_items; + HashIndex m_index; +}; + +void profile_hash_maps(); + +} + +#endif // hash_map_hh_INCLUDED diff --git a/src/string.hh b/src/string.hh index 14e92bd7..6c0a467e 100644 --- a/src/string.hh +++ b/src/string.hh @@ -121,6 +121,8 @@ public: } String(const char* begin, const char* end) : m_data(begin, end-begin) {} + explicit String(StringView str); + [[gnu::always_inline]] char* data() { return m_data.data(); } @@ -254,6 +256,10 @@ private: static_assert(std::is_trivial::value, ""); +template<> struct HashCompatible : std::true_type {}; + +inline String::String(StringView str) : String{str.begin(), str.length()} {} + template inline StringView StringOps::substr(ByteCount from, ByteCount length) const { @@ -319,6 +325,11 @@ inline String operator"" _str(const char* str, size_t) return String(str); } +inline StringView operator"" _sv(const char* str, size_t) +{ + return StringView{str}; +} + Vector split(StringView str, char separator, char escape); Vector split(StringView str, char separator);