diff --git a/seeds/chunker-python/src/google_labs_html_chunker/html_chunker.py b/seeds/chunker-python/src/google_labs_html_chunker/html_chunker.py index c83f827f..4bace918 100644 --- a/seeds/chunker-python/src/google_labs_html_chunker/html_chunker.py +++ b/seeds/chunker-python/src/google_labs_html_chunker/html_chunker.py @@ -17,6 +17,9 @@ # Text within these html tags will be excluded from passages by default. _DEFAULT_HTML_TAGS_TO_EXCLUDE = frozenset({"noscript", "script", "style"}) +# Text within these html classes will be excluded from passages by default. +_DEFAULT_HTML_CLASSES_TO_EXCLUDE = frozenset({}) + # Html tags that indicate a section break. Sibling nodes will not be # greedily-aggregated into a chunk across one of these tags. _SECTION_BREAK_HTML_TAGS = frozenset({ @@ -56,6 +59,8 @@ class HtmlChunker: html_tags_to_exclude: Text within any of the tags in this set will not be included in the output passages. Defaults to {"noscript", "script", "style"}. + html_classes_to_exclude: Text within any of the classes in this set will not be + included in the output passages. """ def __init__( @@ -63,12 +68,14 @@ def __init__( max_words_per_aggregate_passage: int, greedily_aggregate_sibling_nodes: bool, html_tags_to_exclude: frozenset[str] = _DEFAULT_HTML_TAGS_TO_EXCLUDE, + html_classes_to_exclude: frozenset[str] = _DEFAULT_HTML_CLASSES_TO_EXCLUDE, ) -> None: self.max_words_per_aggregate_passage = max_words_per_aggregate_passage self.greedily_aggregate_sibling_nodes = greedily_aggregate_sibling_nodes self.html_tags_to_exclude = { tag.strip().lower() for tag in html_tags_to_exclude } + self.html_classes_to_exclude = {class_.strip().lower() for class_ in html_classes_to_exclude} class PassageList: """A list of text passages.""" @@ -134,7 +141,13 @@ def _process_node(self, node) -> AggregateNode: current_node = self.AggregateNode() if node.name: current_node.html_tag = node.name - if node.name in self.html_tags_to_exclude or isinstance(node, bs4.Comment): + if (node.name in self.html_tags_to_exclude + or isinstance(node, bs4.Comment) + or ( + not isinstance(node, bs4.NavigableString) + and any(cls in self.html_classes_to_exclude for cls in node.get("class", [])) + ) + ): # Exclude text within these tags. return current_node @@ -215,4 +228,4 @@ def chunk(self, html: str) -> list[str]: root_agg_node = self._process_node(tree) if not root_agg_node.get_passages(): root_agg_node.passage_list.add_passage_for_node(root_agg_node) - return root_agg_node.get_passages() \ No newline at end of file + return root_agg_node.get_passages()