diff --git a/src/buffer_manager.cc b/src/buffer_manager.cc index 74cacdb2..db3e37c5 100644 --- a/src/buffer_manager.cc +++ b/src/buffer_manager.cc @@ -36,7 +36,7 @@ void BufferManager::unregister_buffer(Buffer& buffer) { for (auto it = m_buffers.begin(); it != m_buffers.end(); ++it) { - if (*it == &buffer) + if (it->get() == &buffer) { m_buffers.erase(it); return; @@ -44,7 +44,7 @@ void BufferManager::unregister_buffer(Buffer& buffer) } for (auto it = m_buffer_trash.begin(); it != m_buffer_trash.end(); ++it) { - if (*it == &buffer) + if (it->get() == &buffer) { m_buffer_trash.erase(it); return; @@ -57,7 +57,7 @@ void BufferManager::delete_buffer(Buffer& buffer) { for (auto it = m_buffers.begin(); it != m_buffers.end(); ++it) { - if (*it == &buffer) + if (it->get() == &buffer) { if (ClientManager::has_instance()) ClientManager::instance().ensure_no_client_uses_buffer(buffer); @@ -98,7 +98,8 @@ Buffer& BufferManager::get_buffer(StringView name) void BufferManager::set_last_used_buffer(Buffer& buffer) { - auto it = find(m_buffers, &buffer); + auto it = find_if(m_buffers, [&buffer](const SafePtr& p) + { return p.get() == &buffer; }); kak_assert(it != m_buffers.end()); m_buffers.erase(it); m_buffers.emplace(m_buffers.begin(), &buffer); diff --git a/src/client_manager.cc b/src/client_manager.cc index f1da3159..b1b97888 100644 --- a/src/client_manager.cc +++ b/src/client_manager.cc @@ -121,7 +121,7 @@ void ClientManager::ensure_no_client_uses_buffer(Buffer& buffer) // access, this selects a sensible buffer to display. for (auto& buf : BufferManager::instance()) { - if (buf != &buffer) + if (buf.get() != &buffer) { client->context().change_buffer(*buf); break; diff --git a/src/ref_ptr.hh b/src/ref_ptr.hh index 09ec1f48..792e765c 100644 --- a/src/ref_ptr.hh +++ b/src/ref_ptr.hh @@ -1,17 +1,26 @@ #ifndef ref_ptr_hh_INCLUDED #define ref_ptr_hh_INCLUDED +#include + namespace Kakoune { -template +struct WorstMatch { [[gnu::always_inline]] WorstMatch(...) {} }; + +[[gnu::always_inline]] +inline void ref_ptr_moved(WorstMatch, void*, void*) noexcept {} + +template struct RefPtr { RefPtr() = default; - RefPtr(T* ptr) : m_ptr(ptr) { acquire(); } + explicit RefPtr(T* ptr) : m_ptr(ptr) { acquire(); } ~RefPtr() { release(); } RefPtr(const RefPtr& other) : m_ptr(other.m_ptr) { acquire(); } - RefPtr(RefPtr&& other) : m_ptr(other.m_ptr) { other.m_ptr = nullptr; } + RefPtr(RefPtr&& other) + noexcept(noexcept(std::declval().moved(nullptr))) + : m_ptr(other.m_ptr) { other.m_ptr = nullptr; moved(&other); } RefPtr& operator=(const RefPtr& other) { @@ -25,6 +34,7 @@ struct RefPtr release(); m_ptr = other.m_ptr; other.m_ptr = nullptr; + moved(&other); return *this; } @@ -33,31 +43,42 @@ struct RefPtr T* get() const { return m_ptr; } - explicit operator bool() { return m_ptr; } + explicit operator bool() const { return m_ptr; } - friend bool operator==(const RefPtr& lhs, const RefPtr& rhs) + void reset(T* ptr = nullptr) { - return lhs.m_ptr == rhs.m_ptr; - } - friend bool operator!=(const RefPtr& lhs, const RefPtr& rhs) - { - return lhs.m_ptr != rhs.m_ptr; + if (ptr == m_ptr) + return; + release(); + m_ptr = ptr; + acquire(); } + + friend bool operator==(const RefPtr& lhs, const RefPtr& rhs) { return lhs.m_ptr == rhs.m_ptr; } + friend bool operator!=(const RefPtr& lhs, const RefPtr& rhs) { return lhs.m_ptr != rhs.m_ptr; } + private: T* m_ptr = nullptr; void acquire() { if (m_ptr) - inc_ref_count(m_ptr); + inc_ref_count(static_cast(m_ptr), this); } void release() { if (m_ptr) - dec_ref_count(m_ptr); + dec_ref_count(static_cast(m_ptr), this); m_ptr = nullptr; - } + } + + void moved(void* from) + noexcept(noexcept(ref_ptr_moved(static_cast(nullptr), nullptr, nullptr))) + { + if (m_ptr) + ref_ptr_moved(static_cast(m_ptr), from, this); + } }; } diff --git a/src/safe_ptr.hh b/src/safe_ptr.hh index 6366bd13..2e3b3d2b 100644 --- a/src/safe_ptr.hh +++ b/src/safe_ptr.hh @@ -4,6 +4,10 @@ // #define SAFE_PTR_TRACK_CALLSTACKS #include "assert.hh" +#include "ref_ptr.hh" + +#include +#include #ifdef SAFE_PTR_TRACK_CALLSTACKS #include "vector.hh" @@ -15,84 +19,6 @@ namespace Kakoune // *** SafePtr: objects that assert nobody references them when they die *** -template -class SafePtr -{ -public: - SafePtr() : m_ptr(nullptr) {} - explicit SafePtr(T* ptr) : m_ptr(ptr) - { - #ifdef KAK_DEBUG - if (m_ptr) - m_ptr->inc_safe_count(this); - #endif - } - SafePtr(const SafePtr& other) : SafePtr(other.m_ptr) {} - SafePtr(SafePtr&& other) noexcept : m_ptr(other.m_ptr) - { - other.m_ptr = nullptr; - #ifdef KAK_DEBUG - if (m_ptr) - m_ptr->safe_ptr_moved(&other, this); - #endif - } - ~SafePtr() - { - #ifdef KAK_DEBUG - if (m_ptr) - m_ptr->dec_safe_count(this); - #endif - } - - SafePtr& operator=(const SafePtr& other) - { - #ifdef KAK_DEBUG - if (m_ptr != other.m_ptr) - { - if (m_ptr) - m_ptr->dec_safe_count(this); - if (other.m_ptr) - other.m_ptr->inc_safe_count(this); - } - #endif - m_ptr = other.m_ptr; - return *this; - } - - SafePtr& operator=(SafePtr&& other) noexcept - { - #ifdef KAK_DEBUG - if (m_ptr) - m_ptr->dec_safe_count(this); - if (other.m_ptr) - other.m_ptr->safe_ptr_moved(&other, this); - #endif - m_ptr = other.m_ptr; - other.m_ptr = nullptr; - return *this; - } - - void reset(T* ptr = nullptr) - { - *this = SafePtr(ptr); - } - - bool operator== (const SafePtr& other) const { return m_ptr == other.m_ptr; } - bool operator!= (const SafePtr& other) const { return m_ptr != other.m_ptr; } - bool operator== (T* ptr) const { return m_ptr == ptr; } - bool operator!= (T* ptr) const { return m_ptr != ptr; } - - T& operator* () const { return *m_ptr; } - T* operator-> () const { return m_ptr; } - - T* get() const { return m_ptr; } - - explicit operator bool() const { return m_ptr; } - -private: - T* m_ptr; -}; - class SafeCountable { public: @@ -106,31 +32,32 @@ public: #endif } - void inc_safe_count(void* ptr) const + friend void inc_ref_count(const SafeCountable* sc, void* ptr) { - ++m_count; + ++sc->m_count; #ifdef SAFE_PTR_TRACK_CALLSTACKS - m_callstacks.emplace_back(ptr); - #endif - } - void dec_safe_count(void* ptr) const - { - --m_count; - kak_assert(m_count >= 0); - #ifdef SAFE_PTR_TRACK_CALLSTACKS - auto it = std::find_if(m_callstacks.begin(), m_callstacks.end(), - [=](const Callstack& cs) { return cs.ptr == ptr; }); - kak_assert(it != m_callstacks.end()); - m_callstacks.erase(it); + sc->m_callstacks.emplace_back(ptr); #endif } - void safe_ptr_moved(void* from, void* to) const + friend void dec_ref_count(const SafeCountable* sc, void* ptr) + { + --sc->m_count; + kak_assert(sc->m_count >= 0); + #ifdef SAFE_PTR_TRACK_CALLSTACKS + auto it = std::find_if(sc->m_callstacks.begin(), sc->m_callstacks.end(), + [=](const Callstack& cs) { return cs.ptr == ptr; }); + kak_assert(it != sc->m_callstacks.end()); + sc->m_callstacks.erase(it); + #endif + } + + friend void ref_ptr_moved(const SafeCountable* sc, void* from, void* to) { #ifdef SAFE_PTR_TRACK_CALLSTACKS - auto it = std::find_if(m_callstacks.begin(), m_callstacks.end(), + auto it = std::find_if(sc->m_callstacks.begin(), sc->m_callstacks.end(), [=](const Callstack& cs) { return cs.ptr == from; }); - kak_assert(it != m_callstacks.end()); + kak_assert(it != sc->m_callstacks.end()); it->ptr = to; #endif } @@ -157,9 +84,19 @@ private: mutable Vector m_callstacks; #endif mutable int m_count; +#else + [[gnu::always_inline]] + friend void inc_ref_count(const SafeCountable* sc, void* ptr) {} + + [[gnu::always_inline]] + friend void dec_ref_count(const SafeCountable* sc, void* ptr) {} #endif }; +template using SafePtr = + RefPtr::value, + const SafeCountable, SafeCountable>::type>; + } #endif // safe_ptr_hh_INCLUDED diff --git a/src/shared_string.hh b/src/shared_string.hh index a73f8dd2..8af9820c 100644 --- a/src/shared_string.hh +++ b/src/shared_string.hh @@ -21,7 +21,7 @@ struct StringStorage : UseMemoryDomain [[gnu::always_inline]] StringView strview() const { return {data(), length}; } - static StringStorage* create(StringView str, char back = 0) + static RefPtr create(StringView str, char back = 0) { const int len = (int)str.length() + (back != 0 ? 1 : 0); void* ptr = StringStorage::operator new(sizeof(StringStorage) + len + 1); @@ -32,7 +32,7 @@ struct StringStorage : UseMemoryDomain if (back != 0) res->data()[len-1] = back; res->data()[len] = 0; - return res; + return RefPtr(res); } static void destroy(StringStorage* s) @@ -40,8 +40,8 @@ struct StringStorage : UseMemoryDomain StringStorage::operator delete(s, sizeof(StringStorage) + s->length + 1); } - friend void inc_ref_count(StringStorage* s) { ++s->refcount; } - friend void dec_ref_count(StringStorage* s) { if (--s->refcount == 0) StringStorage::destroy(s); } + friend void inc_ref_count(StringStorage* s, void*) { ++s->refcount; } + friend void dec_ref_count(StringStorage* s, void*) { if (--s->refcount == 0) StringStorage::destroy(s); } }; inline RefPtr operator"" _ss(const char* ptr, size_t len)