From 1a2accdc8f40f93a12ff3d2786316b3c631baf40 Mon Sep 17 00:00:00 2001 From: Tim Semenov Date: Mon, 2 Sep 2024 07:38:25 -0700 Subject: [PATCH] Add FineWeb-Edu dataset to TFDS. PiperOrigin-RevId: 670214823 --- tensorflow_datasets/core/dataset_builder.py | 2 +- .../core/download/download_manager.py | 264 +++++++++--------- .../core/download/downloader.py | 26 +- tensorflow_datasets/core/download/resource.py | 11 +- .../core/download/resource_test.py | 6 +- .../datasets/fineweb_edu/CITATIONS.bib | 8 + .../datasets/fineweb_edu/README.md | 3 + .../datasets/fineweb_edu/TAGS.txt | 6 + .../datasets/fineweb_edu/__init__.py | 15 + .../datasets/fineweb_edu/checksums.tsv | 3 + .../fineweb_edu/dummy_data/data.parquet | Bin 0 -> 33638 bytes .../fineweb_edu_dataset_builder.py | 102 +++++++ .../fineweb_edu_dataset_builder_test.py | 35 +++ 13 files changed, 326 insertions(+), 155 deletions(-) create mode 100644 tensorflow_datasets/datasets/fineweb_edu/CITATIONS.bib create mode 100644 tensorflow_datasets/datasets/fineweb_edu/README.md create mode 100644 tensorflow_datasets/datasets/fineweb_edu/TAGS.txt create mode 100644 tensorflow_datasets/datasets/fineweb_edu/__init__.py create mode 100644 tensorflow_datasets/datasets/fineweb_edu/checksums.tsv create mode 100644 tensorflow_datasets/datasets/fineweb_edu/dummy_data/data.parquet create mode 100644 tensorflow_datasets/datasets/fineweb_edu/fineweb_edu_dataset_builder.py create mode 100644 tensorflow_datasets/datasets/fineweb_edu/fineweb_edu_dataset_builder_test.py diff --git a/tensorflow_datasets/core/dataset_builder.py b/tensorflow_datasets/core/dataset_builder.py index f38bfe6a6f0..5262d6f0c32 100644 --- a/tensorflow_datasets/core/dataset_builder.py +++ b/tensorflow_datasets/core/dataset_builder.py @@ -1315,7 +1315,7 @@ def _make_download_manager( ) return download.DownloadManager( - download_dir=download_dir, + download_dir=download_dir / self.name, extract_dir=extract_dir, manual_dir=manual_dir, url_infos=self.url_infos, diff --git a/tensorflow_datasets/core/download/download_manager.py b/tensorflow_datasets/core/download/download_manager.py index a5a1bf284b9..43221d00c4d 100644 --- a/tensorflow_datasets/core/download/download_manager.py +++ b/tensorflow_datasets/core/download/download_manager.py @@ -316,9 +316,6 @@ def downloaded_size(self): """Returns the total size of downloaded files.""" return sum(url_info.size for url_info in self._recorded_url_infos.values()) - def _get_dl_path(self, url: str, sha256: str) -> epath.Path: - return self._download_dir / resource_lib.get_dl_fname(url, sha256) - @property def register_checksums(self): """Returns whether checksums are being computed and recorded to file.""" @@ -341,7 +338,7 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]: This function: - 1. Reuse cache (`_get_cached_path`) or download the file + 1. Reuse cache (`downloader.get_cached_path`) or download the file 2. Register or validate checksums (`_register_or_validate_checksums`) 3. Rename download to final path (`_rename_and_get_final_dl_path`) @@ -352,37 +349,39 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]: path: The path to the downloaded resource. """ # Normalize the input - if isinstance(resource, str): - url = resource - else: - url = resource.url + if not isinstance(resource, resource_lib.Resource): + resource = resource_lib.Resource(url=resource) + url = resource.url assert url is not None, 'URL is undefined from resource.' - expected_url_info = self._url_infos.get(url) + registered_url_info = self._url_infos.get(url) # 3 possible destinations for the path: # * In `manual_dir` (manually downloaded data) - # * In `downloads/url_path` (checksum unknown) - # * In `downloads/checksum_path` (checksum registered) + # * In `downloads/unregistered_path` (checksum unknown) + # * In `downloads/registered_path` (checksum registered) manually_downloaded_path = _get_manually_downloaded_path( manual_dir=self._manual_dir, - expected_url_info=expected_url_info, + url_info=registered_url_info, ) - url_path = self._get_dl_path( - url, sha256=hashlib.sha256(url.encode('utf-8')).hexdigest() - ) - checksum_path = ( - self._get_dl_path(url, sha256=expected_url_info.checksum) - if expected_url_info - else None + download_dir = self._download_dir / resource.relative_download_dir + download_dir.mkdir(parents=True, exist_ok=True) + unregistered_path = download_dir / resource_lib.get_dl_fname( + url=url, checksum=hashlib.sha256(url.encode('utf-8')).hexdigest() ) + if registered_url_info: + registered_path = download_dir / resource_lib.get_dl_fname( + url=url, checksum=registered_url_info.checksum + ) + else: + registered_path = None # Get the cached path and url_info (if they exists) dl_result = downloader.get_cached_path( manually_downloaded_path=manually_downloaded_path, - checksum_path=checksum_path, - url_path=url_path, - expected_url_info=expected_url_info, + registered_path=registered_path, + unregistered_path=unregistered_path, + registered_url_info=registered_url_info, ) if dl_result.path and not self._force_download: # Download was cached logging.info( @@ -394,8 +393,10 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]: else: # Download in an empty tmp directory (to avoid name collisions) # `download_tmp_dir` is cleaned-up in `_rename_and_get_final_dl_path` - dirname = f'{resource_lib.get_dl_dirname(url)}.tmp.{uuid.uuid4().hex}' - download_tmp_dir = self._download_dir / dirname + download_tmp_dir = ( + unregistered_path.parent + / f'{unregistered_path.name}.tmp.{uuid.uuid4().hex}' + ) download_tmp_dir.mkdir() logging.info(f'Downloading {url} into {download_tmp_dir}...') future = self._downloader.download( @@ -403,121 +404,155 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]: ) # Post-process the result - return future.then( - lambda dl_result: self._register_or_validate_checksums( # pylint: disable=g-long-lambda - url=url, - path=dl_result.path, - computed_url_info=dl_result.url_info, - expected_url_info=expected_url_info, - checksum_path=checksum_path, - url_path=url_path, - ) - ) + def callback(dl_result: downloader.DownloadResult) -> epath.Path: + return self._register_or_validate_checksums( + url=url, + dl_url_info=dl_result.url_info, + registered_url_info=registered_url_info, + dl_path=dl_result.path, + registered_path=registered_path, + unregistered_path=unregistered_path, + ) + + return future.then(callback) def _register_or_validate_checksums( self, - path: epath.Path, url: str, - expected_url_info: checksums.UrlInfo | None, - computed_url_info: checksums.UrlInfo | None, - checksum_path: epath.Path | None, - url_path: epath.Path, + dl_url_info: checksums.UrlInfo | None, + registered_url_info: checksums.UrlInfo | None, + dl_path: epath.Path, + registered_path: epath.Path | None, + unregistered_path: epath.Path, ) -> epath.Path: """Validates/records checksums and renames final downloaded path.""" - # `path` can be: - # * Manually downloaded - # * (cached) checksum_path - # * (cached) url_path - # * `tmp_dir/file` (downloaded path) - - if computed_url_info: + if dl_url_info: # Used both in `.downloaded_size` and `_record_url_infos()` - self._recorded_url_infos[url] = computed_url_info + self._recorded_url_infos[url] = dl_url_info if self._register_checksums: - if not computed_url_info: + if not dl_url_info: raise ValueError( f'Cannot register checksums for {url}: no computed checksum. ' '--register_checksums with manually downloaded data not supported.' ) # Note: - # * We save even if `expected_url_info == computed_url_info` as - # `expected_url_info` might have been loaded from another dataset. + # * We save even if `registered_url_info == dl_url_info` as + # `registered_url_info` might have been loaded from another dataset. # * `register_checksums_path` was validated in `__init__` so this # shouldn't fail. self._record_url_infos() # Checksum path should now match the new registered checksum (even if # checksums were previously registered) - expected_url_info = computed_url_info - checksum_path = self._get_dl_path(url, computed_url_info.checksum) + registered_url_info = dl_url_info + registered_path = unregistered_path.parent / resource_lib.get_dl_fname( + url, dl_url_info.checksum + ) else: # Eventually validate checksums # Note: - # * If path is cached at `url_path` but cached - # `computed_url_info != expected_url_info`, a new download has - # been triggered (as _get_cached_path returns None) + # * If path is cached at `unregistered_path` but + # `dl_url_info != registered_url_info`, a new download has + # been triggered (as `downloader.get_cached_path` returns None) # * If path was downloaded but checksums don't match expected, then # the download isn't cached (re-running build will retrigger a new # download). This is expected as it might mean the downloaded file # was corrupted. Note: The tmp file isn't deleted to allow inspection. - _validate_checksums( + self._validate_checksums( url=url, - path=path, - expected_url_info=expected_url_info, - computed_url_info=computed_url_info, - force_checksums_validation=self._force_checksums_validation, + dl_url_info=dl_url_info, + registered_url_info=registered_url_info, + dl_path=dl_path, ) return self._rename_and_get_final_dl_path( url=url, - path=path, - expected_url_info=expected_url_info, - computed_url_info=computed_url_info, - checksum_path=checksum_path, - url_path=url_path, + dl_url_info=dl_url_info, + registered_url_info=registered_url_info, + dl_path=dl_path, + registered_path=registered_path, + unregistered_path=unregistered_path, ) + def _validate_checksums( + self, + url: str, + dl_url_info: checksums.UrlInfo | None, + registered_url_info: checksums.UrlInfo | None, + dl_path: epath.Path, + ) -> None: + """Validate cached_url_info match url_info.""" + # If force-checksums validations, both downloaded and registered url_info + # should exists + if self._force_checksums_validation: + # Checksum of the downloaded file unknown (for manually downloaded file) + if not dl_url_info: + dl_url_info = checksums.compute_url_info(dl_path) + # Checksums have not been registered + if not registered_url_info: + raise ValueError( + f'Missing checksums url: {url}, yet ' + '`force_checksums_validation=True`. ' + 'Did you forget to register checksums?' + ) + + if ( + registered_url_info + and dl_url_info + and registered_url_info != dl_url_info + ): + msg = ( + f'Artifact {url}, downloaded to {dl_path}, has wrong checksum:\n' + f'* Expected: {registered_url_info}\n' + f'* Got: {dl_url_info}\n' + 'To debug, see: ' + 'https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror' + ) + raise NonMatchingChecksumError(msg) + def _rename_and_get_final_dl_path( self, url: str, - path: epath.Path, - expected_url_info: checksums.UrlInfo | None, - computed_url_info: checksums.UrlInfo | None, - checksum_path: epath.Path | None, - url_path: epath.Path, + dl_url_info: checksums.UrlInfo | None, + registered_url_info: checksums.UrlInfo | None, + dl_path: epath.Path, + registered_path: epath.Path | None, + unregistered_path: epath.Path, ) -> epath.Path: """Eventually rename the downloaded file if checksums were recorded.""" - # `path` can be: - # * Manually downloaded - # * (cached) checksum_path - # * (cached) url_path - # * `tmp_dir/file` (downloaded path) - if self._manual_dir and path.is_relative_to(self._manual_dir): - return path # Manually downloaded data - elif path == checksum_path: # Path already at final destination - assert computed_url_info == expected_url_info # Sanity check - return checksum_path # pytype: disable=bad-return-type - elif path == url_path: - if checksum_path: + # Manually downloaded data + if self._manual_dir and dl_path.is_relative_to(self._manual_dir): + return dl_path + + # Cached at the final destination + elif dl_path == registered_path: + assert dl_url_info == registered_url_info # Sanity check + return dl_path + + # Cached at the tmp destination + elif dl_path == unregistered_path: + if registered_path: # Checksums were registered: Rename -> checksums_path - resource_lib.replace_info_file(path, checksum_path) - return path.replace(checksum_path) + resource_lib.replace_info_file(dl_path, registered_path) + return dl_path.replace(registered_path) else: # Checksums not registered: -> do nothing - return path - else: # Path was downloaded in tmp dir - dst_path = checksum_path or url_path + return dl_path + + # Downloaded at the tmp destination + else: + path = registered_path or unregistered_path resource_lib.write_info_file( url=url, - path=dst_path, + path=path, dataset_name=self._dataset_name, - original_fname=path.name, - url_info=computed_url_info, + original_fname=dl_path.name, + url_info=dl_url_info, ) - path.replace(dst_path) - path.parent.rmdir() # Cleanup tmp dir (will fail if dir not empty) - return dst_path + dl_path.replace(path) + dl_path.parent.rmdir() # Cleanup tmp dir (will fail if dir not empty) + return path @utils.build_synchronize_decorator() @utils.memoize() @@ -711,59 +746,22 @@ def manual_dir(self) -> epath.Path: def _get_manually_downloaded_path( manual_dir: epath.Path | None, - expected_url_info: checksums.UrlInfo | None, + url_info: checksums.UrlInfo | None, ) -> epath.Path | None: """Checks if file is already downloaded in manual_dir.""" if not manual_dir: # Manual dir not passed return None - if not expected_url_info or not expected_url_info.filename: + if not url_info or not url_info.filename: return None # Filename unknown. - manual_path = manual_dir / expected_url_info.filename + manual_path = manual_dir / url_info.filename if not manual_path.exists(): # File not manually downloaded return None return manual_path -def _validate_checksums( - url: str, - path: epath.Path, - computed_url_info: checksums.UrlInfo | None, - expected_url_info: checksums.UrlInfo | None, - force_checksums_validation: bool, -) -> None: - """Validate computed_url_info match expected_url_info.""" - # If force-checksums validations, both expected and computed url_info - # should exists - if force_checksums_validation: - # Checksum of the downloaded file unknown (for manually downloaded file) - if not computed_url_info: - computed_url_info = checksums.compute_url_info(path) - # Checksums have not been registered - if not expected_url_info: - raise ValueError( - f'Missing checksums url: {url}, yet ' - '`force_checksums_validation=True`. ' - 'Did you forget to register checksums?' - ) - - if ( - expected_url_info - and computed_url_info - and expected_url_info != computed_url_info - ): - msg = ( - f'Artifact {url}, downloaded to {path}, has wrong checksum:\n' - f'* Expected: {expected_url_info}\n' - f'* Got: {computed_url_info}\n' - 'To debug, see: ' - 'https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror' - ) - raise NonMatchingChecksumError(msg) - - def _map_promise(map_fn, all_inputs): """Map the function into each element and resolve the promise.""" all_promises = tree.map_structure(map_fn, all_inputs) # Apply the function diff --git a/tensorflow_datasets/core/download/downloader.py b/tensorflow_datasets/core/download/downloader.py index 3faa0deda70..112a6ab5079 100644 --- a/tensorflow_datasets/core/download/downloader.py +++ b/tensorflow_datasets/core/download/downloader.py @@ -77,9 +77,9 @@ def _read_url_info(url_path: epath.Path) -> checksums_lib.UrlInfo: def get_cached_path( manually_downloaded_path: epath.Path | None, - checksum_path: epath.Path | None, - url_path: epath.Path, - expected_url_info: checksums_lib.UrlInfo | None, + registered_path: epath.Path | None, + unregistered_path: epath.Path, + registered_url_info: checksums_lib.UrlInfo | None, ) -> DownloadResult: """Returns the downloaded path and computed url-info. @@ -90,29 +90,31 @@ def get_cached_path( Args: manually_downloaded_path: Manually downloaded in `dl_manager.manual_dir` - checksum_path: Cached in the final destination (if checksum known) - url_path: Cached in the tmp destination (if checksum unknown). - expected_url_info: Registered checksum (if known) + registered_path: Cached at the final destination (if checksum known) + unregistered_path: Cached at the tmp destination (if checksum unknown). + registered_url_info: Registered checksum (if known) """ # User has manually downloaded the file. if manually_downloaded_path and manually_downloaded_path.exists(): return DownloadResult(path=manually_downloaded_path, url_info=None) # Download has been cached (checksum known) - elif checksum_path and resource_lib.Resource.exists_locally(checksum_path): + elif registered_path and resource_lib.Resource.exists_locally( + registered_path + ): # `path = f(checksum)` was found, so url_info match - return DownloadResult(checksum_path, url_info=expected_url_info) + return DownloadResult(path=registered_path, url_info=registered_url_info) # Download has been cached (checksum unknown) - elif resource_lib.Resource.exists_locally(url_path): + elif resource_lib.Resource.exists_locally(unregistered_path): # Info restored from `.INFO` file - computed_url_info = _read_url_info(url_path) + url_info = _read_url_info(unregistered_path) # If checksums are now registered but do not match, trigger a new # download (e.g. previous file corrupted, checksums updated) - if expected_url_info and computed_url_info != expected_url_info: + if registered_url_info and url_info != registered_url_info: return DownloadResult(path=None, url_info=None) else: - return DownloadResult(path=url_path, url_info=computed_url_info) + return DownloadResult(path=unregistered_path, url_info=url_info) # Else file not found (or has bad checksums). (re)download. else: diff --git a/tensorflow_datasets/core/download/resource.py b/tensorflow_datasets/core/download/resource.py index 545f842ae85..2c7266ab82e 100644 --- a/tensorflow_datasets/core/download/resource.py +++ b/tensorflow_datasets/core/download/resource.py @@ -19,7 +19,6 @@ import codecs from collections.abc import Mapping import enum -import hashlib import itertools import json import os @@ -191,12 +190,6 @@ def get_dl_fname(url: str, checksum: str) -> str: return f'{name}{checksum}{extension}' -def get_dl_dirname(url: str) -> str: - """Returns name of temp dir for given url.""" - checksum = hashlib.sha256(url.encode()).hexdigest() - return get_dl_fname(url, checksum) - - def _get_info_path(path: epath.Path) -> epath.Path: """Returns path of INFO file associated with resource at path.""" return path.with_suffix(path.suffix + '.INFO') @@ -290,6 +283,7 @@ def __init__( url: str | None = None, extract_method: ExtractMethod | None = None, path: epath.PathLike | None = None, + relative_download_dir: epath.PathLike = '', ): """Resource constructor. @@ -299,10 +293,13 @@ def __init__( set, will be guessed from downloaded file name `original_fname`. path: Path of resource on local disk. Can be None if resource has not be downloaded yet. In such case, `url` must be set. + relative_download_dir: Optional directory for downloading relative to + `download_dir`. """ self.url = url self._extract_method = extract_method self.path: epath.Path = epath.Path(path) if path else None # pytype: disable=annotation-type-mismatch # attribute-variable-annotations + self.relative_download_dir = relative_download_dir @classmethod def exists_locally(cls, path: epath.Path) -> bool: diff --git a/tensorflow_datasets/core/download/resource_test.py b/tensorflow_datasets/core/download/resource_test.py index 42eea8ef6b8..c63cc53ec21 100644 --- a/tensorflow_datasets/core/download/resource_test.py +++ b/tensorflow_datasets/core/download/resource_test.py @@ -13,7 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for resource module.""" +import hashlib + from tensorflow_datasets import testing from tensorflow_datasets.core.download import resource @@ -85,7 +86,8 @@ class DlDirNameTest(testing.TestCase): def test_(self): for url, expected in zip(self.urls, self.expected): - res = resource.get_dl_dirname(url) + checksum = hashlib.sha256(url.encode()).hexdigest() + res = resource.get_dl_fname(url, checksum) self.assertEqual(res, expected) diff --git a/tensorflow_datasets/datasets/fineweb_edu/CITATIONS.bib b/tensorflow_datasets/datasets/fineweb_edu/CITATIONS.bib new file mode 100644 index 00000000000..c2c0d205978 --- /dev/null +++ b/tensorflow_datasets/datasets/fineweb_edu/CITATIONS.bib @@ -0,0 +1,8 @@ +@software{lozhkov2024fineweb-edu, + author = {Lozhkov, Anton and Ben Allal, Loubna and von Werra, Leandro and Wolf, Thomas}, + title = {FineWeb-Edu}, + month = May, + year = 2024, + doi = { 10.57967/hf/2497 }, + url = {https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu} +} diff --git a/tensorflow_datasets/datasets/fineweb_edu/README.md b/tensorflow_datasets/datasets/fineweb_edu/README.md new file mode 100644 index 00000000000..1928ff12ed1 --- /dev/null +++ b/tensorflow_datasets/datasets/fineweb_edu/README.md @@ -0,0 +1,3 @@ +📚 FineWeb-Edu dataset consists of 1.3T tokens and 5.4T tokens +(FineWeb-Edu-score-2) of educational web pages filtered from 🍷 FineWeb dataset. +This is the 1.3 trillion version. diff --git a/tensorflow_datasets/datasets/fineweb_edu/TAGS.txt b/tensorflow_datasets/datasets/fineweb_edu/TAGS.txt new file mode 100644 index 00000000000..4e9a3399259 --- /dev/null +++ b/tensorflow_datasets/datasets/fineweb_edu/TAGS.txt @@ -0,0 +1,6 @@ +content.data-type.tabular # Contains tabular data. +content.data-type.text # Contains text data. +content.language.en # Contains text in language English / en. +content.monolingual # Contains text in 1 natural language. +ml.task.language-modelling # Relates to Language Modelling, a machine learning task. +ml.task.text-generation # Relates to Text Generation, a machine learning task. diff --git a/tensorflow_datasets/datasets/fineweb_edu/__init__.py b/tensorflow_datasets/datasets/fineweb_edu/__init__.py new file mode 100644 index 00000000000..5310ec58c7d --- /dev/null +++ b/tensorflow_datasets/datasets/fineweb_edu/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2024 The TensorFlow Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tensorflow_datasets/datasets/fineweb_edu/checksums.tsv b/tensorflow_datasets/datasets/fineweb_edu/checksums.tsv new file mode 100644 index 00000000000..9181fe6c64f --- /dev/null +++ b/tensorflow_datasets/datasets/fineweb_edu/checksums.tsv @@ -0,0 +1,3 @@ +# TODO(finewebedu): If your dataset downloads files, then the checksums +# will be automatically added here when running +# `tfds build --register_checksums`. diff --git a/tensorflow_datasets/datasets/fineweb_edu/dummy_data/data.parquet b/tensorflow_datasets/datasets/fineweb_edu/dummy_data/data.parquet new file mode 100644 index 0000000000000000000000000000000000000000..3c0d58c20d61f1f3546e6525893f92844f98b4b1 GIT binary patch literal 33638 zcmeHQe{dXEecwIbnUgFlijl;*RbZu97rQ=5cakMrP6AoU`sG-PEZh23g1fi7cemEw z?qzpRr!#RNOh||cnNlbT!;fOxq@)d#goXf_OdB(#Ll~MF7@++FS~3F^C{TWsDNv@< z&-d-#58a*ZxOFm_vN!f=_wC#FzVG|}zVG{d-oAY|FD@iHf*oHP?ReLoBOQSbZU@JG zx9g|>D7A9joY^l7?z>(%ky9i`RfJi|k~6|!JdtP~*EM-f$f%YoIHsVO1w%I_MUV_d z$XS}<2$C>=Xi`{LEn71U!AuEGMzfJo&SzD_>2IEwEc~?v{1NhYUefgXM^uhs2Gg5h7NTE>F zC-G0U^b5x`SN~PScGR4p+5FP3G+IHTTSdOL#gX*2P&kZcgg$dazaSby#xzv4$IPil zQ_nsrY33avrK-ASr0vjMZDz_*jnMnT{la|cwwNtsr6SsBnzqUx7i6>tatHkE5)bP2ynB-GVB}x@x71vF^R|CKLv#6%=3#g^Vez8D?SU zLYq)fbv^Wxpsj{}*mQ5iQf*UTS3|qQP)PUPL0jnV?s#|a^|z;_tfm+HnO?+HTdPrMmUR)kMj3bGwjai$6d)3S=)H#{yG8=C4AV*)(J&cI8BWD%Z~ z&g+tezp1

%-a7Ccd}d#l@~}ZP8UV$4|hPYpS@&Is9EQNp`e#&Cyg_aC%?Ug7OXi zQ!{THeoGsSBBl8+v`t#-ZTa5c>;2U{%9G7(PKWC_HUye3q8-~$#EaL%wUxB$K$X`| zh`SsK+R2*hYL*{sPpU=JP&%Fsna>GXWANXN!#{Gc_(99np1?J>3ox!cwKu;$Q$ZLZ<`>Yr@`g2IsT8@ zA-I*#@*hM`uvEka!{%>pf%>!j`?xi6Uiw6fs;8Py%nImo@N-qrR_6olGlrDZbpGB5 zv@N4M^zP%I)&v$X6bYVo_uAo&88a(RwjWbB5MDwL#Nfl| zez3_QHJ5nc#5Z@@dE0@VR3*lLqxX|-ieRYeJrA^GO;h3jM!Y(rIeA5m3HLr9lw|Qw z+ayC%6{&~INgINc%fXb68tXqAXtCikpR7;H21Fw8rjgo8>>rg&#aXrk@}Def}ofLkS-CJgU-n%S*o9jFJ%TPGzn9uL%UFm|Y`!wTJ{@$PPeG!=4hI8Eca7-;n z;pV=4PQ0NdnYZkwlj2u`R_x4+{YQAtQH!C^#jLy`NU$2(D!ymC^u2aVFifNGy>AGv ztEkF;&#hsE@;#H?8fSG6hraUM3qkw(xAX+NAg`52Y~Y*PGHS}Pzq>_Phs!{9c2-OV z9gBY|C^)ls?217WT~kwb?3W)8?Q6*!dAs+izF?c(d`wlF-qt2eqI>_Km79E#+l3x} zu=(1-VPWa@eNEi4);q%7bn9*591Q>~X5JFUY5Zxvc6d-2M<=(`40-=AxcP7WjGOrS z&)RWizcBZFgcE=L`p6-KBHhfLdjTTMYKnp$bnLm`a53@G-NJ-qVbGX{8MeP}m#_l! z*|_(>9dlgE5p=1sgIwqYH}>5AUCr0V3r10?>Fc51YaZJ9)v)a}w@x;- ziCNXsT2D7cL>aL$TjX!(y()Q-+g%iPi{iLAzelj%`A`ti-R}*LN*s<+=F{tEUv_n_j`P=SmBQ&-aEbYbr;A63@#>RFS zKku|Y(;PNThd6pyYG9)@n6#3jc3m1Y2He2OyG+V)pr~PrmU#xz;n=T4bGU;(ur)c%9uLw|0fQ>>O|n@XFxdasUyN z=f4`k5PDl)wZ-Nz0E77F4_$?kOm=Q;Ylp{X^9KL*-3%v13C*oNchlS9$)=^iwvK60 zNI=sQ!~6Q1XHr7Z%p)Xr|23BtzbySADB0&St=DbCAerTq?(@QEBGKPGZsv*)SzQ8b z4PZDZn!WG;`{%cC{T*D#j*iY45M^7(){d=RBi!y6yZ%1{N`enB5mCDExCT@b3?@`r zg1#QKWT*s}Tx>~a1TNWMWMw$G7GSzLHe*bNY8lF8AQw2(!`&66NoAIyCShrU;jDs7 zK~5Lemr*BS?}o6mQX)g2thKf(P)l9N>!DDV&x7?gL)~RjnCdQtNeSozlrj|U!BUSs zNFH+_esXbiip~K>flh>eU&g1buT%q6!sr)4s0_rqc^Fc4Acj>W5Pc!6O1ir&uqxQB z{Fn<}1%e3L zbD#>Or4n3BTL#-ewoCzxtwB!uO98t~w}5or91QFdnS(s@#2?`&D6b582~}66SA$=s zMWA1o*p&b<+zQmLjDX!f)`=u?*f z-vdb`YcjO&()v6L8+n^b2IObD2hTMHou(l?t^mDMb-7Jl(e-$ySMu+^lGC%lD29y) zOtB@Z0Jcs#!Kud=h+bC{hDsPhgdpPJnho4g1wvnKR1%l-266^&BOz>(>;pHW! z>oSYd-AJK9&kHW#7&_A;DcT}-#2Vzve!a(%>mo#ngJ6CEJpv>Iu-2;35Z9SV>azD$ z$Vhn>h}}c1gjvm-x?<$S8%&9xxTE1D&B;_aN@?~DJZ4RprYWq8@)XkXeART!ggYHW zftJx0cU%xFb2lD=q4~NBgYiZk3O*GQBOG%BSeeKOa~Cg-NDiYv)pxSaKE|d(DzpdF zF7EuHT?#sochUqK z4X%|LI=6o!3V4i)v|l8i$9lY?uy&!JGG0Y}&l@kZn8%Tk?gbA7Tg))J6pKyf?P1Me z@QfWbLNhdNSLZa0jIjwdrcsxWF`*TgsFqln8+1#=jN%l6EaQ?cHABD2vB;QW)R`vM zh}$vL1<+(chs4wfE}=Jo!`N(&6~`))SuaETfUU|H0Hlp3xF~NmZzJpq-bP5QMBHfb zDKj_ER$fJ5Z>*D}_LRrpJo&tdU1D*_OPGpKaF{fnqCFbuMt0Y=NI6iG!r%a#d2{(7 zowK`u;!0IeBTLk9ll0(-kdEeAi6pXdL&%YfvT#UDu}U8gW>29?Y}>O~gu-;euWvkSyYsC?vTXvkA>?TwX`+kN6yu`M1w8iBZF;eU8Z> zE_)tAamnYHjD`9f(|xMub4;IO`W(~enDunQdk*GvOaODAW7e zB67Jm)I7$$|NVi=W88qIu!p$s*wT4+OZ4oY1fqAnH$YEuFSnA*m#U$DqaA-x((h}h zuY%;`<9)N@%v>K{xgEk+1N~i9lO6=$jea3Ne2L}eqFHR5Sq_o;!D7?v#vg(aXc(;D9WA3H5 zTlUi1EqgIS@1?i!_p;YHqPB|nhD_r|Hbu-XmWJ;J>R5T%DvP{)Mc0_&soX*>$pL#41z56j{l+yhT zMQH12FNqNDLKPpo>hT}HsaA+cxOAm*)XD8`kXj?;mp=s|o4r+5ttBB(+Olc6qJEa| ze3p+s^F$ze-xGnZO6irO3)FMz#{7a5icfuGN>?J`(h*fsINr>6Wl&aNBOGZ%OG2Qx{IQ-Q=l@9amCu zTHTW0e74Fgp;~MF@H2tv)1UFIQB%+biobLlwxA&T)VHB@B@!+jQ6;70N-A!yZo@mC zsj@~B0)CxW{ru@b^z%F0vIbcB)yaqn+z$=?%CG279?9)MHoz zpl9_dLt@nlH11cQR|08uYPlPi5}Q$1O}7DXM`E@6aUGP+HKTjAq7JSzu(NCDqyBP7 zf3FwXVUQlD_#oz12+k(g6>&4rjg!Vh|_jX7dNB>QN2kBn%V~eFm&@9;$zSg1FL(5pJNo;H(w>-@j3{lYe#d*=7V zgdSQ5j07W3zv9?~Eu<|6r*U>{=5Ux`hblWfb1P2V@=mhCRj=61rKB`fS59tDbHzI8 z;J5my{=A{drs7I~h0w+F1v7aHi;hXMtcqVb#G+GVe`^nZchAHxD5dG%6UQg}4)mO< z5@xd$J+wm@m9R;Ua_av{64KL|7zFZK*P|)*HMwb5wjNq8USCrw^~%U*K6ue;n!cCk$rThv8l&YU8w3Zb&guK z2C}~Pnu|46W5Hgdip$9ttClTTtgo<~TvJt5WBymIEUzuBuX59+nrh2U*<5KkxweHY z4$=<rGxJJ#mKii2Ix~1XGBpvivRSvSrqq)V`*_x6w~5(nx&16ZMt5d#EJ6tThho& z=lHNv*maay5OE*=QXz-LF$qr|dOD8B#URm0(kBp0&JhYIJjiX2t3*PyU<5b^jZgpv6 zZF#Uky;GBree(43(891hMKXz+vNW7S`D?4kW~gq+m?)@=qj7b4jFmf+9o3}l;wfdE zU3XTO7cwX#zH$umPTSh*((tL|;6g5$m05XIta_d{o>Uwf&A(hNdI?&4q z^pab3+h16LjaQcEQp&Vm@UA72S)FuHvUei6JT_ddwTs4R;hO}B_n>xS-1WMHnrjBhO#ti zK+mTUABtXlXyk9%oE~4$d`-As6V2%G9WZ*pfosi6mkr z*{6Y=@#LVRC-u=&$)Ux(7fTw+m_r$*`WwVJ{5ES@{+HB0t98;KcFjSa!)`oiJ8u=c^v5U!8t;@~hfoIEQ^og15Tyb@2GM*fo zp}vARyOCSLxF{QIY@9$p!|~yDrO4vg;Og?>)6&vto{gWUCN_?o67#d;hY_POUY6s^ zsxBLIx#ZwJR>rKhuW;nltd?Tq_X-B-G0*n1I5zSB&Zao!>&^ z{kV-UF-c#CO87t&SLXK*4AAM(8CLQ%T`xlo+(oHV@R>=kZ_xS8hX6_rP)33bqImfD zUiW22i5ygRWd9zG3fF*p`Av-!mqD@ue*7A_+zwDy8Lgh_CEAAMDJb&Btgg5>u3a|8abgJxcj&^-uSYv2hj$upi|gkm-vqc=C~vs^w#<1x^sf*;B}Q z;3!$!)sIa2XPL`F9c`>1*4}A`Z|DvxAGfgfuyZaV;CEE tfds.core.DatasetInfo: + """Returns the dataset metadata.""" + return self.dataset_info_from_configs( + features=tfds.features.FeaturesDict({ + 'text': tfds.features.Text(), + 'id': tfds.features.Text(), + 'dump': tfds.features.Text(), + 'url': tfds.features.Text(), + 'file_path': tfds.features.Text(), + 'language': tfds.features.Text(), + 'language_score': tfds.features.Scalar(dtype=np.float64), + 'token_count': tfds.features.Scalar(dtype=np.int64), + 'score': tfds.features.Scalar(dtype=np.float64), + 'int_score': tfds.features.Scalar(dtype=np.int64), + }), + supervised_keys=None, + homepage='https://huggingface.co/datasets/HuggingFaceFW/fineweb-edu', + ) + + def _split_generators(self, dl_manager: tfds.download.DownloadManager): + """Returns SplitGenerators.""" + shard_url_by_shard_idx = { + shard_idx: _URL_TEMPLATE.format(shard_idx=shard_idx) + for shard_idx in range(_NUM_SHARDS) + } + shard_filepath_by_shard_idx = dl_manager.download(shard_url_by_shard_idx) + + return { + 'train': self._generate_examples(shard_filepath_by_shard_idx), + } + + def _generate_examples(self, shard_filepath_by_shard_idx): + """Yields examples.""" + beam = tfds.core.lazy_imports.apache_beam + + shard_filepaths_with_offsets: list[tuple[epath.Path, int]] = [] + offset = 0 + + for shard_idx in range(_NUM_SHARDS): + shard_filepath = shard_filepath_by_shard_idx[shard_idx] + shard_filepaths_with_offsets.append((shard_filepath, offset)) + + parquet_file = pq.ParquetFile(shard_filepath) + offset += parquet_file.metadata.num_rows + + def _process_shard(shard_filepath: epath.Path, offset: int): + parquet_file = pq.ParquetFile(shard_filepath) + + for batch in parquet_file.iter_batches(): + df = batch.to_pandas() + + for row_idx, row in enumerate(df.itertuples()): + yield offset + row_idx, { + 'text': row.text, + 'id': row.id, + 'dump': row.dump, + 'url': row.url, + 'file_path': row.file_path, + 'language': row.language, + 'language_score': row.language_score, + 'token_count': row.token_count, + 'score': row.score, + 'int_score': row.int_score, + } + + return beam.Create(shard_filepaths_with_offsets) | beam.FlatMapTuple( + _process_shard + ) diff --git a/tensorflow_datasets/datasets/fineweb_edu/fineweb_edu_dataset_builder_test.py b/tensorflow_datasets/datasets/fineweb_edu/fineweb_edu_dataset_builder_test.py new file mode 100644 index 00000000000..4035606a137 --- /dev/null +++ b/tensorflow_datasets/datasets/fineweb_edu/fineweb_edu_dataset_builder_test.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# Copyright 2024 The TensorFlow Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FineWebEdu dataset.""" + +from tensorflow_datasets.datasets.fineweb_edu import fineweb_edu_dataset_builder +import tensorflow_datasets.public_api as tfds + + +class FinewebeduTest(tfds.testing.DatasetBuilderTestCase): + """Tests for FineWebEdu dataset.""" + + fineweb_edu_dataset_builder._NUM_SHARDS = 1 + + DATASET_CLASS = fineweb_edu_dataset_builder.Builder + SPLITS = {'train': 1} + + DL_DOWNLOAD_RESULT = {0: 'data.parquet'} + SKIP_CHECKSUMS = True + + +if __name__ == '__main__': + tfds.testing.test_main()