Regex: Optimize parsing and compilation

AstNodes are now POD, stored in a single vector, accessed through
their index. The children list is implicit, with nodes storing only
the node index at which their child graph ends.

That makes reverse iteration slower, but that is only used for reverse
matching regex, which are uncommon. In the general case compilation
is now faster.
This commit is contained in:
Maxime Coste 2017-10-23 15:35:43 +08:00
parent aea2de885d
commit 18a02ccacd

View File

@ -68,22 +68,58 @@ struct ParsedRegex
};
struct AstNode;
using AstNodePtr = std::unique_ptr<AstNode>;
using AstNodeIndex = uint16_t;
struct AstNode
{
Op op;
bool ignore_case;
AstNodeIndex children_end;
Codepoint value;
Quantifier quantifier;
Vector<AstNodePtr, MemoryDomain::Regex> children;
};
AstNodePtr ast;
Vector<AstNode, MemoryDomain::Regex> nodes;
size_t capture_count;
Vector<std::function<bool (Codepoint)>, MemoryDomain::Regex> matchers;
};
namespace
{
template<typename Func>
bool for_each_child(const ParsedRegex& parsed_regex, ParsedRegex::AstNodeIndex index, Func&& func)
{
const auto end = parsed_regex.nodes[index].children_end;
for (auto child = index+1; child != end;
child = parsed_regex.nodes[child].children_end)
{
if (func(child) == false)
return false;
}
return true;
}
template<typename Func>
bool for_each_child_reverse(const ParsedRegex& parsed_regex, ParsedRegex::AstNodeIndex index, Func&& func)
{
auto find_last_child = [&](ParsedRegex::AstNodeIndex begin, ParsedRegex::AstNodeIndex end) {
while (parsed_regex.nodes[begin].children_end != end)
begin = parsed_regex.nodes[begin].children_end;
return begin;
};
const auto first_child = index+1;
auto end = parsed_regex.nodes[index].children_end;
while (end != first_child)
{
auto child = find_last_child(first_child, end);
if (func(child) == false)
return false;
end = child;
}
return true;
}
}
// Recursive descent parser based on naming used in the ECMAScript
// standard, although the syntax is not fully compatible.
struct RegexParser
@ -92,7 +128,8 @@ struct RegexParser
: m_regex{re}, m_pos{re.begin(), re}
{
m_parsed_regex.capture_count = 1;
m_parsed_regex.ast = disjunction(0);
AstNodeIndex root = disjunction(0);
kak_assert(root == 0);
}
ParsedRegex get_parsed_regex() { return std::move(m_parsed_regex); }
@ -106,38 +143,35 @@ private:
};
using Iterator = utf8::iterator<const char*, Codepoint, int, InvalidPolicy>;
using AstNodePtr = ParsedRegex::AstNodePtr;
using AstNodeIndex = ParsedRegex::AstNodeIndex;
AstNodePtr disjunction(unsigned capture = -1)
AstNodeIndex disjunction(unsigned capture = -1)
{
AstNodePtr node = alternative();
if (at_end() or *m_pos != '|')
{
node->value = capture;
return node;
}
AstNodePtr res = new_node(ParsedRegex::Alternation);
res->value = capture;
res->children.push_back(std::move(node));
do
AstNodeIndex index = new_node(ParsedRegex::Alternation);
get_node(index).value = capture;
while (true)
{
alternative();
if (at_end() or *m_pos != '|')
break;
++m_pos;
res->children.push_back(alternative());
}
while (not at_end() and *m_pos == '|');
return res;
get_node(index).children_end = m_parsed_regex.nodes.size();
return index;
}
AstNodePtr alternative(ParsedRegex::Op op = ParsedRegex::Sequence)
AstNodeIndex alternative(ParsedRegex::Op op = ParsedRegex::Sequence)
{
AstNodePtr res = new_node(op);
while (auto node = term())
res->children.push_back(std::move(node));
return res;
AstNodeIndex index = new_node(op);
while (auto t = term())
{}
get_node(index).children_end = m_parsed_regex.nodes.size();
return index;
}
AstNodePtr term()
Optional<AstNodeIndex> term()
{
while (modifiers()) // read all modifiers
{}
@ -145,10 +179,10 @@ private:
return node;
if (auto node = atom())
{
node->quantifier = quantifier();
get_node(*node).quantifier = quantifier();
return node;
}
return nullptr;
return {};
}
bool accept(StringView expected)
@ -178,10 +212,10 @@ private:
return false;
}
AstNodePtr assertion()
Optional<AstNodeIndex> assertion()
{
if (at_end())
return nullptr;
return {};
switch (*m_pos)
{
@ -189,7 +223,7 @@ private:
case '$': ++m_pos; return new_node(ParsedRegex::LineEnd);
case '\\':
if (m_pos+1 == m_regex.end())
return nullptr;
return {};
switch (*(m_pos+1))
{
case 'b': m_pos += 2; return new_node(ParsedRegex::WordBoundary);
@ -217,9 +251,9 @@ private:
}
}
if (not lookaround_op)
return nullptr;
return {};
AstNodePtr lookaround = alternative(*lookaround_op);
AstNodeIndex lookaround = alternative(*lookaround_op);
if (at_end() or *m_pos++ != ')')
parse_error("unclosed parenthesis");
@ -227,13 +261,13 @@ private:
return lookaround;
}
}
return nullptr;
return {};
}
AstNodePtr atom()
Optional<AstNodeIndex> atom()
{
if (at_end())
return nullptr;
return {};
const Codepoint cp = *m_pos;
switch (cp)
@ -243,7 +277,7 @@ private:
{
++m_pos;
const bool capture = not accept("?:");
AstNodePtr content = disjunction(capture ? m_parsed_regex.capture_count++ : -1);
AstNodeIndex content = disjunction(capture ? m_parsed_regex.capture_count++ : -1);
if (at_end() or *m_pos++ != ')')
parse_error("unclosed parenthesis");
return content;
@ -255,7 +289,7 @@ private:
++m_pos;
return character_class();
case '|': case ')':
return nullptr;
return {};
default:
if (contains("^$.*+?[]{}", cp))
parse_error(format("unexpected '{}'", cp));
@ -264,7 +298,7 @@ private:
}
}
AstNodePtr atom_escape()
AstNodeIndex atom_escape()
{
const Codepoint cp = *m_pos++;
@ -272,9 +306,12 @@ private:
{
auto escaped_sequence = new_node(ParsedRegex::Sequence);
constexpr StringView end_mark{"\\E"};
auto quote_end = std::search(m_pos.base(), m_regex.end(), end_mark.begin(), end_mark.end());
while (m_pos != quote_end)
escaped_sequence->children.push_back(new_node(ParsedRegex::Literal, *m_pos++));
new_node(ParsedRegex::Literal, *m_pos++);
get_node(escaped_sequence).children_end = m_parsed_regex.nodes.size();
if (quote_end != m_regex.end())
m_pos += 2;
@ -372,7 +409,7 @@ private:
ranges.erase(pos+1, ranges.end());
}
AstNodePtr character_class()
AstNodeIndex character_class()
{
const bool negative = m_pos != m_regex.end() and *m_pos == '^';
if (negative)
@ -543,14 +580,26 @@ private:
}
}
AstNodePtr new_node(ParsedRegex::Op op, Codepoint value = -1,
ParsedRegex::Quantifier quantifier = {ParsedRegex::Quantifier::One})
AstNodeIndex new_node(ParsedRegex::Op op, Codepoint value = -1,
ParsedRegex::Quantifier quantifier = {ParsedRegex::Quantifier::One})
{
return AstNodePtr{new ParsedRegex::AstNode{op, m_ignore_case, value, quantifier, {}}};
constexpr auto max_nodes = std::numeric_limits<uint16_t>::max();
const AstNodeIndex res = m_parsed_regex.nodes.size();
if (res == max_nodes)
parse_error(format("regex parsed to more than {} ast nodes", max_nodes));
const AstNodeIndex next = res+1;
m_parsed_regex.nodes.push_back({op, m_ignore_case, next, value, quantifier});
return res;
}
bool at_end() const { return m_pos == m_regex.end(); }
ParsedRegex::AstNode& get_node(AstNodeIndex index)
{
return m_parsed_regex.nodes[index];
}
[[gnu::noreturn]]
void parse_error(StringView error) const
{
@ -559,16 +608,17 @@ private:
StringView{m_pos.base(), m_regex.end()}));
}
void validate_lookaround(const AstNodePtr& node)
void validate_lookaround(AstNodeIndex index)
{
for (auto& child : node->children)
{
if (child->op != ParsedRegex::Literal and child->op != ParsedRegex::Matcher and
child->op != ParsedRegex::AnyChar)
for_each_child(m_parsed_regex, index, [this](AstNodeIndex child_index) {
auto& child = get_node(child_index);
if (child.op != ParsedRegex::Literal and child.op != ParsedRegex::Matcher and
child.op != ParsedRegex::AnyChar)
parse_error("Lookaround can only contain literals, any chars or character classes");
if (child->quantifier.type != ParsedRegex::Quantifier::One)
if (child.quantifier.type != ParsedRegex::Quantifier::One)
parse_error("Quantifiers cannot be used in lookarounds");
}
return true;
});
}
ParsedRegex m_parsed_regex;
@ -609,7 +659,7 @@ struct RegexCompiler
: m_parsed_regex{parsed_regex}, m_flags(flags), m_forward{direction == MatchDirection::Forward}
{
write_search_prefix();
compile_node(m_parsed_regex.ast);
compile_node(0);
push_inst(CompiledRegex::Match);
m_program.matchers = m_parsed_regex.matchers;
m_program.save_count = m_parsed_regex.capture_count * 2;
@ -621,61 +671,68 @@ struct RegexCompiler
private:
uint32_t compile_node_inner(const ParsedRegex::AstNodePtr& node)
uint32_t compile_node_inner(ParsedRegex::AstNodeIndex index)
{
const auto start_pos = m_program.instructions.size();
const bool ignore_case = node->ignore_case;
auto& node = get_node(index);
const bool save = (node->op == ParsedRegex::Alternation or node->op == ParsedRegex::Sequence) and
(node->value == 0 or (node->value != -1 and not (m_flags & RegexCompileFlags::NoSubs)));
const uint32_t start_pos = (uint32_t)m_program.instructions.size();
const bool ignore_case = node.ignore_case;
const bool save = (node.op == ParsedRegex::Alternation or node.op == ParsedRegex::Sequence) and
(node.value == 0 or (node.value != -1 and not (m_flags & RegexCompileFlags::NoSubs)));
if (save)
push_inst(CompiledRegex::Save, node->value * 2 + (m_forward ? 0 : 1));
push_inst(CompiledRegex::Save, node.value * 2 + (m_forward ? 0 : 1));
Vector<uint32_t> goto_inner_end_offsets;
switch (node->op)
switch (node.op)
{
case ParsedRegex::Literal:
if (ignore_case)
push_inst(CompiledRegex::Literal_IgnoreCase, to_lower(node->value));
push_inst(CompiledRegex::Literal_IgnoreCase, to_lower(node.value));
else
push_inst(CompiledRegex::Literal, node->value);
push_inst(CompiledRegex::Literal, node.value);
break;
case ParsedRegex::AnyChar:
push_inst(CompiledRegex::AnyChar);
break;
case ParsedRegex::Matcher:
push_inst(CompiledRegex::Matcher, node->value);
push_inst(CompiledRegex::Matcher, node.value);
break;
case ParsedRegex::Sequence:
{
if (m_forward)
for (auto& child : node->children)
compile_node(child);
for_each_child(m_parsed_regex, index, [this](ParsedRegex::AstNodeIndex child) {
compile_node(child); return true;
});
else
for (auto& child : node->children | reverse())
compile_node(child);
for_each_child_reverse(m_parsed_regex, index, [this](ParsedRegex::AstNodeIndex child) {
compile_node(child); return true;
});
break;
}
case ParsedRegex::Alternation:
{
auto& children = node->children;
kak_assert(children.size() > 1);
//kak_assert(children.size() > 1);
const auto split_pos = m_program.instructions.size();
for (int i = 0; i < children.size() - 1; ++i)
push_inst(CompiledRegex::Split_PrioritizeParent);
auto split_pos = m_program.instructions.size();
for_each_child(m_parsed_regex, index, [this, index](ParsedRegex::AstNodeIndex child) {
if (child != index+1)
push_inst(CompiledRegex::Split_PrioritizeParent);
return true;
});
for (int i = 0; i < children.size(); ++i)
{
auto node = compile_node(children[i]);
if (i > 0)
m_program.instructions[split_pos + i - 1].param = node;
if (i < children.size() - 1)
for_each_child(m_parsed_regex, index,
[&, end = node.children_end](ParsedRegex::AstNodeIndex child) {
auto node = compile_node(child);
if (child != index+1)
m_program.instructions[split_pos++].param = node;
if (get_node(child).children_end != end)
{
auto jump = push_inst(CompiledRegex::Jump);
goto_inner_end_offsets.push_back(jump);
}
}
return true;
});
break;
}
case ParsedRegex::LookAhead:
@ -683,28 +740,28 @@ private:
: CompiledRegex::LookAhead)
: (ignore_case ? CompiledRegex::LookBehind_IgnoreCase
: CompiledRegex::LookBehind),
push_lookaround(node->children, false, ignore_case));
push_lookaround(index, false, ignore_case));
break;
case ParsedRegex::NegativeLookAhead:
push_inst(m_forward ? (ignore_case ? CompiledRegex::NegativeLookAhead_IgnoreCase
: CompiledRegex::NegativeLookAhead)
: (ignore_case ? CompiledRegex::NegativeLookBehind_IgnoreCase
: CompiledRegex::NegativeLookBehind),
push_lookaround(node->children, false, ignore_case));
push_lookaround(index, false, ignore_case));
break;
case ParsedRegex::LookBehind:
push_inst(m_forward ? (ignore_case ? CompiledRegex::LookBehind_IgnoreCase
: CompiledRegex::LookBehind)
: (ignore_case ? CompiledRegex::LookAhead_IgnoreCase
: CompiledRegex::LookAhead),
push_lookaround(node->children, true, ignore_case));
push_lookaround(index, true, ignore_case));
break;
case ParsedRegex::NegativeLookBehind:
push_inst(m_forward ? (ignore_case ? CompiledRegex::NegativeLookBehind_IgnoreCase
: CompiledRegex::NegativeLookBehind)
: (ignore_case ? CompiledRegex::NegativeLookAhead_IgnoreCase
: CompiledRegex::NegativeLookAhead),
push_lookaround(node->children, true, ignore_case));
push_lookaround(index, true, ignore_case));
break;
case ParsedRegex::LineStart:
push_inst(m_forward ? CompiledRegex::LineStart
@ -737,17 +794,19 @@ private:
m_program.instructions[offset].param = m_program.instructions.size();
if (save)
push_inst(CompiledRegex::Save, node->value * 2 + (m_forward ? 1 : 0));
push_inst(CompiledRegex::Save, node.value * 2 + (m_forward ? 1 : 0));
return start_pos;
}
uint32_t compile_node(const ParsedRegex::AstNodePtr& node)
uint32_t compile_node(ParsedRegex::AstNodeIndex index)
{
uint32_t pos = m_program.instructions.size();
auto& node = get_node(index);
const uint32_t start_pos = (uint32_t)m_program.instructions.size();
Vector<uint32_t> goto_ends;
auto& quantifier = node->quantifier;
auto& quantifier = node.quantifier;
// TODO reverse, invert the way we write optional quantifiers ?
@ -758,10 +817,10 @@ private:
goto_ends.push_back(split_pos);
}
auto inner_pos = compile_node_inner(node);
auto inner_pos = compile_node_inner(index);
// Write the node multiple times when we have a min count quantifier
for (int i = 1; i < quantifier.min; ++i)
inner_pos = compile_node_inner(node);
inner_pos = compile_node_inner(index);
if (quantifier.allows_infinite_repeat())
push_inst(quantifier.greedy ? CompiledRegex::Split_PrioritizeChild
@ -775,13 +834,13 @@ private:
auto split_pos = push_inst(quantifier.greedy ? CompiledRegex::Split_PrioritizeParent
: CompiledRegex::Split_PrioritizeChild);
goto_ends.push_back(split_pos);
compile_node_inner(node);
compile_node_inner(index);
}
for (auto offset : goto_ends)
m_program.instructions[offset].param = m_program.instructions.size();
return pos;
return start_pos;
}
// Add an set of instruction prefix used in the search use case
@ -804,29 +863,27 @@ private:
return res;
}
uint32_t push_lookaround(ArrayView<const ParsedRegex::AstNodePtr> characters,
bool reversed, bool ignore_case)
uint32_t push_lookaround(ParsedRegex::AstNodeIndex index, bool reversed, bool ignore_case)
{
uint32_t res = m_program.lookarounds.size();
auto write_lookaround = [this, ignore_case](auto&& characters) {
for (auto& character : characters)
{
if (character->op == ParsedRegex::Literal)
m_program.lookarounds.push_back(ignore_case ? to_lower(character->value)
: character->value);
else if (character->op == ParsedRegex::AnyChar)
auto write_matcher = [this, ignore_case](ParsedRegex::AstNodeIndex child) {
auto& character = get_node(child);
if (character.op == ParsedRegex::Literal)
m_program.lookarounds.push_back(ignore_case ? to_lower(character.value)
: character.value);
else if (character.op == ParsedRegex::AnyChar)
m_program.lookarounds.push_back(0xF000);
else if (character->op == ParsedRegex::Matcher)
m_program.lookarounds.push_back(0xF0001 + character->value);
else if (character.op == ParsedRegex::Matcher)
m_program.lookarounds.push_back(0xF0001 + character.value);
else
kak_assert(false);
}
return true;
};
if (reversed)
write_lookaround(characters | reverse());
for_each_child_reverse(m_parsed_regex, index, write_matcher);
else
write_lookaround(characters);
for_each_child(m_parsed_regex, index, write_matcher);
m_program.lookarounds.push_back((Codepoint)-1);
return res;
@ -835,57 +892,58 @@ private:
// Fills accepted and rejected according to which chars can start the given node,
// returns true if the node did not consume the char, hence a following node in
// sequence would be still relevant for the parent node start chars computation.
bool compute_start_chars(const ParsedRegex::AstNodePtr& node,
bool compute_start_chars(ParsedRegex::AstNodeIndex index,
CompiledRegex::StartChars& start_chars) const
{
switch (node->op)
auto& node = get_node(index);
switch (node.op)
{
case ParsedRegex::Literal:
if (node->value < CompiledRegex::StartChars::count)
if (node.value < CompiledRegex::StartChars::count)
{
if (node->ignore_case)
if (node.ignore_case)
{
start_chars.map[to_lower(node->value)] = true;
start_chars.map[to_upper(node->value)] = true;
start_chars.map[to_lower(node.value)] = true;
start_chars.map[to_upper(node.value)] = true;
}
else
start_chars.map[node->value] = true;
start_chars.map[node.value] = true;
}
else
start_chars.map[CompiledRegex::StartChars::other] = true;
return node->quantifier.allows_none();
return node.quantifier.allows_none();
case ParsedRegex::AnyChar:
for (auto& b : start_chars.map)
b = true;
start_chars.map[CompiledRegex::StartChars::other] = true;
return node->quantifier.allows_none();
return node.quantifier.allows_none();
case ParsedRegex::Matcher:
for (Codepoint c = 0; c < CompiledRegex::StartChars::count; ++c)
if (m_program.matchers[node->value](c))
if (m_program.matchers[node.value](c))
start_chars.map[c] = true;
start_chars.map[CompiledRegex::StartChars::other] = true; // stay safe
return node->quantifier.allows_none();
return node.quantifier.allows_none();
case ParsedRegex::Sequence:
{
bool consumed = false;
auto consumes = [&, this](auto& child) {
return not this->compute_start_chars(child, start_chars);
bool did_not_consume = false;
auto does_not_consume = [&, this](auto child) {
return this->compute_start_chars(child, start_chars);
};
if (m_forward)
consumed = contains_that(node->children, consumes);
did_not_consume = for_each_child(m_parsed_regex, index, does_not_consume);
else
consumed = contains_that(node->children | reverse(), consumes);
did_not_consume = for_each_child_reverse(m_parsed_regex, index, does_not_consume);
return not consumed or node->quantifier.allows_none();
return did_not_consume or node.quantifier.allows_none();
}
case ParsedRegex::Alternation:
{
bool all_consumed = not node->quantifier.allows_none();
for (auto& child : node->children)
{
bool all_consumed = not node.quantifier.allows_none();
for_each_child(m_parsed_regex, index, [&](ParsedRegex::AstNodeIndex child) {
if (compute_start_chars(child, start_chars))
all_consumed = false;
}
return true;
});
return not all_consumed;
}
case ParsedRegex::LineStart:
@ -908,7 +966,7 @@ private:
std::unique_ptr<CompiledRegex::StartChars> compute_start_chars() const
{
CompiledRegex::StartChars start_chars{};
if (compute_start_chars(m_parsed_regex.ast, start_chars))
if (compute_start_chars(0, start_chars))
return nullptr;
if (not contains(start_chars.map, false))
@ -917,6 +975,11 @@ private:
return std::make_unique<CompiledRegex::StartChars>(start_chars);
}
const ParsedRegex::AstNode& get_node(ParsedRegex::AstNodeIndex index) const
{
return m_parsed_regex.nodes[index];
}
CompiledRegex m_program;
RegexCompileFlags m_flags;
const ParsedRegex& m_parsed_regex;