Skip to content

Commit

Permalink
Add functionality to exclude certain classes from being chunked (#338)
Browse files Browse the repository at this point in the history
* feat: adds functionality to exclude certain classes from being chunking

* fix: remove comment to hint at default classes to exclude
  • Loading branch information
marvinvr authored Sep 2, 2024
1 parent 259737c commit 5666a92
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions seeds/chunker-python/src/google_labs_html_chunker/html_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# Text within these html tags will be excluded from passages by default.
_DEFAULT_HTML_TAGS_TO_EXCLUDE = frozenset({"noscript", "script", "style"})

# Text within these html classes will be excluded from passages by default.
_DEFAULT_HTML_CLASSES_TO_EXCLUDE = frozenset({})

# Html tags that indicate a section break. Sibling nodes will not be
# greedily-aggregated into a chunk across one of these tags.
_SECTION_BREAK_HTML_TAGS = frozenset({
Expand Down Expand Up @@ -56,19 +59,23 @@ class HtmlChunker:
html_tags_to_exclude: Text within any of the tags in this set will not be
included in the output passages. Defaults to {"noscript", "script",
"style"}.
html_classes_to_exclude: Text within any of the classes in this set will not be
included in the output passages.
"""

def __init__(
self,
max_words_per_aggregate_passage: int,
greedily_aggregate_sibling_nodes: bool,
html_tags_to_exclude: frozenset[str] = _DEFAULT_HTML_TAGS_TO_EXCLUDE,
html_classes_to_exclude: frozenset[str] = _DEFAULT_HTML_CLASSES_TO_EXCLUDE,
) -> None:
self.max_words_per_aggregate_passage = max_words_per_aggregate_passage
self.greedily_aggregate_sibling_nodes = greedily_aggregate_sibling_nodes
self.html_tags_to_exclude = {
tag.strip().lower() for tag in html_tags_to_exclude
}
self.html_classes_to_exclude = {class_.strip().lower() for class_ in html_classes_to_exclude}

class PassageList:
"""A list of text passages."""
Expand Down Expand Up @@ -134,7 +141,13 @@ def _process_node(self, node) -> AggregateNode:
current_node = self.AggregateNode()
if node.name:
current_node.html_tag = node.name
if node.name in self.html_tags_to_exclude or isinstance(node, bs4.Comment):
if (node.name in self.html_tags_to_exclude
or isinstance(node, bs4.Comment)
or (
not isinstance(node, bs4.NavigableString)
and any(cls in self.html_classes_to_exclude for cls in node.get("class", []))
)
):
# Exclude text within these tags.
return current_node

Expand Down Expand Up @@ -215,4 +228,4 @@ def chunk(self, html: str) -> list[str]:
root_agg_node = self._process_node(tree)
if not root_agg_node.get_passages():
root_agg_node.passage_list.add_passage_for_node(root_agg_node)
return root_agg_node.get_passages()
return root_agg_node.get_passages()

0 comments on commit 5666a92

Please sign in to comment.