"""
The matcher module contains the UrlMatcher class.
"""
from __future__ import annotations
from collections.abc import Iterable, Iterator, Mapping
from dataclasses import dataclass, field
from itertools import chain
from typing import Any
from url_matcher.patterns import PatternMatcher, get_pattern_domain, hierarchical_str
from url_matcher.util import get_domain
[docs]
@dataclass(init=False, frozen=True)
class Patterns:
include: tuple[str, ...]
exclude: tuple[str, ...]
priority: int
[docs]
def __init__(self, include: list[str], exclude: list[str] | None = None, priority: int = 500):
# The initialization is manually set so that we can support an API of
# accepting and returning lists. However, tuples are being used underneath
# that class so that the attributes are truly immutable, in addition to
# being frozen=True.
# Using lists are far less likely to have human typing mistakes compared to
# tuples since the trailing `,` char can easily be missed out. For
# example:
# * ("element") is not the same as ("element",) which is a tuple.
# Lastly, the manner of how we set the attribute values below is in line
# with how Python's own `dataclasses` library assign attributes to frozen
# classes. Here's a reference:
# * https://github.com/python/cpython/blob/v3.10.2/Lib/dataclasses.py#L1117-L1120
object.__setattr__(self, "include", tuple(include))
object.__setattr__(self, "exclude", tuple(exclude or []))
object.__setattr__(self, "priority", priority)
[docs]
def get_domains(self) -> list[str]:
domains = [get_pattern_domain(pattern) for pattern in self.include]
# remove duplicate domains preserving the order
return list(dict.fromkeys(domain for domain in domains if domain))
[docs]
def get_includes_without_domain(self) -> list[str]:
return [pattern for pattern in self.include if get_pattern_domain(pattern) is None]
[docs]
def all_includes_have_domain(self) -> bool:
"""Return true if all the include patterns have a domain"""
return not self.get_includes_without_domain()
[docs]
def is_universal_pattern(self) -> bool:
"""Return true if there are no include patterns or they are empty. A universal pattern matches any domain"""
return not any(pattern for pattern in self.include)
[docs]
def get_includes_for(self, domain: str) -> list[str]:
return [pattern for pattern in self.include if get_pattern_domain(pattern) == domain]
@dataclass
class PatternsMatcher:
identifier: Any
patterns: Patterns
include_matchers: list[PatternMatcher] = field(init=False)
exclude_matchers: list[PatternMatcher] = field(init=False)
def __post_init__(self) -> None:
self.include_matchers = [PatternMatcher(pattern) for pattern in self.patterns.include]
self.exclude_matchers = [PatternMatcher(pattern) for pattern in self.patterns.exclude]
def match(self, url: str) -> bool:
if self.include_matchers:
for include in self.include_matchers:
if include.match(url):
break
else:
return False
return not any(exclude.match(url) for exclude in self.exclude_matchers)
class IncludePatternsWithoutDomainError(ValueError):
def __init__(self, *args: Any, identifier: Any, patterns: Patterns, wrong_patterns: list[str]):
super().__init__(*args)
self.id = identifier
self.patterns = patterns
self.wrong_patterns = wrong_patterns
[docs]
class URLMatcher:
[docs]
def __init__(self, data: Mapping[Any, Patterns] | Iterable[tuple[Any, Patterns]] | None = None):
"""
A class that matches URLs against a list of patterns, returning
the identifier of the rule that matched the URL.
Example usage::
matcher = URLMatcher()
matcher.add_or_update(1, Patterns(include=["example.com/product"]))
matcher.add_or_update(2, Patterns(include=["other.com"]))
assert matcher.match("http://example.com/product/a_product.html") == 1
assert matcher.match("http://other.com/a_different_page") == 2
:param data: A map or a list of tuples with identifier, patterns pairs to
initialize the object from
"""
self.matchers_by_domain: dict[str, list[PatternsMatcher]] = {}
self.matchers_universal: list[PatternsMatcher] = []
self.patterns: dict[Any, Patterns] = {}
if data:
items = data.items() if isinstance(data, Mapping) else data
for identifier, patterns in items:
self.add_or_update(identifier, patterns)
[docs]
def add_or_update(self, identifier: Any, patterns: Patterns) -> None:
if not patterns.all_includes_have_domain() and not patterns.is_universal_pattern():
wrong_patterns = [p for p in patterns.get_includes_without_domain() if p]
raise IncludePatternsWithoutDomainError(
f"All include patterns must belong to a domain "
f"but the patterns {wrong_patterns} doesn't. "
f"For example, the include pattern '/product/* "
f"is invalid whereas the pattern 'example.com/product/*' isn't. "
f"The only exception is the empty pattern which matches everything "
f"and it is allowed. "
f"identifier: {identifier}.",
identifier=identifier,
patterns=patterns,
wrong_patterns=wrong_patterns,
)
if identifier in self.patterns:
self.remove(identifier)
self.patterns[identifier] = patterns
matcher = PatternsMatcher(identifier, patterns)
for domain in patterns.get_domains():
self._add_matcher(domain, matcher)
if patterns.is_universal_pattern():
self._add_matcher("", matcher)
[docs]
def remove(self, identifier: Any) -> None:
patterns = self.patterns.get(identifier)
if not patterns:
return
del self.patterns[identifier]
for domain in patterns.get_domains():
self._del_matcher(domain, identifier)
if patterns.is_universal_pattern():
self._del_matcher("", identifier)
[docs]
def get(self, identifier: Any) -> Patterns | None:
return self.patterns.get(identifier)
[docs]
def match(self, url: str, *, include_universal: bool = True) -> Any | None:
return next(self.match_all(url, include_universal=include_universal), None)
[docs]
def match_all(self, url: str, *, include_universal: bool = True) -> Iterator[Any]:
domain = get_domain(url)
matchers: Iterable[PatternsMatcher] = self.matchers_by_domain.get(domain) or []
if include_universal:
matchers = chain(matchers, self.matchers_universal)
for matcher in matchers:
if matcher.match(url):
yield matcher.identifier
[docs]
def match_universal(self) -> Iterator[Any]:
return (m.identifier for m in self.matchers_universal)
def _sort_domain(self, domain: str) -> None:
"""
Sort all the rules within a domain so that the matching can be done in sequence:
the first rule matching wins.
A total ordering is defined. This is ensured by using including
the identifier in the sorting criteria
Sorting criteria:
* Priority (descending)
* Sorted list of includes for this domain (descending)
* Rule identifier (descending)
"""
def sort_key(matcher: PatternsMatcher) -> tuple[int, list[str], Any]:
sorted_includes = sorted(map(hierarchical_str, matcher.patterns.get_includes_for(domain)))
return (matcher.patterns.priority, sorted_includes, matcher.identifier)
self.matchers_by_domain[domain].sort(key=sort_key, reverse=True)
self.matchers_universal.sort(key=sort_key, reverse=True)
def _del_matcher(self, domain: str, identifier: Any) -> None:
matchers = self.matchers_by_domain[domain]
for idx in range(len(matchers)):
if matchers[idx].identifier == identifier:
del matchers[idx]
break
if not matchers:
del self.matchers_by_domain[domain]
for idx in range(len(self.matchers_universal)):
if self.matchers_universal[idx].identifier == identifier:
del self.matchers_universal[idx]
break
def _add_matcher(self, domain: str, matcher: PatternsMatcher) -> None:
# FIXME: This can be made much more efficient if we insert the data directly in order instead of resorting.
# The bisect module could be used for this purpose.
# I'm leaving it for the future as insertion time is not critical.
self.matchers_by_domain.setdefault(domain, []).append(matcher)
if domain == "":
self.matchers_universal.append(matcher)
self._sort_domain(domain)