diff --git a/src/regex_impl.hh b/src/regex_impl.hh index 8b7fa9b8..881651c9 100644 --- a/src/regex_impl.hh +++ b/src/regex_impl.hh @@ -68,10 +68,35 @@ struct ThreadedRegexVM ThreadedRegexVM(const CompiledRegex& program) : m_program{program} { kak_assert(m_program); } + struct Saves + { + int refcount; + Vector pos; + }; + + Saves* clone_saves(Saves* saves) + { + Saves* res = nullptr; + if (not m_free_saves.empty()) + { + res = m_free_saves.back(); + m_free_saves.pop_back(); + } + else + { + m_saves.push_back(std::make_unique()); + res = m_saves.back().get(); + } + + res->refcount = 1; + res->pos = saves->pos; + return res; + } + struct Thread { const char* inst; - Vector saves = {}; + Saves* saves; }; enum class StepResult { Consumed, Matched, Failed }; @@ -97,33 +122,35 @@ struct ThreadedRegexVM case CompiledRegex::AnyChar: return StepResult::Consumed; case CompiledRegex::Jump: - { - auto inst = prog_start + *reinterpret_cast(thread.inst); - // if instruction is already going to be executed by another thread, drop this thread - if (std::find_if(m_threads.begin(), m_threads.end(), - [inst](const Thread& t) { return t.inst == inst; }) != m_threads.end()) - return StepResult::Failed; - thread.inst = inst; + thread.inst = prog_start + *reinterpret_cast(thread.inst); break; - } case CompiledRegex::Split_PrioritizeParent: { - auto new_thread_inst = prog_start + *reinterpret_cast(thread.inst); - thread.inst += sizeof(CompiledRegex::Offset); - add_thread(thread_index+1, new_thread_inst, thread.saves); + auto parent = thread.inst + sizeof(CompiledRegex::Offset); + auto child = prog_start + *reinterpret_cast(thread.inst); + thread.inst = parent; + ++thread.saves->refcount; + m_threads.insert(m_threads.begin() + thread_index + 1, {child, thread.saves}); break; } case CompiledRegex::Split_PrioritizeChild: { - auto new_thread_inst = thread.inst + sizeof(CompiledRegex::Offset); - thread.inst = prog_start + *reinterpret_cast(thread.inst); - add_thread(thread_index+1, new_thread_inst, thread.saves); + auto parent = thread.inst + sizeof(CompiledRegex::Offset); + auto child = prog_start + *reinterpret_cast(thread.inst); + thread.inst = child; + ++thread.saves->refcount; + m_threads.insert(m_threads.begin() + thread_index + 1, {parent, thread.saves}); break; } case CompiledRegex::Save: { const char index = *thread.inst++; - thread.saves[index] = m_pos.base(); + if (thread.saves->refcount > 1) + { + --thread.saves->refcount; + thread.saves = clone_saves(thread.saves); + } + thread.saves->pos[index] = m_pos.base(); break; } case CompiledRegex::Matcher: @@ -194,8 +221,8 @@ struct ThreadedRegexVM bool found_match = false; m_threads.clear(); const auto start_offset = (flags & RegexExecFlags::Search) ? 0 : CompiledRegex::search_prefix_size; - add_thread(0, m_program.bytecode.data() + start_offset, - Vector(m_program.save_count, Iterator{})); + m_saves.push_back(std::make_unique(Saves{1, Vector(m_program.save_count, Iterator{})})); + m_threads.push_back({m_program.bytecode.data() + start_offset, m_saves.back().get()}); m_begin = begin; m_end = end; @@ -204,9 +231,14 @@ struct ThreadedRegexVM if (flags & RegexExecFlags::NotInitialNull and m_begin == m_end) return false; + auto release_saves = [this](Saves* saves) { + if (--saves->refcount == 0) + m_free_saves.push_back(saves); + }; + for (m_pos = Utf8It{m_begin, m_begin, m_end}; m_pos != m_end; ++m_pos) { - for (int i = 0; i < m_threads.size(); ) + for (int i = 0; i < m_threads.size(); ++i) { const auto res = step(i); if (res == StepResult::Matched) @@ -214,11 +246,12 @@ struct ThreadedRegexVM if (not (flags & RegexExecFlags::Search) or // We are not at end, this is not a full match (flags & RegexExecFlags::NotInitialNull and m_pos == m_begin)) { - m_threads.erase(m_threads.begin() + i); + m_threads[i].inst = nullptr; + release_saves(m_threads[i].saves); continue; } - m_captures = std::move(m_threads[i].saves); + m_captures = std::move(m_threads[i].saves->pos); if (flags & RegexExecFlags::AnyMatch) return true; @@ -226,17 +259,25 @@ struct ThreadedRegexVM m_threads.resize(i); // remove this and lower priority threads } else if (res == StepResult::Failed) - m_threads.erase(m_threads.begin() + i); + { + m_threads[i].inst = nullptr; + release_saves(m_threads[i].saves); + } else { auto it = m_threads.begin() + i; if (std::find_if(m_threads.begin(), it, [inst = it->inst](auto& t) { return t.inst == inst; }) != it) - m_threads.erase(it); - else - ++i; + { + m_threads[i].inst = nullptr; + release_saves(m_threads[i].saves); + } } } + // Remove dead threads + m_threads.erase(std::remove_if(m_threads.begin(), m_threads.end(), + [](auto& t) { return t.inst == nullptr; }), + m_threads.end()); // we should never have more than one thread on the same instruction kak_assert(m_threads.size() <= m_program.bytecode.size()); if (m_threads.empty()) @@ -250,21 +291,13 @@ struct ThreadedRegexVM { if (step(i) == StepResult::Matched) { - m_captures = std::move(m_threads[i].saves); + m_captures = std::move(m_threads[i].saves->pos); return true; } } return false; } - void add_thread(int index, const char* inst, Vector saves) - { - if (std::find_if(m_threads.begin(), m_threads.end(), - [inst](const Thread& t) { return t.inst == inst; }) == m_threads.end()) - m_threads.insert(m_threads.begin() + index, {inst, std::move(saves)}); - kak_assert(m_threads.size() < m_program.bytecode.size()); - } - bool is_line_start() const { return (m_pos == m_begin and not (m_flags & RegexExecFlags::NotBeginOfLine)) or @@ -294,6 +327,9 @@ struct ThreadedRegexVM Utf8It m_pos; RegexExecFlags m_flags; + Vector> m_saves; + Vector m_free_saves; + Vector m_captures; };