diff --git a/src/regex_impl.cc b/src/regex_impl.cc index 0bd44219..52144762 100644 --- a/src/regex_impl.cc +++ b/src/regex_impl.cc @@ -68,22 +68,58 @@ struct ParsedRegex }; struct AstNode; - using AstNodePtr = std::unique_ptr; + using AstNodeIndex = uint16_t; struct AstNode { Op op; bool ignore_case; + AstNodeIndex children_end; Codepoint value; Quantifier quantifier; - Vector children; }; - AstNodePtr ast; + Vector nodes; size_t capture_count; Vector, MemoryDomain::Regex> matchers; }; +namespace +{ +template +bool for_each_child(const ParsedRegex& parsed_regex, ParsedRegex::AstNodeIndex index, Func&& func) +{ + const auto end = parsed_regex.nodes[index].children_end; + for (auto child = index+1; child != end; + child = parsed_regex.nodes[child].children_end) + { + if (func(child) == false) + return false; + } + return true; +} + +template +bool for_each_child_reverse(const ParsedRegex& parsed_regex, ParsedRegex::AstNodeIndex index, Func&& func) +{ + auto find_last_child = [&](ParsedRegex::AstNodeIndex begin, ParsedRegex::AstNodeIndex end) { + while (parsed_regex.nodes[begin].children_end != end) + begin = parsed_regex.nodes[begin].children_end; + return begin; + }; + const auto first_child = index+1; + auto end = parsed_regex.nodes[index].children_end; + while (end != first_child) + { + auto child = find_last_child(first_child, end); + if (func(child) == false) + return false; + end = child; + } + return true; +} +} + // Recursive descent parser based on naming used in the ECMAScript // standard, although the syntax is not fully compatible. struct RegexParser @@ -92,7 +128,8 @@ struct RegexParser : m_regex{re}, m_pos{re.begin(), re} { m_parsed_regex.capture_count = 1; - m_parsed_regex.ast = disjunction(0); + AstNodeIndex root = disjunction(0); + kak_assert(root == 0); } ParsedRegex get_parsed_regex() { return std::move(m_parsed_regex); } @@ -106,38 +143,35 @@ private: }; using Iterator = utf8::iterator; - using AstNodePtr = ParsedRegex::AstNodePtr; + using AstNodeIndex = ParsedRegex::AstNodeIndex; - AstNodePtr disjunction(unsigned capture = -1) + AstNodeIndex disjunction(unsigned capture = -1) { - AstNodePtr node = alternative(); - if (at_end() or *m_pos != '|') - { - node->value = capture; - return node; - } - - AstNodePtr res = new_node(ParsedRegex::Alternation); - res->value = capture; - res->children.push_back(std::move(node)); - do + AstNodeIndex index = new_node(ParsedRegex::Alternation); + get_node(index).value = capture; + while (true) { + alternative(); + if (at_end() or *m_pos != '|') + break; ++m_pos; - res->children.push_back(alternative()); } - while (not at_end() and *m_pos == '|'); - return res; + get_node(index).children_end = m_parsed_regex.nodes.size(); + + return index; } - AstNodePtr alternative(ParsedRegex::Op op = ParsedRegex::Sequence) + AstNodeIndex alternative(ParsedRegex::Op op = ParsedRegex::Sequence) { - AstNodePtr res = new_node(op); - while (auto node = term()) - res->children.push_back(std::move(node)); - return res; + AstNodeIndex index = new_node(op); + while (auto t = term()) + {} + get_node(index).children_end = m_parsed_regex.nodes.size(); + + return index; } - AstNodePtr term() + Optional term() { while (modifiers()) // read all modifiers {} @@ -145,10 +179,10 @@ private: return node; if (auto node = atom()) { - node->quantifier = quantifier(); + get_node(*node).quantifier = quantifier(); return node; } - return nullptr; + return {}; } bool accept(StringView expected) @@ -178,10 +212,10 @@ private: return false; } - AstNodePtr assertion() + Optional assertion() { if (at_end()) - return nullptr; + return {}; switch (*m_pos) { @@ -189,7 +223,7 @@ private: case '$': ++m_pos; return new_node(ParsedRegex::LineEnd); case '\\': if (m_pos+1 == m_regex.end()) - return nullptr; + return {}; switch (*(m_pos+1)) { case 'b': m_pos += 2; return new_node(ParsedRegex::WordBoundary); @@ -217,9 +251,9 @@ private: } } if (not lookaround_op) - return nullptr; + return {}; - AstNodePtr lookaround = alternative(*lookaround_op); + AstNodeIndex lookaround = alternative(*lookaround_op); if (at_end() or *m_pos++ != ')') parse_error("unclosed parenthesis"); @@ -227,13 +261,13 @@ private: return lookaround; } } - return nullptr; + return {}; } - AstNodePtr atom() + Optional atom() { if (at_end()) - return nullptr; + return {}; const Codepoint cp = *m_pos; switch (cp) @@ -243,7 +277,7 @@ private: { ++m_pos; const bool capture = not accept("?:"); - AstNodePtr content = disjunction(capture ? m_parsed_regex.capture_count++ : -1); + AstNodeIndex content = disjunction(capture ? m_parsed_regex.capture_count++ : -1); if (at_end() or *m_pos++ != ')') parse_error("unclosed parenthesis"); return content; @@ -255,7 +289,7 @@ private: ++m_pos; return character_class(); case '|': case ')': - return nullptr; + return {}; default: if (contains("^$.*+?[]{}", cp)) parse_error(format("unexpected '{}'", cp)); @@ -264,7 +298,7 @@ private: } } - AstNodePtr atom_escape() + AstNodeIndex atom_escape() { const Codepoint cp = *m_pos++; @@ -272,9 +306,12 @@ private: { auto escaped_sequence = new_node(ParsedRegex::Sequence); constexpr StringView end_mark{"\\E"}; + auto quote_end = std::search(m_pos.base(), m_regex.end(), end_mark.begin(), end_mark.end()); while (m_pos != quote_end) - escaped_sequence->children.push_back(new_node(ParsedRegex::Literal, *m_pos++)); + new_node(ParsedRegex::Literal, *m_pos++); + get_node(escaped_sequence).children_end = m_parsed_regex.nodes.size(); + if (quote_end != m_regex.end()) m_pos += 2; @@ -372,7 +409,7 @@ private: ranges.erase(pos+1, ranges.end()); } - AstNodePtr character_class() + AstNodeIndex character_class() { const bool negative = m_pos != m_regex.end() and *m_pos == '^'; if (negative) @@ -543,14 +580,26 @@ private: } } - AstNodePtr new_node(ParsedRegex::Op op, Codepoint value = -1, - ParsedRegex::Quantifier quantifier = {ParsedRegex::Quantifier::One}) + AstNodeIndex new_node(ParsedRegex::Op op, Codepoint value = -1, + ParsedRegex::Quantifier quantifier = {ParsedRegex::Quantifier::One}) { - return AstNodePtr{new ParsedRegex::AstNode{op, m_ignore_case, value, quantifier, {}}}; + constexpr auto max_nodes = std::numeric_limits::max(); + const AstNodeIndex res = m_parsed_regex.nodes.size(); + if (res == max_nodes) + parse_error(format("regex parsed to more than {} ast nodes", max_nodes)); + const AstNodeIndex next = res+1; + m_parsed_regex.nodes.push_back({op, m_ignore_case, next, value, quantifier}); + return res; } bool at_end() const { return m_pos == m_regex.end(); } + ParsedRegex::AstNode& get_node(AstNodeIndex index) + { + return m_parsed_regex.nodes[index]; + } + + [[gnu::noreturn]] void parse_error(StringView error) const { @@ -559,16 +608,17 @@ private: StringView{m_pos.base(), m_regex.end()})); } - void validate_lookaround(const AstNodePtr& node) + void validate_lookaround(AstNodeIndex index) { - for (auto& child : node->children) - { - if (child->op != ParsedRegex::Literal and child->op != ParsedRegex::Matcher and - child->op != ParsedRegex::AnyChar) + for_each_child(m_parsed_regex, index, [this](AstNodeIndex child_index) { + auto& child = get_node(child_index); + if (child.op != ParsedRegex::Literal and child.op != ParsedRegex::Matcher and + child.op != ParsedRegex::AnyChar) parse_error("Lookaround can only contain literals, any chars or character classes"); - if (child->quantifier.type != ParsedRegex::Quantifier::One) + if (child.quantifier.type != ParsedRegex::Quantifier::One) parse_error("Quantifiers cannot be used in lookarounds"); - } + return true; + }); } ParsedRegex m_parsed_regex; @@ -609,7 +659,7 @@ struct RegexCompiler : m_parsed_regex{parsed_regex}, m_flags(flags), m_forward{direction == MatchDirection::Forward} { write_search_prefix(); - compile_node(m_parsed_regex.ast); + compile_node(0); push_inst(CompiledRegex::Match); m_program.matchers = m_parsed_regex.matchers; m_program.save_count = m_parsed_regex.capture_count * 2; @@ -621,61 +671,68 @@ struct RegexCompiler private: - uint32_t compile_node_inner(const ParsedRegex::AstNodePtr& node) + uint32_t compile_node_inner(ParsedRegex::AstNodeIndex index) { - const auto start_pos = m_program.instructions.size(); - const bool ignore_case = node->ignore_case; + auto& node = get_node(index); - const bool save = (node->op == ParsedRegex::Alternation or node->op == ParsedRegex::Sequence) and - (node->value == 0 or (node->value != -1 and not (m_flags & RegexCompileFlags::NoSubs))); + const uint32_t start_pos = (uint32_t)m_program.instructions.size(); + const bool ignore_case = node.ignore_case; + + const bool save = (node.op == ParsedRegex::Alternation or node.op == ParsedRegex::Sequence) and + (node.value == 0 or (node.value != -1 and not (m_flags & RegexCompileFlags::NoSubs))); if (save) - push_inst(CompiledRegex::Save, node->value * 2 + (m_forward ? 0 : 1)); + push_inst(CompiledRegex::Save, node.value * 2 + (m_forward ? 0 : 1)); Vector goto_inner_end_offsets; - switch (node->op) + switch (node.op) { case ParsedRegex::Literal: if (ignore_case) - push_inst(CompiledRegex::Literal_IgnoreCase, to_lower(node->value)); + push_inst(CompiledRegex::Literal_IgnoreCase, to_lower(node.value)); else - push_inst(CompiledRegex::Literal, node->value); + push_inst(CompiledRegex::Literal, node.value); break; case ParsedRegex::AnyChar: push_inst(CompiledRegex::AnyChar); break; case ParsedRegex::Matcher: - push_inst(CompiledRegex::Matcher, node->value); + push_inst(CompiledRegex::Matcher, node.value); break; case ParsedRegex::Sequence: { if (m_forward) - for (auto& child : node->children) - compile_node(child); + for_each_child(m_parsed_regex, index, [this](ParsedRegex::AstNodeIndex child) { + compile_node(child); return true; + }); else - for (auto& child : node->children | reverse()) - compile_node(child); + for_each_child_reverse(m_parsed_regex, index, [this](ParsedRegex::AstNodeIndex child) { + compile_node(child); return true; + }); break; } case ParsedRegex::Alternation: { - auto& children = node->children; - kak_assert(children.size() > 1); + //kak_assert(children.size() > 1); - const auto split_pos = m_program.instructions.size(); - for (int i = 0; i < children.size() - 1; ++i) - push_inst(CompiledRegex::Split_PrioritizeParent); + auto split_pos = m_program.instructions.size(); + for_each_child(m_parsed_regex, index, [this, index](ParsedRegex::AstNodeIndex child) { + if (child != index+1) + push_inst(CompiledRegex::Split_PrioritizeParent); + return true; + }); - for (int i = 0; i < children.size(); ++i) - { - auto node = compile_node(children[i]); - if (i > 0) - m_program.instructions[split_pos + i - 1].param = node; - if (i < children.size() - 1) + for_each_child(m_parsed_regex, index, + [&, end = node.children_end](ParsedRegex::AstNodeIndex child) { + auto node = compile_node(child); + if (child != index+1) + m_program.instructions[split_pos++].param = node; + if (get_node(child).children_end != end) { auto jump = push_inst(CompiledRegex::Jump); goto_inner_end_offsets.push_back(jump); } - } + return true; + }); break; } case ParsedRegex::LookAhead: @@ -683,28 +740,28 @@ private: : CompiledRegex::LookAhead) : (ignore_case ? CompiledRegex::LookBehind_IgnoreCase : CompiledRegex::LookBehind), - push_lookaround(node->children, false, ignore_case)); + push_lookaround(index, false, ignore_case)); break; case ParsedRegex::NegativeLookAhead: push_inst(m_forward ? (ignore_case ? CompiledRegex::NegativeLookAhead_IgnoreCase : CompiledRegex::NegativeLookAhead) : (ignore_case ? CompiledRegex::NegativeLookBehind_IgnoreCase : CompiledRegex::NegativeLookBehind), - push_lookaround(node->children, false, ignore_case)); + push_lookaround(index, false, ignore_case)); break; case ParsedRegex::LookBehind: push_inst(m_forward ? (ignore_case ? CompiledRegex::LookBehind_IgnoreCase : CompiledRegex::LookBehind) : (ignore_case ? CompiledRegex::LookAhead_IgnoreCase : CompiledRegex::LookAhead), - push_lookaround(node->children, true, ignore_case)); + push_lookaround(index, true, ignore_case)); break; case ParsedRegex::NegativeLookBehind: push_inst(m_forward ? (ignore_case ? CompiledRegex::NegativeLookBehind_IgnoreCase : CompiledRegex::NegativeLookBehind) : (ignore_case ? CompiledRegex::NegativeLookAhead_IgnoreCase : CompiledRegex::NegativeLookAhead), - push_lookaround(node->children, true, ignore_case)); + push_lookaround(index, true, ignore_case)); break; case ParsedRegex::LineStart: push_inst(m_forward ? CompiledRegex::LineStart @@ -737,17 +794,19 @@ private: m_program.instructions[offset].param = m_program.instructions.size(); if (save) - push_inst(CompiledRegex::Save, node->value * 2 + (m_forward ? 1 : 0)); + push_inst(CompiledRegex::Save, node.value * 2 + (m_forward ? 1 : 0)); return start_pos; } - uint32_t compile_node(const ParsedRegex::AstNodePtr& node) + uint32_t compile_node(ParsedRegex::AstNodeIndex index) { - uint32_t pos = m_program.instructions.size(); + auto& node = get_node(index); + + const uint32_t start_pos = (uint32_t)m_program.instructions.size(); Vector goto_ends; - auto& quantifier = node->quantifier; + auto& quantifier = node.quantifier; // TODO reverse, invert the way we write optional quantifiers ? @@ -758,10 +817,10 @@ private: goto_ends.push_back(split_pos); } - auto inner_pos = compile_node_inner(node); + auto inner_pos = compile_node_inner(index); // Write the node multiple times when we have a min count quantifier for (int i = 1; i < quantifier.min; ++i) - inner_pos = compile_node_inner(node); + inner_pos = compile_node_inner(index); if (quantifier.allows_infinite_repeat()) push_inst(quantifier.greedy ? CompiledRegex::Split_PrioritizeChild @@ -775,13 +834,13 @@ private: auto split_pos = push_inst(quantifier.greedy ? CompiledRegex::Split_PrioritizeParent : CompiledRegex::Split_PrioritizeChild); goto_ends.push_back(split_pos); - compile_node_inner(node); + compile_node_inner(index); } for (auto offset : goto_ends) m_program.instructions[offset].param = m_program.instructions.size(); - return pos; + return start_pos; } // Add an set of instruction prefix used in the search use case @@ -804,29 +863,27 @@ private: return res; } - uint32_t push_lookaround(ArrayView characters, - bool reversed, bool ignore_case) + uint32_t push_lookaround(ParsedRegex::AstNodeIndex index, bool reversed, bool ignore_case) { uint32_t res = m_program.lookarounds.size(); - auto write_lookaround = [this, ignore_case](auto&& characters) { - for (auto& character : characters) - { - if (character->op == ParsedRegex::Literal) - m_program.lookarounds.push_back(ignore_case ? to_lower(character->value) - : character->value); - else if (character->op == ParsedRegex::AnyChar) + auto write_matcher = [this, ignore_case](ParsedRegex::AstNodeIndex child) { + auto& character = get_node(child); + if (character.op == ParsedRegex::Literal) + m_program.lookarounds.push_back(ignore_case ? to_lower(character.value) + : character.value); + else if (character.op == ParsedRegex::AnyChar) m_program.lookarounds.push_back(0xF000); - else if (character->op == ParsedRegex::Matcher) - m_program.lookarounds.push_back(0xF0001 + character->value); + else if (character.op == ParsedRegex::Matcher) + m_program.lookarounds.push_back(0xF0001 + character.value); else kak_assert(false); - } + return true; }; if (reversed) - write_lookaround(characters | reverse()); + for_each_child_reverse(m_parsed_regex, index, write_matcher); else - write_lookaround(characters); + for_each_child(m_parsed_regex, index, write_matcher); m_program.lookarounds.push_back((Codepoint)-1); return res; @@ -835,57 +892,58 @@ private: // Fills accepted and rejected according to which chars can start the given node, // returns true if the node did not consume the char, hence a following node in // sequence would be still relevant for the parent node start chars computation. - bool compute_start_chars(const ParsedRegex::AstNodePtr& node, + bool compute_start_chars(ParsedRegex::AstNodeIndex index, CompiledRegex::StartChars& start_chars) const { - switch (node->op) + auto& node = get_node(index); + switch (node.op) { case ParsedRegex::Literal: - if (node->value < CompiledRegex::StartChars::count) + if (node.value < CompiledRegex::StartChars::count) { - if (node->ignore_case) + if (node.ignore_case) { - start_chars.map[to_lower(node->value)] = true; - start_chars.map[to_upper(node->value)] = true; + start_chars.map[to_lower(node.value)] = true; + start_chars.map[to_upper(node.value)] = true; } else - start_chars.map[node->value] = true; + start_chars.map[node.value] = true; } else start_chars.map[CompiledRegex::StartChars::other] = true; - return node->quantifier.allows_none(); + return node.quantifier.allows_none(); case ParsedRegex::AnyChar: for (auto& b : start_chars.map) b = true; start_chars.map[CompiledRegex::StartChars::other] = true; - return node->quantifier.allows_none(); + return node.quantifier.allows_none(); case ParsedRegex::Matcher: for (Codepoint c = 0; c < CompiledRegex::StartChars::count; ++c) - if (m_program.matchers[node->value](c)) + if (m_program.matchers[node.value](c)) start_chars.map[c] = true; start_chars.map[CompiledRegex::StartChars::other] = true; // stay safe - return node->quantifier.allows_none(); + return node.quantifier.allows_none(); case ParsedRegex::Sequence: { - bool consumed = false; - auto consumes = [&, this](auto& child) { - return not this->compute_start_chars(child, start_chars); + bool did_not_consume = false; + auto does_not_consume = [&, this](auto child) { + return this->compute_start_chars(child, start_chars); }; if (m_forward) - consumed = contains_that(node->children, consumes); + did_not_consume = for_each_child(m_parsed_regex, index, does_not_consume); else - consumed = contains_that(node->children | reverse(), consumes); + did_not_consume = for_each_child_reverse(m_parsed_regex, index, does_not_consume); - return not consumed or node->quantifier.allows_none(); + return did_not_consume or node.quantifier.allows_none(); } case ParsedRegex::Alternation: { - bool all_consumed = not node->quantifier.allows_none(); - for (auto& child : node->children) - { + bool all_consumed = not node.quantifier.allows_none(); + for_each_child(m_parsed_regex, index, [&](ParsedRegex::AstNodeIndex child) { if (compute_start_chars(child, start_chars)) all_consumed = false; - } + return true; + }); return not all_consumed; } case ParsedRegex::LineStart: @@ -908,7 +966,7 @@ private: std::unique_ptr compute_start_chars() const { CompiledRegex::StartChars start_chars{}; - if (compute_start_chars(m_parsed_regex.ast, start_chars)) + if (compute_start_chars(0, start_chars)) return nullptr; if (not contains(start_chars.map, false)) @@ -917,6 +975,11 @@ private: return std::make_unique(start_chars); } + const ParsedRegex::AstNode& get_node(ParsedRegex::AstNodeIndex index) const + { + return m_parsed_regex.nodes[index]; + } + CompiledRegex m_program; RegexCompileFlags m_flags; const ParsedRegex& m_parsed_regex;