diff --git a/src/regex_impl.cc b/src/regex_impl.cc index 3197292d..9d7d98ab 100644 --- a/src/regex_impl.cc +++ b/src/regex_impl.cc @@ -22,7 +22,8 @@ struct ParsedRegex { Literal, AnyChar, - Matcher, + Class, + CharacterType, Sequence, Alternation, LineStart, @@ -80,7 +81,8 @@ struct ParsedRegex }; Vector nodes; - Vector, MemoryDomain::Regex> matchers; + + Vector character_classes; size_t capture_count; }; @@ -319,18 +321,9 @@ private: } // CharacterClassEscape - auto class_it = find_if(character_class_escapes, - [cp = to_lower(cp)](auto& c) { return c.cp == cp; }); + auto class_it = find_if(character_class_escapes, [cp](auto& c) { return c.cp == cp; }); if (class_it != std::end(character_class_escapes)) - { - auto matcher_id = m_parsed_regex.matchers.size(); - m_parsed_regex.matchers.push_back( - [ctype = class_it->ctype ? wctype(class_it->ctype) : (wctype_t)0, - chars = class_it->additional_chars, neg = is_upper(cp)] (Codepoint cp) { - return ((ctype != 0 and iswctype(cp, ctype)) or contains(chars, cp)) != neg; - }); - return new_node(ParsedRegex::Matcher, matcher_id); - } + return new_node(ParsedRegex::CharacterType, (Codepoint)class_it->ctype); // CharacterEscape for (auto& control : control_escapes) @@ -383,9 +376,7 @@ private: parse_error(format("unknown atom escape '{}'", cp)); } - struct CharRange { Codepoint min, max; }; - - void normalize_ranges(Vector& ranges) + void normalize_ranges(Vector& ranges) { if (ranges.empty()) return; @@ -411,19 +402,19 @@ private: NodeIndex character_class() { - const bool negative = m_pos != m_regex.end() and *m_pos == '^'; - if (negative) + CharacterClass character_class; + + character_class.ignore_case = m_ignore_case; + character_class.negative = m_pos != m_regex.end() and *m_pos == '^'; + if (character_class.negative) ++m_pos; - Vector ranges; - Vector excluded; - Vector, MemoryDomain::Regex> ctypes; while (m_pos != m_regex.end() and *m_pos != ']') { auto cp = *m_pos++; if (cp == '-') { - ranges.push_back({ '-', '-' }); + character_class.ranges.push_back({ '-', '-' }); continue; } @@ -433,19 +424,10 @@ private: if (cp == '\\') { auto it = find_if(character_class_escapes, - [cp = to_lower(*m_pos)](auto& t) { return t.cp == cp; }); + [cp = *m_pos](auto& t) { return t.cp == cp; }); if (it != std::end(character_class_escapes)) { - auto negative = is_upper(*m_pos); - if (it->ctype) - ctypes.push_back({wctype(it->ctype), not negative}); - for (auto& c : it->additional_chars) - { - if (negative) - excluded.push_back((Codepoint)c); - else - ranges.push_back({(Codepoint)c, (Codepoint)c}); - } + character_class.ctypes |= it->ctype; ++m_pos; continue; } @@ -463,7 +445,7 @@ private: } } - CharRange range = { cp, cp }; + CharacterClass::Range range = { cp, cp }; if (*m_pos == '-') { if (++m_pos == m_regex.end()) @@ -476,11 +458,11 @@ private: } else { - ranges.push_back(range); + character_class.ranges.push_back(range); range = { '-', '-' }; } } - ranges.push_back(range); + character_class.ranges.push_back(range); } if (at_end()) parse_error("unclosed character class"); @@ -488,45 +470,30 @@ private: if (m_ignore_case) { - for (auto& range : ranges) + for (auto& range : character_class.ranges) { range.min = to_lower(range.min); range.max = to_lower(range.max); } - for (auto& cp : excluded) - cp = to_lower(cp); } - normalize_ranges(ranges); + normalize_ranges(character_class.ranges); // Optimize the relatively common case of using a character class to // escape a character, such as [*] - if (ctypes.empty() and excluded.empty() and not negative and - ranges.size() == 1 and ranges.front().min == ranges.front().max) - return new_node(ParsedRegex::Literal, ranges.front().min); + if (character_class.ctypes == CharacterType::None and not character_class.negative and + character_class.ranges.size() == 1 and + character_class.ranges.front().min == character_class.ranges.front().max) + return new_node(ParsedRegex::Literal, character_class.ranges.front().min); - auto matcher = [ranges = std::move(ranges), - ctypes = std::move(ctypes), - excluded = std::move(excluded), - negative, ignore_case = m_ignore_case] (Codepoint cp) { - if (ignore_case) - cp = to_lower(cp); + if (character_class.ctypes != CharacterType::None and not character_class.negative and + character_class.ranges.empty()) + return new_node(ParsedRegex::CharacterType, (Codepoint)character_class.ctypes); - auto it = std::lower_bound(ranges.begin(), ranges.end(), cp, - [](auto& range, Codepoint cp) - { return range.max < cp; }); + auto class_id = m_parsed_regex.character_classes.size(); + m_parsed_regex.character_classes.push_back(std::move(character_class)); - auto found = (it != ranges.end() and it->min <= cp) or - contains_that(ctypes, [cp](auto& c) { - return (bool)iswctype(cp, c.first) == c.second; - }) or (not excluded.empty() and not contains(excluded, cp)); - return negative ? not found : found; - }; - - auto matcher_id = m_parsed_regex.matchers.size(); - m_parsed_regex.matchers.push_back(std::move(matcher)); - - return new_node(ParsedRegex::Matcher, matcher_id); + return new_node(ParsedRegex::Class, class_id); } ParsedRegex::Quantifier quantifier() @@ -612,8 +579,8 @@ private: { for_each_child(m_parsed_regex, index, [this](NodeIndex child_index) { auto& child = get_node(child_index); - if (child.op != ParsedRegex::Literal and child.op != ParsedRegex::Matcher and - child.op != ParsedRegex::AnyChar) + if (child.op != ParsedRegex::Literal and child.op != ParsedRegex::Class and + child.op != ParsedRegex::CharacterType and child.op != ParsedRegex::AnyChar) parse_error("Lookaround can only contain literals, any chars or character classes"); if (child.quantifier.type != ParsedRegex::Quantifier::One) parse_error("Quantifiers cannot be used in lookarounds"); @@ -628,14 +595,12 @@ private: static constexpr struct CharacterClassEscape { Codepoint cp; - const char* ctype; - StringView additional_chars; - bool neg; + CharacterType ctype; } character_class_escapes[] = { - { 'd', "digit", "", false }, - { 'w', "alnum", "_", false }, - { 's', "space", "", false }, - { 'h', nullptr, " \t", false }, + { 'd', CharacterType::Digit }, { 'D', CharacterType::NotDigit }, + { 'w', CharacterType::Word }, { 'W', CharacterType::NotWord }, + { 's', CharacterType::Whitespace }, { 'S', CharacterType::NotWhitespace }, + { 'h', CharacterType::HorizontalWhitespace }, { 'H', CharacterType::NotHorizontalWhitespace }, }; static constexpr struct ControlEscape { @@ -661,7 +626,7 @@ struct RegexCompiler write_search_prefix(); compile_node(0); push_inst(CompiledRegex::Match); - m_program.matchers = m_parsed_regex.matchers; + m_program.character_classes = m_parsed_regex.character_classes; m_program.save_count = m_parsed_regex.capture_count * 2; m_program.direction = direction; m_program.start_chars = compute_start_chars(); @@ -695,8 +660,11 @@ private: case ParsedRegex::AnyChar: push_inst(CompiledRegex::AnyChar); break; - case ParsedRegex::Matcher: - push_inst(CompiledRegex::Matcher, node.value); + case ParsedRegex::Class: + push_inst(CompiledRegex::Class, node.value); + break; + case ParsedRegex::CharacterType: + push_inst(CompiledRegex::CharacterType, node.value); break; case ParsedRegex::Sequence: { @@ -871,8 +839,10 @@ private: : character.value); else if (character.op == ParsedRegex::AnyChar) m_program.lookarounds.push_back(0xF000); - else if (character.op == ParsedRegex::Matcher) + else if (character.op == ParsedRegex::Class) m_program.lookarounds.push_back(0xF0001 + character.value); + else if (character.op == ParsedRegex::CharacterType) + m_program.lookarounds.push_back(0xF8000 | character.value); else kak_assert(false); return true; @@ -915,12 +885,28 @@ private: b = true; start_chars.map[CompiledRegex::StartChars::other] = true; return node.quantifier.allows_none(); - case ParsedRegex::Matcher: - for (Codepoint c = 0; c < CompiledRegex::StartChars::count; ++c) - if (m_program.matchers[node.value](c)) - start_chars.map[c] = true; - start_chars.map[CompiledRegex::StartChars::other] = true; // stay safe + case ParsedRegex::Class: + { + auto& character_class = m_parsed_regex.character_classes[node.value]; + for (Codepoint cp = 0; cp < CompiledRegex::StartChars::count; ++cp) + { + if (is_character_class(character_class, cp)) + start_chars.map[cp] = true; + } + start_chars.map[CompiledRegex::StartChars::other] = true; return node.quantifier.allows_none(); + } + case ParsedRegex::CharacterType: + { + const CharacterType ctype = (CharacterType)node.value; + for (Codepoint cp = 0; cp < CompiledRegex::StartChars::count; ++cp) + { + if (is_ctype(ctype, cp)) + start_chars.map[cp] = true; + } + start_chars.map[CompiledRegex::StartChars::other] = true; + return node.quantifier.allows_none(); + } case ParsedRegex::Sequence: { bool did_not_consume = false; @@ -1015,8 +1001,11 @@ void dump_regex(const CompiledRegex& program) case CompiledRegex::Save: printf("save %d\n", inst.param); break; - case CompiledRegex::Matcher: - printf("matcher %d\n", inst.param); + case CompiledRegex::Class: + printf("class %d\n", inst.param); + break; + case CompiledRegex::CharacterType: + printf("character type %d\n", inst.param); break; case CompiledRegex::LineStart: printf("line start\n"); @@ -1084,6 +1073,43 @@ CompiledRegex compile_regex(StringView re, RegexCompileFlags flags, MatchDirecti return RegexCompiler{RegexParser::parse(re), flags, direction}.get_compiled_regex(); } +bool is_character_class(const CharacterClass& character_class, Codepoint cp) +{ + if (character_class.ignore_case) + cp = to_lower(cp); + + auto it = std::lower_bound(character_class.ranges.begin(), + character_class.ranges.end(), cp, + [](auto& range, Codepoint cp) + { return range.max < cp; }); + + auto found = (it != character_class.ranges.end() and it->min <= cp) or + is_ctype(character_class.ctypes, cp); + + return found != character_class.negative; +} + +bool is_ctype(CharacterType ctype, Codepoint cp) +{ + if ((ctype & CharacterType::Digit) and iswdigit(cp)) + return true; + if ((ctype & CharacterType::Word) and is_word(cp)) + return true; + if ((ctype & CharacterType::Whitespace) and is_blank(cp)) + return true; + if ((ctype & CharacterType::HorizontalWhitespace) and is_horizontal_blank(cp)) + return true; + if ((ctype & CharacterType::NotDigit) and not iswdigit(cp)) + return true; + if ((ctype & CharacterType::NotWord) and not is_word(cp)) + return true; + if ((ctype & CharacterType::NotWhitespace) and not is_blank(cp)) + return true; + if ((ctype & CharacterType::NotHorizontalWhitespace) and not is_horizontal_blank(cp)) + return true; + return false; +} + namespace { template @@ -1257,7 +1283,7 @@ auto test_regex = UnitTest{[]{ } { - TestVM<> vm{R"((?=foo).)"}; + TestVM<> vm{R"((?=fo[\w]).)"}; kak_assert(vm.exec("barfoo", RegexExecFlags::Search)); kak_assert(StringView{vm.captures()[0], vm.captures()[1]} == "f"); } @@ -1274,7 +1300,7 @@ auto test_regex = UnitTest{[]{ } { - TestVM<> vm{R"(...(?<=f.o))"}; + TestVM<> vm{R"(...(?<=f\w.))"}; kak_assert(vm.exec("foo")); kak_assert(not vm.exec("qux")); } diff --git a/src/regex_impl.hh b/src/regex_impl.hh index 74294f4c..6c687e4f 100644 --- a/src/regex_impl.hh +++ b/src/regex_impl.hh @@ -23,6 +23,33 @@ enum class MatchDirection Backward }; +enum class CharacterType : unsigned char +{ + None = 0, + Word = 1 << 0, + Whitespace = 1 << 1, + HorizontalWhitespace = 1 << 2, + Digit = 1 << 3, + NotWord = 1 << 4, + NotWhitespace = 1 << 5, + NotHorizontalWhitespace = 1 << 6, + NotDigit = 1 << 7 +}; +constexpr bool with_bit_ops(Meta::Type) { return true; } + +struct CharacterClass +{ + struct Range { Codepoint min, max; }; + + Vector ranges; + CharacterType ctypes = CharacterType::None; + bool negative = false; + bool ignore_case = false; +}; + +bool is_character_class(const CharacterClass& character_class, Codepoint cp); +bool is_ctype(CharacterType ctype, Codepoint cp); + struct CompiledRegex : RefCountable, UseMemoryDomain { enum Op : char @@ -32,7 +59,8 @@ struct CompiledRegex : RefCountable, UseMemoryDomain Literal, Literal_IgnoreCase, AnyChar, - Matcher, + Class, + CharacterType, Jump, Split_PrioritizeParent, Split_PrioritizeChild, @@ -68,7 +96,7 @@ struct CompiledRegex : RefCountable, UseMemoryDomain explicit operator bool() const { return not instructions.empty(); } Vector instructions; - Vector, MemoryDomain::Regex> matchers; + Vector character_classes; Vector lookarounds; MatchDirection direction; size_t save_count; @@ -289,11 +317,16 @@ private: thread.saves->pos[inst.param] = get_base(pos); break; } - case CompiledRegex::Matcher: + case CompiledRegex::Class: if (pos == m_end) return StepResult::Failed; - return m_program.matchers[inst.param](*pos) ? + return is_character_class(m_program.character_classes[inst.param], *pos) ? StepResult::Consumed : StepResult::Failed; + case CompiledRegex::CharacterType: + if (pos == m_end) + return StepResult::Failed; + return is_ctype((CharacterType)inst.param, *pos) ? + StepResult::Consumed : StepResult::Failed;; case CompiledRegex::LineStart: if (not is_line_start(pos)) return StepResult::Failed; @@ -457,9 +490,14 @@ private: const Codepoint ref = *it; if (ref == 0xF000) {} // any character matches - else if (ref > 0xF0000 and ref <= 0xFFFFD) + else if (ref > 0xF0000 and ref < 0xF8000) { - if (not m_program.matchers[ref - 0xF0001](cp)) + if (not is_character_class(m_program.character_classes[ref - 0xF0001], cp)) + return false; + } + else if (ref >= 0xF8000 and ref <= 0xFFFFD) + { + if (not is_ctype((CharacterType)(ref & 0xFF), cp)) return false; } else if (ref != cp)