diff --git a/src/regex_impl.cc b/src/regex_impl.cc index db180a15..b7ac2de6 100644 --- a/src/regex_impl.cc +++ b/src/regex_impl.cc @@ -18,6 +18,7 @@ enum Op : char AnyChar, Jump, Split, + Save, LineStart, LineEnd, WordBoundary, @@ -83,7 +84,7 @@ struct AstNode using AstNodePtr = std::unique_ptr; -AstNodePtr make_ast_node(Op op, char value = 0, +AstNodePtr make_ast_node(Op op, char value = -1, Quantifier quantifier = {Quantifier::One}) { return AstNodePtr{new AstNode{op, value, quantifier, {}}}; @@ -94,25 +95,29 @@ AstNodePtr make_ast_node(Op op, char value = 0, template struct Parser { - static AstNodePtr parse(Iterator pos, Iterator end) + AstNodePtr parse(Iterator pos, Iterator end) { - return disjunction(pos, end); + return disjunction(pos, end, 0); } private: - static AstNodePtr disjunction(Iterator& pos, Iterator end) + AstNodePtr disjunction(Iterator& pos, Iterator end, char capture = -1) { AstNodePtr node = alternative(pos, end); if (pos == end or *pos != '|') + { + node->value = capture; return node; + } AstNodePtr res = make_ast_node(Op::Alternation); res->children.push_back(std::move(node)); res->children.push_back(disjunction(++pos, end)); + res->value = capture; return res; } - static AstNodePtr alternative(Iterator& pos, Iterator end) + AstNodePtr alternative(Iterator& pos, Iterator end) { AstNodePtr res = make_ast_node(Op::Sequence); while (auto node = term(pos, end)) @@ -120,7 +125,7 @@ private: return res; } - static AstNodePtr term(Iterator& pos, Iterator end) + AstNodePtr term(Iterator& pos, Iterator end) { if (auto node = assertion(pos, end)) return node; @@ -132,7 +137,7 @@ private: return nullptr; } - static AstNodePtr assertion(Iterator& pos, Iterator end) + AstNodePtr assertion(Iterator& pos, Iterator end) { switch (*pos) { @@ -154,7 +159,7 @@ private: return nullptr; } - static AstNodePtr atom(Iterator& pos, Iterator end) + AstNodePtr atom(Iterator& pos, Iterator end) { const auto c = *pos; switch (c) @@ -163,7 +168,8 @@ private: case '(': { ++pos; - auto content = disjunction(pos, end); + auto content = disjunction(pos, end, m_next_capture++); + if (pos == end or *pos != ')') throw runtime_error{"Unclosed parenthesis"}; ++pos; @@ -177,7 +183,7 @@ private: } } - static Quantifier quantifier(Iterator& pos, Iterator end) + Quantifier quantifier(Iterator& pos, Iterator end) { auto read_int = [](Iterator& pos, Iterator begin, Iterator end) { int res = 0; @@ -214,6 +220,8 @@ private: default: return {Quantifier::One}; } } + + char m_next_capture = 1; }; RegexProgram::Offset compile_node(Vector& program, const AstNodePtr& node); @@ -234,6 +242,13 @@ RegexProgram::Offset compile_node_inner(Vector& program, const AstNodePtr& { const auto start_pos = program.size(); + const char capture = (node->op == Op::Alternation or node->op == Op::Sequence) ? node->value : -1; + if (capture >= 0) + { + program.push_back(RegexProgram::Save); + program.push_back(capture * 2); + } + Vector goto_inner_end_offsets; switch (node->op) { @@ -288,6 +303,12 @@ RegexProgram::Offset compile_node_inner(Vector& program, const AstNodePtr& for (auto& offset : goto_inner_end_offsets) get_offset(program, offset) = program.size(); + if (capture >= 0) + { + program.push_back(RegexProgram::Save); + program.push_back(capture * 2 + 1); + } + return start_pos; } @@ -338,7 +359,7 @@ Vector compile(const AstNodePtr& node) template Vector compile(Iterator begin, Iterator end) { - return compile(Parser::parse(begin, end)); + return compile(Parser{}.parse(begin, end)); } } @@ -367,6 +388,9 @@ void dump(ConstArrayView program) pos += sizeof(RegexProgram::Offset); break; } + case RegexProgram::Save: + printf("save %d\n", program[pos++]); + break; case RegexProgram::LineStart: printf("line start\n"); break; @@ -395,72 +419,87 @@ struct ThreadedExecutor { ThreadedExecutor(ConstArrayView program) : m_program{program} {} - struct StepResult + struct Thread { - enum Result { Consumed, Matched, Failed } result; - const char* next = nullptr; + const char* inst; + Vector saves = {}; }; - StepResult step(const char* inst) + enum class StepResult { Consumed, Matched, Failed }; + StepResult step(size_t thread_index) { while (true) { + auto& thread = m_threads[thread_index]; char c = m_pos == m_subject.end() ? 0 : *m_pos; - const RegexProgram::Op op = (RegexProgram::Op)*inst++; + const RegexProgram::Op op = (RegexProgram::Op)*thread.inst++; switch (op) { case RegexProgram::Literal: - if (*inst++ == c) - return { StepResult::Consumed, inst }; - return { StepResult::Failed }; + if (*thread.inst++ == c) + return StepResult::Consumed; + return StepResult::Failed; case RegexProgram::AnyChar: - return { StepResult::Consumed, inst }; + return StepResult::Consumed; case RegexProgram::Jump: - inst = m_program.begin() + *reinterpret_cast(inst); + { + auto inst = m_program.begin() + *reinterpret_cast(thread.inst); // if instruction is already going to be executed, drop this thread - if (std::find(m_threads.begin(), m_threads.end(), inst) != m_threads.end()) - return { StepResult::Failed }; + if (std::find_if(m_threads.begin(), m_threads.end(), + [inst](const Thread& t) { return t.inst == inst; }) != m_threads.end()) + return StepResult::Failed; + thread.inst = inst; break; + } case RegexProgram::Split: { - add_thread(*reinterpret_cast(inst)); - inst += sizeof(RegexProgram::Offset); + add_thread(*reinterpret_cast(thread.inst), thread.saves); + // thread is invalidated now, as we mutated the m_thread vector + m_threads[thread_index].inst += sizeof(RegexProgram::Offset); + break; + } + case RegexProgram::Save: + { + const char index = *thread.inst++; + thread.saves[index] = m_pos; break; } case RegexProgram::LineStart: if (not is_line_start()) - return { StepResult::Failed }; + return StepResult::Failed; break; case RegexProgram::LineEnd: if (not is_line_end()) - return { StepResult::Failed }; + return StepResult::Failed; break; case RegexProgram::WordBoundary: if (not is_word_boundary()) - return { StepResult::Failed }; + return StepResult::Failed; break; case RegexProgram::NotWordBoundary: if (is_word_boundary()) - return { StepResult::Failed }; + return StepResult::Failed; break; case RegexProgram::SubjectBegin: if (m_pos != m_subject.begin()) - return { StepResult::Failed }; + return StepResult::Failed; break; case RegexProgram::SubjectEnd: if (m_pos != m_subject.end()) - return { StepResult::Failed }; + return StepResult::Failed; break; case RegexProgram::Match: - return { StepResult::Matched }; + return StepResult::Matched; } } - return { StepResult::Failed }; + return StepResult::Failed; } bool match(ConstArrayView program, StringView data) { - m_threads = Vector{program.begin()}; + m_threads.clear(); + add_thread(0, Vector(10, nullptr)); + m_subject = data; m_pos = data.begin(); @@ -468,30 +507,39 @@ struct ThreadedExecutor { for (int i = 0; i < m_threads.size(); ++i) { - auto res = step(m_threads[i]); - m_threads[i] = res.next; - if (res.result == StepResult::Matched) + const auto res = step(i); + if (res == StepResult::Matched) + { + m_captures = std::move(m_threads[i].saves); return true; + } + else if (res == StepResult::Failed) + m_threads[i].inst = nullptr; } - m_threads.erase(std::remove(m_threads.begin(), m_threads.end(), nullptr), m_threads.end()); + m_threads.erase(std::remove_if(m_threads.begin(), m_threads.end(), + [](const Thread& t) { return t.inst == nullptr; }), m_threads.end()); if (m_threads.empty()) - break; + return false; } // Step remaining threads to see if they match without consuming anything else for (int i = 0; i < m_threads.size(); ++i) { - if (step(m_threads[i]).result == StepResult::Matched) + if (step(i) == StepResult::Matched) + { + m_captures = std::move(m_threads[i].saves); return true; + } } return false; } - void add_thread(RegexProgram::Offset pos) + void add_thread(RegexProgram::Offset pos, Vector saves) { const char* inst = m_program.begin() + pos; - if (std::find(m_threads.begin(), m_threads.end(), inst) == m_threads.end()) - m_threads.push_back(inst); + if (std::find_if(m_threads.begin(), m_threads.end(), + [inst](const Thread& t) { return t.inst == inst; }) == m_threads.end()) + m_threads.push_back({inst, std::move(saves)}); } bool is_line_start() const @@ -512,7 +560,8 @@ struct ThreadedExecutor } ConstArrayView m_program; - Vector m_threads; + Vector m_threads; + Vector m_captures; StringView m_subject; const char* m_pos; }; @@ -549,6 +598,7 @@ auto test_regex = UnitTest{[]{ RegexProgram::dump(program); Exec exec{program}; kak_assert(exec.match(program, "fooquxbarbaz")); + kak_assert(StringView{exec.m_captures[2], exec.m_captures[3]} == "qux"); kak_assert(not exec.match(program, "fooquxbarbaze")); kak_assert(not exec.match(program, "quxbar")); kak_assert(not exec.match(program, "blahblah")); @@ -562,6 +612,7 @@ auto test_regex = UnitTest{[]{ RegexProgram::dump(program); Exec exec{program}; kak_assert(exec.match(program, "qux foo baz")); + kak_assert(StringView{exec.m_captures[2], exec.m_captures[3]} == "foo"); kak_assert(not exec.match(program, "quxfoobaz")); kak_assert(exec.match(program, "bar")); kak_assert(not exec.match(program, "foobar"));