Regex: Replace generic 'Matchers' with specialized functionality

Introduce CharacterClass and CharacterType Regex Op, and optimize
their evaluation.
This commit is contained in:
Maxime Coste 2017-11-25 18:14:15 +08:00
parent 0d44cf9591
commit 8b40f57145
2 changed files with 155 additions and 91 deletions

View File

@ -22,7 +22,8 @@ struct ParsedRegex
{ {
Literal, Literal,
AnyChar, AnyChar,
Matcher, Class,
CharacterType,
Sequence, Sequence,
Alternation, Alternation,
LineStart, LineStart,
@ -80,7 +81,8 @@ struct ParsedRegex
}; };
Vector<Node, MemoryDomain::Regex> nodes; Vector<Node, MemoryDomain::Regex> nodes;
Vector<std::function<bool (Codepoint)>, MemoryDomain::Regex> matchers;
Vector<CharacterClass, MemoryDomain::Regex> character_classes;
size_t capture_count; size_t capture_count;
}; };
@ -319,18 +321,9 @@ private:
} }
// CharacterClassEscape // CharacterClassEscape
auto class_it = find_if(character_class_escapes, auto class_it = find_if(character_class_escapes, [cp](auto& c) { return c.cp == cp; });
[cp = to_lower(cp)](auto& c) { return c.cp == cp; });
if (class_it != std::end(character_class_escapes)) if (class_it != std::end(character_class_escapes))
{ return new_node(ParsedRegex::CharacterType, (Codepoint)class_it->ctype);
auto matcher_id = m_parsed_regex.matchers.size();
m_parsed_regex.matchers.push_back(
[ctype = class_it->ctype ? wctype(class_it->ctype) : (wctype_t)0,
chars = class_it->additional_chars, neg = is_upper(cp)] (Codepoint cp) {
return ((ctype != 0 and iswctype(cp, ctype)) or contains(chars, cp)) != neg;
});
return new_node(ParsedRegex::Matcher, matcher_id);
}
// CharacterEscape // CharacterEscape
for (auto& control : control_escapes) for (auto& control : control_escapes)
@ -383,9 +376,7 @@ private:
parse_error(format("unknown atom escape '{}'", cp)); parse_error(format("unknown atom escape '{}'", cp));
} }
struct CharRange { Codepoint min, max; }; void normalize_ranges(Vector<CharacterClass::Range, MemoryDomain::Regex>& ranges)
void normalize_ranges(Vector<CharRange, MemoryDomain::Regex>& ranges)
{ {
if (ranges.empty()) if (ranges.empty())
return; return;
@ -411,19 +402,19 @@ private:
NodeIndex character_class() NodeIndex character_class()
{ {
const bool negative = m_pos != m_regex.end() and *m_pos == '^'; CharacterClass character_class;
if (negative)
character_class.ignore_case = m_ignore_case;
character_class.negative = m_pos != m_regex.end() and *m_pos == '^';
if (character_class.negative)
++m_pos; ++m_pos;
Vector<CharRange, MemoryDomain::Regex> ranges;
Vector<Codepoint, MemoryDomain::Regex> excluded;
Vector<std::pair<wctype_t, bool>, MemoryDomain::Regex> ctypes;
while (m_pos != m_regex.end() and *m_pos != ']') while (m_pos != m_regex.end() and *m_pos != ']')
{ {
auto cp = *m_pos++; auto cp = *m_pos++;
if (cp == '-') if (cp == '-')
{ {
ranges.push_back({ '-', '-' }); character_class.ranges.push_back({ '-', '-' });
continue; continue;
} }
@ -433,19 +424,10 @@ private:
if (cp == '\\') if (cp == '\\')
{ {
auto it = find_if(character_class_escapes, auto it = find_if(character_class_escapes,
[cp = to_lower(*m_pos)](auto& t) { return t.cp == cp; }); [cp = *m_pos](auto& t) { return t.cp == cp; });
if (it != std::end(character_class_escapes)) if (it != std::end(character_class_escapes))
{ {
auto negative = is_upper(*m_pos); character_class.ctypes |= it->ctype;
if (it->ctype)
ctypes.push_back({wctype(it->ctype), not negative});
for (auto& c : it->additional_chars)
{
if (negative)
excluded.push_back((Codepoint)c);
else
ranges.push_back({(Codepoint)c, (Codepoint)c});
}
++m_pos; ++m_pos;
continue; continue;
} }
@ -463,7 +445,7 @@ private:
} }
} }
CharRange range = { cp, cp }; CharacterClass::Range range = { cp, cp };
if (*m_pos == '-') if (*m_pos == '-')
{ {
if (++m_pos == m_regex.end()) if (++m_pos == m_regex.end())
@ -476,11 +458,11 @@ private:
} }
else else
{ {
ranges.push_back(range); character_class.ranges.push_back(range);
range = { '-', '-' }; range = { '-', '-' };
} }
} }
ranges.push_back(range); character_class.ranges.push_back(range);
} }
if (at_end()) if (at_end())
parse_error("unclosed character class"); parse_error("unclosed character class");
@ -488,45 +470,30 @@ private:
if (m_ignore_case) if (m_ignore_case)
{ {
for (auto& range : ranges) for (auto& range : character_class.ranges)
{ {
range.min = to_lower(range.min); range.min = to_lower(range.min);
range.max = to_lower(range.max); range.max = to_lower(range.max);
} }
for (auto& cp : excluded)
cp = to_lower(cp);
} }
normalize_ranges(ranges); normalize_ranges(character_class.ranges);
// Optimize the relatively common case of using a character class to // Optimize the relatively common case of using a character class to
// escape a character, such as [*] // escape a character, such as [*]
if (ctypes.empty() and excluded.empty() and not negative and if (character_class.ctypes == CharacterType::None and not character_class.negative and
ranges.size() == 1 and ranges.front().min == ranges.front().max) character_class.ranges.size() == 1 and
return new_node(ParsedRegex::Literal, ranges.front().min); character_class.ranges.front().min == character_class.ranges.front().max)
return new_node(ParsedRegex::Literal, character_class.ranges.front().min);
auto matcher = [ranges = std::move(ranges), if (character_class.ctypes != CharacterType::None and not character_class.negative and
ctypes = std::move(ctypes), character_class.ranges.empty())
excluded = std::move(excluded), return new_node(ParsedRegex::CharacterType, (Codepoint)character_class.ctypes);
negative, ignore_case = m_ignore_case] (Codepoint cp) {
if (ignore_case)
cp = to_lower(cp);
auto it = std::lower_bound(ranges.begin(), ranges.end(), cp, auto class_id = m_parsed_regex.character_classes.size();
[](auto& range, Codepoint cp) m_parsed_regex.character_classes.push_back(std::move(character_class));
{ return range.max < cp; });
auto found = (it != ranges.end() and it->min <= cp) or return new_node(ParsedRegex::Class, class_id);
contains_that(ctypes, [cp](auto& c) {
return (bool)iswctype(cp, c.first) == c.second;
}) or (not excluded.empty() and not contains(excluded, cp));
return negative ? not found : found;
};
auto matcher_id = m_parsed_regex.matchers.size();
m_parsed_regex.matchers.push_back(std::move(matcher));
return new_node(ParsedRegex::Matcher, matcher_id);
} }
ParsedRegex::Quantifier quantifier() ParsedRegex::Quantifier quantifier()
@ -612,8 +579,8 @@ private:
{ {
for_each_child(m_parsed_regex, index, [this](NodeIndex child_index) { for_each_child(m_parsed_regex, index, [this](NodeIndex child_index) {
auto& child = get_node(child_index); auto& child = get_node(child_index);
if (child.op != ParsedRegex::Literal and child.op != ParsedRegex::Matcher and if (child.op != ParsedRegex::Literal and child.op != ParsedRegex::Class and
child.op != ParsedRegex::AnyChar) child.op != ParsedRegex::CharacterType and child.op != ParsedRegex::AnyChar)
parse_error("Lookaround can only contain literals, any chars or character classes"); parse_error("Lookaround can only contain literals, any chars or character classes");
if (child.quantifier.type != ParsedRegex::Quantifier::One) if (child.quantifier.type != ParsedRegex::Quantifier::One)
parse_error("Quantifiers cannot be used in lookarounds"); parse_error("Quantifiers cannot be used in lookarounds");
@ -628,14 +595,12 @@ private:
static constexpr struct CharacterClassEscape { static constexpr struct CharacterClassEscape {
Codepoint cp; Codepoint cp;
const char* ctype; CharacterType ctype;
StringView additional_chars;
bool neg;
} character_class_escapes[] = { } character_class_escapes[] = {
{ 'd', "digit", "", false }, { 'd', CharacterType::Digit }, { 'D', CharacterType::NotDigit },
{ 'w', "alnum", "_", false }, { 'w', CharacterType::Word }, { 'W', CharacterType::NotWord },
{ 's', "space", "", false }, { 's', CharacterType::Whitespace }, { 'S', CharacterType::NotWhitespace },
{ 'h', nullptr, " \t", false }, { 'h', CharacterType::HorizontalWhitespace }, { 'H', CharacterType::NotHorizontalWhitespace },
}; };
static constexpr struct ControlEscape { static constexpr struct ControlEscape {
@ -661,7 +626,7 @@ struct RegexCompiler
write_search_prefix(); write_search_prefix();
compile_node(0); compile_node(0);
push_inst(CompiledRegex::Match); push_inst(CompiledRegex::Match);
m_program.matchers = m_parsed_regex.matchers; m_program.character_classes = m_parsed_regex.character_classes;
m_program.save_count = m_parsed_regex.capture_count * 2; m_program.save_count = m_parsed_regex.capture_count * 2;
m_program.direction = direction; m_program.direction = direction;
m_program.start_chars = compute_start_chars(); m_program.start_chars = compute_start_chars();
@ -695,8 +660,11 @@ private:
case ParsedRegex::AnyChar: case ParsedRegex::AnyChar:
push_inst(CompiledRegex::AnyChar); push_inst(CompiledRegex::AnyChar);
break; break;
case ParsedRegex::Matcher: case ParsedRegex::Class:
push_inst(CompiledRegex::Matcher, node.value); push_inst(CompiledRegex::Class, node.value);
break;
case ParsedRegex::CharacterType:
push_inst(CompiledRegex::CharacterType, node.value);
break; break;
case ParsedRegex::Sequence: case ParsedRegex::Sequence:
{ {
@ -871,8 +839,10 @@ private:
: character.value); : character.value);
else if (character.op == ParsedRegex::AnyChar) else if (character.op == ParsedRegex::AnyChar)
m_program.lookarounds.push_back(0xF000); m_program.lookarounds.push_back(0xF000);
else if (character.op == ParsedRegex::Matcher) else if (character.op == ParsedRegex::Class)
m_program.lookarounds.push_back(0xF0001 + character.value); m_program.lookarounds.push_back(0xF0001 + character.value);
else if (character.op == ParsedRegex::CharacterType)
m_program.lookarounds.push_back(0xF8000 | character.value);
else else
kak_assert(false); kak_assert(false);
return true; return true;
@ -915,12 +885,28 @@ private:
b = true; b = true;
start_chars.map[CompiledRegex::StartChars::other] = true; start_chars.map[CompiledRegex::StartChars::other] = true;
return node.quantifier.allows_none(); return node.quantifier.allows_none();
case ParsedRegex::Matcher: case ParsedRegex::Class:
for (Codepoint c = 0; c < CompiledRegex::StartChars::count; ++c) {
if (m_program.matchers[node.value](c)) auto& character_class = m_parsed_regex.character_classes[node.value];
start_chars.map[c] = true; for (Codepoint cp = 0; cp < CompiledRegex::StartChars::count; ++cp)
start_chars.map[CompiledRegex::StartChars::other] = true; // stay safe {
if (is_character_class(character_class, cp))
start_chars.map[cp] = true;
}
start_chars.map[CompiledRegex::StartChars::other] = true;
return node.quantifier.allows_none(); return node.quantifier.allows_none();
}
case ParsedRegex::CharacterType:
{
const CharacterType ctype = (CharacterType)node.value;
for (Codepoint cp = 0; cp < CompiledRegex::StartChars::count; ++cp)
{
if (is_ctype(ctype, cp))
start_chars.map[cp] = true;
}
start_chars.map[CompiledRegex::StartChars::other] = true;
return node.quantifier.allows_none();
}
case ParsedRegex::Sequence: case ParsedRegex::Sequence:
{ {
bool did_not_consume = false; bool did_not_consume = false;
@ -1015,8 +1001,11 @@ void dump_regex(const CompiledRegex& program)
case CompiledRegex::Save: case CompiledRegex::Save:
printf("save %d\n", inst.param); printf("save %d\n", inst.param);
break; break;
case CompiledRegex::Matcher: case CompiledRegex::Class:
printf("matcher %d\n", inst.param); printf("class %d\n", inst.param);
break;
case CompiledRegex::CharacterType:
printf("character type %d\n", inst.param);
break; break;
case CompiledRegex::LineStart: case CompiledRegex::LineStart:
printf("line start\n"); printf("line start\n");
@ -1084,6 +1073,43 @@ CompiledRegex compile_regex(StringView re, RegexCompileFlags flags, MatchDirecti
return RegexCompiler{RegexParser::parse(re), flags, direction}.get_compiled_regex(); return RegexCompiler{RegexParser::parse(re), flags, direction}.get_compiled_regex();
} }
bool is_character_class(const CharacterClass& character_class, Codepoint cp)
{
if (character_class.ignore_case)
cp = to_lower(cp);
auto it = std::lower_bound(character_class.ranges.begin(),
character_class.ranges.end(), cp,
[](auto& range, Codepoint cp)
{ return range.max < cp; });
auto found = (it != character_class.ranges.end() and it->min <= cp) or
is_ctype(character_class.ctypes, cp);
return found != character_class.negative;
}
bool is_ctype(CharacterType ctype, Codepoint cp)
{
if ((ctype & CharacterType::Digit) and iswdigit(cp))
return true;
if ((ctype & CharacterType::Word) and is_word(cp))
return true;
if ((ctype & CharacterType::Whitespace) and is_blank(cp))
return true;
if ((ctype & CharacterType::HorizontalWhitespace) and is_horizontal_blank(cp))
return true;
if ((ctype & CharacterType::NotDigit) and not iswdigit(cp))
return true;
if ((ctype & CharacterType::NotWord) and not is_word(cp))
return true;
if ((ctype & CharacterType::NotWhitespace) and not is_blank(cp))
return true;
if ((ctype & CharacterType::NotHorizontalWhitespace) and not is_horizontal_blank(cp))
return true;
return false;
}
namespace namespace
{ {
template<MatchDirection dir = MatchDirection::Forward> template<MatchDirection dir = MatchDirection::Forward>
@ -1257,7 +1283,7 @@ auto test_regex = UnitTest{[]{
} }
{ {
TestVM<> vm{R"((?=foo).)"}; TestVM<> vm{R"((?=fo[\w]).)"};
kak_assert(vm.exec("barfoo", RegexExecFlags::Search)); kak_assert(vm.exec("barfoo", RegexExecFlags::Search));
kak_assert(StringView{vm.captures()[0], vm.captures()[1]} == "f"); kak_assert(StringView{vm.captures()[0], vm.captures()[1]} == "f");
} }
@ -1274,7 +1300,7 @@ auto test_regex = UnitTest{[]{
} }
{ {
TestVM<> vm{R"(...(?<=f.o))"}; TestVM<> vm{R"(...(?<=f\w.))"};
kak_assert(vm.exec("foo")); kak_assert(vm.exec("foo"));
kak_assert(not vm.exec("qux")); kak_assert(not vm.exec("qux"));
} }

View File

@ -23,6 +23,33 @@ enum class MatchDirection
Backward Backward
}; };
enum class CharacterType : unsigned char
{
None = 0,
Word = 1 << 0,
Whitespace = 1 << 1,
HorizontalWhitespace = 1 << 2,
Digit = 1 << 3,
NotWord = 1 << 4,
NotWhitespace = 1 << 5,
NotHorizontalWhitespace = 1 << 6,
NotDigit = 1 << 7
};
constexpr bool with_bit_ops(Meta::Type<CharacterType>) { return true; }
struct CharacterClass
{
struct Range { Codepoint min, max; };
Vector<Range, MemoryDomain::Regex> ranges;
CharacterType ctypes = CharacterType::None;
bool negative = false;
bool ignore_case = false;
};
bool is_character_class(const CharacterClass& character_class, Codepoint cp);
bool is_ctype(CharacterType ctype, Codepoint cp);
struct CompiledRegex : RefCountable, UseMemoryDomain<MemoryDomain::Regex> struct CompiledRegex : RefCountable, UseMemoryDomain<MemoryDomain::Regex>
{ {
enum Op : char enum Op : char
@ -32,7 +59,8 @@ struct CompiledRegex : RefCountable, UseMemoryDomain<MemoryDomain::Regex>
Literal, Literal,
Literal_IgnoreCase, Literal_IgnoreCase,
AnyChar, AnyChar,
Matcher, Class,
CharacterType,
Jump, Jump,
Split_PrioritizeParent, Split_PrioritizeParent,
Split_PrioritizeChild, Split_PrioritizeChild,
@ -68,7 +96,7 @@ struct CompiledRegex : RefCountable, UseMemoryDomain<MemoryDomain::Regex>
explicit operator bool() const { return not instructions.empty(); } explicit operator bool() const { return not instructions.empty(); }
Vector<Instruction, MemoryDomain::Regex> instructions; Vector<Instruction, MemoryDomain::Regex> instructions;
Vector<std::function<bool (Codepoint)>, MemoryDomain::Regex> matchers; Vector<CharacterClass, MemoryDomain::Regex> character_classes;
Vector<Codepoint, MemoryDomain::Regex> lookarounds; Vector<Codepoint, MemoryDomain::Regex> lookarounds;
MatchDirection direction; MatchDirection direction;
size_t save_count; size_t save_count;
@ -289,11 +317,16 @@ private:
thread.saves->pos[inst.param] = get_base(pos); thread.saves->pos[inst.param] = get_base(pos);
break; break;
} }
case CompiledRegex::Matcher: case CompiledRegex::Class:
if (pos == m_end) if (pos == m_end)
return StepResult::Failed; return StepResult::Failed;
return m_program.matchers[inst.param](*pos) ? return is_character_class(m_program.character_classes[inst.param], *pos) ?
StepResult::Consumed : StepResult::Failed; StepResult::Consumed : StepResult::Failed;
case CompiledRegex::CharacterType:
if (pos == m_end)
return StepResult::Failed;
return is_ctype((CharacterType)inst.param, *pos) ?
StepResult::Consumed : StepResult::Failed;;
case CompiledRegex::LineStart: case CompiledRegex::LineStart:
if (not is_line_start(pos)) if (not is_line_start(pos))
return StepResult::Failed; return StepResult::Failed;
@ -457,9 +490,14 @@ private:
const Codepoint ref = *it; const Codepoint ref = *it;
if (ref == 0xF000) if (ref == 0xF000)
{} // any character matches {} // any character matches
else if (ref > 0xF0000 and ref <= 0xFFFFD) else if (ref > 0xF0000 and ref < 0xF8000)
{ {
if (not m_program.matchers[ref - 0xF0001](cp)) if (not is_character_class(m_program.character_classes[ref - 0xF0001], cp))
return false;
}
else if (ref >= 0xF8000 and ref <= 0xFFFFD)
{
if (not is_ctype((CharacterType)(ref & 0xFF), cp))
return false; return false;
} }
else if (ref != cp) else if (ref != cp)