Regex: Replace generic 'Matchers' with specialized functionality

Introduce CharacterClass and CharacterType Regex Op, and optimize
their evaluation.
This commit is contained in:
Maxime Coste 2017-11-25 18:14:15 +08:00
parent 0d44cf9591
commit 8b40f57145
2 changed files with 155 additions and 91 deletions

View File

@ -22,7 +22,8 @@ struct ParsedRegex
{
Literal,
AnyChar,
Matcher,
Class,
CharacterType,
Sequence,
Alternation,
LineStart,
@ -80,7 +81,8 @@ struct ParsedRegex
};
Vector<Node, MemoryDomain::Regex> nodes;
Vector<std::function<bool (Codepoint)>, MemoryDomain::Regex> matchers;
Vector<CharacterClass, MemoryDomain::Regex> character_classes;
size_t capture_count;
};
@ -319,18 +321,9 @@ private:
}
// CharacterClassEscape
auto class_it = find_if(character_class_escapes,
[cp = to_lower(cp)](auto& c) { return c.cp == cp; });
auto class_it = find_if(character_class_escapes, [cp](auto& c) { return c.cp == cp; });
if (class_it != std::end(character_class_escapes))
{
auto matcher_id = m_parsed_regex.matchers.size();
m_parsed_regex.matchers.push_back(
[ctype = class_it->ctype ? wctype(class_it->ctype) : (wctype_t)0,
chars = class_it->additional_chars, neg = is_upper(cp)] (Codepoint cp) {
return ((ctype != 0 and iswctype(cp, ctype)) or contains(chars, cp)) != neg;
});
return new_node(ParsedRegex::Matcher, matcher_id);
}
return new_node(ParsedRegex::CharacterType, (Codepoint)class_it->ctype);
// CharacterEscape
for (auto& control : control_escapes)
@ -383,9 +376,7 @@ private:
parse_error(format("unknown atom escape '{}'", cp));
}
struct CharRange { Codepoint min, max; };
void normalize_ranges(Vector<CharRange, MemoryDomain::Regex>& ranges)
void normalize_ranges(Vector<CharacterClass::Range, MemoryDomain::Regex>& ranges)
{
if (ranges.empty())
return;
@ -411,19 +402,19 @@ private:
NodeIndex character_class()
{
const bool negative = m_pos != m_regex.end() and *m_pos == '^';
if (negative)
CharacterClass character_class;
character_class.ignore_case = m_ignore_case;
character_class.negative = m_pos != m_regex.end() and *m_pos == '^';
if (character_class.negative)
++m_pos;
Vector<CharRange, MemoryDomain::Regex> ranges;
Vector<Codepoint, MemoryDomain::Regex> excluded;
Vector<std::pair<wctype_t, bool>, MemoryDomain::Regex> ctypes;
while (m_pos != m_regex.end() and *m_pos != ']')
{
auto cp = *m_pos++;
if (cp == '-')
{
ranges.push_back({ '-', '-' });
character_class.ranges.push_back({ '-', '-' });
continue;
}
@ -433,19 +424,10 @@ private:
if (cp == '\\')
{
auto it = find_if(character_class_escapes,
[cp = to_lower(*m_pos)](auto& t) { return t.cp == cp; });
[cp = *m_pos](auto& t) { return t.cp == cp; });
if (it != std::end(character_class_escapes))
{
auto negative = is_upper(*m_pos);
if (it->ctype)
ctypes.push_back({wctype(it->ctype), not negative});
for (auto& c : it->additional_chars)
{
if (negative)
excluded.push_back((Codepoint)c);
else
ranges.push_back({(Codepoint)c, (Codepoint)c});
}
character_class.ctypes |= it->ctype;
++m_pos;
continue;
}
@ -463,7 +445,7 @@ private:
}
}
CharRange range = { cp, cp };
CharacterClass::Range range = { cp, cp };
if (*m_pos == '-')
{
if (++m_pos == m_regex.end())
@ -476,11 +458,11 @@ private:
}
else
{
ranges.push_back(range);
character_class.ranges.push_back(range);
range = { '-', '-' };
}
}
ranges.push_back(range);
character_class.ranges.push_back(range);
}
if (at_end())
parse_error("unclosed character class");
@ -488,45 +470,30 @@ private:
if (m_ignore_case)
{
for (auto& range : ranges)
for (auto& range : character_class.ranges)
{
range.min = to_lower(range.min);
range.max = to_lower(range.max);
}
for (auto& cp : excluded)
cp = to_lower(cp);
}
normalize_ranges(ranges);
normalize_ranges(character_class.ranges);
// Optimize the relatively common case of using a character class to
// escape a character, such as [*]
if (ctypes.empty() and excluded.empty() and not negative and
ranges.size() == 1 and ranges.front().min == ranges.front().max)
return new_node(ParsedRegex::Literal, ranges.front().min);
if (character_class.ctypes == CharacterType::None and not character_class.negative and
character_class.ranges.size() == 1 and
character_class.ranges.front().min == character_class.ranges.front().max)
return new_node(ParsedRegex::Literal, character_class.ranges.front().min);
auto matcher = [ranges = std::move(ranges),
ctypes = std::move(ctypes),
excluded = std::move(excluded),
negative, ignore_case = m_ignore_case] (Codepoint cp) {
if (ignore_case)
cp = to_lower(cp);
if (character_class.ctypes != CharacterType::None and not character_class.negative and
character_class.ranges.empty())
return new_node(ParsedRegex::CharacterType, (Codepoint)character_class.ctypes);
auto it = std::lower_bound(ranges.begin(), ranges.end(), cp,
[](auto& range, Codepoint cp)
{ return range.max < cp; });
auto class_id = m_parsed_regex.character_classes.size();
m_parsed_regex.character_classes.push_back(std::move(character_class));
auto found = (it != ranges.end() and it->min <= cp) or
contains_that(ctypes, [cp](auto& c) {
return (bool)iswctype(cp, c.first) == c.second;
}) or (not excluded.empty() and not contains(excluded, cp));
return negative ? not found : found;
};
auto matcher_id = m_parsed_regex.matchers.size();
m_parsed_regex.matchers.push_back(std::move(matcher));
return new_node(ParsedRegex::Matcher, matcher_id);
return new_node(ParsedRegex::Class, class_id);
}
ParsedRegex::Quantifier quantifier()
@ -612,8 +579,8 @@ private:
{
for_each_child(m_parsed_regex, index, [this](NodeIndex child_index) {
auto& child = get_node(child_index);
if (child.op != ParsedRegex::Literal and child.op != ParsedRegex::Matcher and
child.op != ParsedRegex::AnyChar)
if (child.op != ParsedRegex::Literal and child.op != ParsedRegex::Class and
child.op != ParsedRegex::CharacterType and child.op != ParsedRegex::AnyChar)
parse_error("Lookaround can only contain literals, any chars or character classes");
if (child.quantifier.type != ParsedRegex::Quantifier::One)
parse_error("Quantifiers cannot be used in lookarounds");
@ -628,14 +595,12 @@ private:
static constexpr struct CharacterClassEscape {
Codepoint cp;
const char* ctype;
StringView additional_chars;
bool neg;
CharacterType ctype;
} character_class_escapes[] = {
{ 'd', "digit", "", false },
{ 'w', "alnum", "_", false },
{ 's', "space", "", false },
{ 'h', nullptr, " \t", false },
{ 'd', CharacterType::Digit }, { 'D', CharacterType::NotDigit },
{ 'w', CharacterType::Word }, { 'W', CharacterType::NotWord },
{ 's', CharacterType::Whitespace }, { 'S', CharacterType::NotWhitespace },
{ 'h', CharacterType::HorizontalWhitespace }, { 'H', CharacterType::NotHorizontalWhitespace },
};
static constexpr struct ControlEscape {
@ -661,7 +626,7 @@ struct RegexCompiler
write_search_prefix();
compile_node(0);
push_inst(CompiledRegex::Match);
m_program.matchers = m_parsed_regex.matchers;
m_program.character_classes = m_parsed_regex.character_classes;
m_program.save_count = m_parsed_regex.capture_count * 2;
m_program.direction = direction;
m_program.start_chars = compute_start_chars();
@ -695,8 +660,11 @@ private:
case ParsedRegex::AnyChar:
push_inst(CompiledRegex::AnyChar);
break;
case ParsedRegex::Matcher:
push_inst(CompiledRegex::Matcher, node.value);
case ParsedRegex::Class:
push_inst(CompiledRegex::Class, node.value);
break;
case ParsedRegex::CharacterType:
push_inst(CompiledRegex::CharacterType, node.value);
break;
case ParsedRegex::Sequence:
{
@ -871,8 +839,10 @@ private:
: character.value);
else if (character.op == ParsedRegex::AnyChar)
m_program.lookarounds.push_back(0xF000);
else if (character.op == ParsedRegex::Matcher)
else if (character.op == ParsedRegex::Class)
m_program.lookarounds.push_back(0xF0001 + character.value);
else if (character.op == ParsedRegex::CharacterType)
m_program.lookarounds.push_back(0xF8000 | character.value);
else
kak_assert(false);
return true;
@ -915,12 +885,28 @@ private:
b = true;
start_chars.map[CompiledRegex::StartChars::other] = true;
return node.quantifier.allows_none();
case ParsedRegex::Matcher:
for (Codepoint c = 0; c < CompiledRegex::StartChars::count; ++c)
if (m_program.matchers[node.value](c))
start_chars.map[c] = true;
start_chars.map[CompiledRegex::StartChars::other] = true; // stay safe
case ParsedRegex::Class:
{
auto& character_class = m_parsed_regex.character_classes[node.value];
for (Codepoint cp = 0; cp < CompiledRegex::StartChars::count; ++cp)
{
if (is_character_class(character_class, cp))
start_chars.map[cp] = true;
}
start_chars.map[CompiledRegex::StartChars::other] = true;
return node.quantifier.allows_none();
}
case ParsedRegex::CharacterType:
{
const CharacterType ctype = (CharacterType)node.value;
for (Codepoint cp = 0; cp < CompiledRegex::StartChars::count; ++cp)
{
if (is_ctype(ctype, cp))
start_chars.map[cp] = true;
}
start_chars.map[CompiledRegex::StartChars::other] = true;
return node.quantifier.allows_none();
}
case ParsedRegex::Sequence:
{
bool did_not_consume = false;
@ -1015,8 +1001,11 @@ void dump_regex(const CompiledRegex& program)
case CompiledRegex::Save:
printf("save %d\n", inst.param);
break;
case CompiledRegex::Matcher:
printf("matcher %d\n", inst.param);
case CompiledRegex::Class:
printf("class %d\n", inst.param);
break;
case CompiledRegex::CharacterType:
printf("character type %d\n", inst.param);
break;
case CompiledRegex::LineStart:
printf("line start\n");
@ -1084,6 +1073,43 @@ CompiledRegex compile_regex(StringView re, RegexCompileFlags flags, MatchDirecti
return RegexCompiler{RegexParser::parse(re), flags, direction}.get_compiled_regex();
}
bool is_character_class(const CharacterClass& character_class, Codepoint cp)
{
if (character_class.ignore_case)
cp = to_lower(cp);
auto it = std::lower_bound(character_class.ranges.begin(),
character_class.ranges.end(), cp,
[](auto& range, Codepoint cp)
{ return range.max < cp; });
auto found = (it != character_class.ranges.end() and it->min <= cp) or
is_ctype(character_class.ctypes, cp);
return found != character_class.negative;
}
bool is_ctype(CharacterType ctype, Codepoint cp)
{
if ((ctype & CharacterType::Digit) and iswdigit(cp))
return true;
if ((ctype & CharacterType::Word) and is_word(cp))
return true;
if ((ctype & CharacterType::Whitespace) and is_blank(cp))
return true;
if ((ctype & CharacterType::HorizontalWhitespace) and is_horizontal_blank(cp))
return true;
if ((ctype & CharacterType::NotDigit) and not iswdigit(cp))
return true;
if ((ctype & CharacterType::NotWord) and not is_word(cp))
return true;
if ((ctype & CharacterType::NotWhitespace) and not is_blank(cp))
return true;
if ((ctype & CharacterType::NotHorizontalWhitespace) and not is_horizontal_blank(cp))
return true;
return false;
}
namespace
{
template<MatchDirection dir = MatchDirection::Forward>
@ -1257,7 +1283,7 @@ auto test_regex = UnitTest{[]{
}
{
TestVM<> vm{R"((?=foo).)"};
TestVM<> vm{R"((?=fo[\w]).)"};
kak_assert(vm.exec("barfoo", RegexExecFlags::Search));
kak_assert(StringView{vm.captures()[0], vm.captures()[1]} == "f");
}
@ -1274,7 +1300,7 @@ auto test_regex = UnitTest{[]{
}
{
TestVM<> vm{R"(...(?<=f.o))"};
TestVM<> vm{R"(...(?<=f\w.))"};
kak_assert(vm.exec("foo"));
kak_assert(not vm.exec("qux"));
}

View File

@ -23,6 +23,33 @@ enum class MatchDirection
Backward
};
enum class CharacterType : unsigned char
{
None = 0,
Word = 1 << 0,
Whitespace = 1 << 1,
HorizontalWhitespace = 1 << 2,
Digit = 1 << 3,
NotWord = 1 << 4,
NotWhitespace = 1 << 5,
NotHorizontalWhitespace = 1 << 6,
NotDigit = 1 << 7
};
constexpr bool with_bit_ops(Meta::Type<CharacterType>) { return true; }
struct CharacterClass
{
struct Range { Codepoint min, max; };
Vector<Range, MemoryDomain::Regex> ranges;
CharacterType ctypes = CharacterType::None;
bool negative = false;
bool ignore_case = false;
};
bool is_character_class(const CharacterClass& character_class, Codepoint cp);
bool is_ctype(CharacterType ctype, Codepoint cp);
struct CompiledRegex : RefCountable, UseMemoryDomain<MemoryDomain::Regex>
{
enum Op : char
@ -32,7 +59,8 @@ struct CompiledRegex : RefCountable, UseMemoryDomain<MemoryDomain::Regex>
Literal,
Literal_IgnoreCase,
AnyChar,
Matcher,
Class,
CharacterType,
Jump,
Split_PrioritizeParent,
Split_PrioritizeChild,
@ -68,7 +96,7 @@ struct CompiledRegex : RefCountable, UseMemoryDomain<MemoryDomain::Regex>
explicit operator bool() const { return not instructions.empty(); }
Vector<Instruction, MemoryDomain::Regex> instructions;
Vector<std::function<bool (Codepoint)>, MemoryDomain::Regex> matchers;
Vector<CharacterClass, MemoryDomain::Regex> character_classes;
Vector<Codepoint, MemoryDomain::Regex> lookarounds;
MatchDirection direction;
size_t save_count;
@ -289,11 +317,16 @@ private:
thread.saves->pos[inst.param] = get_base(pos);
break;
}
case CompiledRegex::Matcher:
case CompiledRegex::Class:
if (pos == m_end)
return StepResult::Failed;
return m_program.matchers[inst.param](*pos) ?
return is_character_class(m_program.character_classes[inst.param], *pos) ?
StepResult::Consumed : StepResult::Failed;
case CompiledRegex::CharacterType:
if (pos == m_end)
return StepResult::Failed;
return is_ctype((CharacterType)inst.param, *pos) ?
StepResult::Consumed : StepResult::Failed;;
case CompiledRegex::LineStart:
if (not is_line_start(pos))
return StepResult::Failed;
@ -457,9 +490,14 @@ private:
const Codepoint ref = *it;
if (ref == 0xF000)
{} // any character matches
else if (ref > 0xF0000 and ref <= 0xFFFFD)
else if (ref > 0xF0000 and ref < 0xF8000)
{
if (not m_program.matchers[ref - 0xF0001](cp))
if (not is_character_class(m_program.character_classes[ref - 0xF0001], cp))
return false;
}
else if (ref >= 0xF8000 and ref <= 0xFFFFD)
{
if (not is_ctype((CharacterType)(ref & 0xFF), cp))
return false;
}
else if (ref != cp)