Skip to content

Commit

Permalink
add sharding mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Jul 4, 2023
1 parent 3c83290 commit 85303f3
Show file tree
Hide file tree
Showing 39 changed files with 539 additions and 713 deletions.
4 changes: 2 additions & 2 deletions fortuna/calib_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ def _calibrate(
rng=self.rng.get(),
state=state,
loss_fun=loss,
training_dataloader=calib_data_loader,
training_data_loader=calib_data_loader,
training_dataset_size=n_calib_data,
n_epochs=config.optimizer.n_epochs,
metrics=config.monitor.metrics,
validation_dataloader=val_data_loader,
validation_data_loader=val_data_loader,
validation_dataset_size=n_val_data,
verbose=config.monitor.verbose,
callbacks=config.callbacks,
Expand Down
55 changes: 28 additions & 27 deletions fortuna/calib_model/calib_mixin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from typing import Optional

from flax.training import checkpoints
# from flax.training import checkpoints

from fortuna.calib_model.state import CalibState
from fortuna.training.mixins.checkpointing import WithCheckpointingMixin
Expand All @@ -12,29 +12,30 @@


class WithCalibCheckpointingMixin(WithCheckpointingMixin):
def restore_checkpoint(
self,
restore_checkpoint_dir: Path,
optimizer: Optional[OptaxOptimizer] = None,
prefix: str = "",
**kwargs,
) -> CalibState:
if not os.path.isdir(restore_checkpoint_dir) and not os.path.isfile(
restore_checkpoint_dir
):
raise ValueError(
f"`restore_checkpoint_dir={restore_checkpoint_dir}` was not found."
)
d = checkpoints.restore_checkpoint(
ckpt_dir=str(restore_checkpoint_dir),
target=None,
step=None,
prefix=prefix,
parallel=True,
)
if d is None:
raise ValueError(
f"No checkpoint was found in `restore_checkpoint_dir={restore_checkpoint_dir}`."
)

return CalibState.init_from_dict(d, optimizer, **kwargs)
pass
# def restore_checkpoint(
# self,
# restore_checkpoint_dir: Path,
# optimizer: Optional[OptaxOptimizer] = None,
# prefix: str = "",
# **kwargs,
# ) -> CalibState:
# if not os.path.isdir(restore_checkpoint_dir) and not os.path.isfile(
# restore_checkpoint_dir
# ):
# raise ValueError(
# f"`restore_checkpoint_dir={restore_checkpoint_dir}` was not found."
# )
# d = checkpoints.restore_checkpoint(
# ckpt_dir=str(restore_checkpoint_dir),
# target=None,
# step=None,
# prefix=prefix,
# parallel=True,
# )
# if d is None:
# raise ValueError(
# f"No checkpoint was found in `restore_checkpoint_dir={restore_checkpoint_dir}`."
# )
#
# return CalibState.init_from_dict(d, optimizer, **kwargs)
4 changes: 2 additions & 2 deletions fortuna/data/dataset/huggingface_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ def get_data_loader(
drop_last: bool
if True, the last batch (which potentially is smaller then the default batch size) is dropped.
verbose: bool
Whether to show a progress bar while iterating over the dataloader or not.
Whether to show a progress bar while iterating over the data_loader or not.
Returns
-------
HuggingFaceDataLoader
The dataloader
The data_loader
"""
iterable = IterableData.from_callable(
lambda *args, **kwargs: self._get_data_loader(
Expand Down
36 changes: 34 additions & 2 deletions fortuna/data/loader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Tuple,
Type,
TypeVar,
Union
)

from flax import jax_utils
Expand All @@ -24,6 +25,10 @@
Status,
Targets,
)
from fortuna.utils.prefetch import prefetch_to_mesh
from fortuna.partitioner.partition_manager.base import PartitionManager
from jax import device_put
from jax.sharding import NamedSharding, PartitionSpec

T = TypeVar("T")

Expand Down Expand Up @@ -185,7 +190,7 @@ def from_tensorflow_data_loader(cls: Type[T], tf_data_loader) -> T:
T
A concrete instance of a subclass of :class:`~fortuna.data.loader.BaseDataLoader`.
"""
return cls(iterable=IterableData.from_tf_dataloader(tf_data_loader))
return cls(iterable=IterableData.from_tf_data_loader(tf_data_loader))

@classmethod
def from_torch_data_loader(cls: Type[T], torch_data_loader) -> T:
Expand All @@ -203,7 +208,7 @@ def from_torch_data_loader(cls: Type[T], torch_data_loader) -> T:
T
A concrete instance of a subclass of :class:`~fortuna.data.loader.BaseDataLoader`.
"""
return cls(iterable=IterableData.from_torch_dataloader(torch_data_loader))
return cls(iterable=IterableData.from_torch_data_loader(torch_data_loader))

@classmethod
def from_inputs_loaders(
Expand Down Expand Up @@ -545,3 +550,30 @@ def __iter__(self, *args, **kwargs):
loader = map(lambda batch: tree_map(self._reshape_inputs, batch), self._loader)
loader = jax_utils.prefetch_to_device(loader, 2)
yield from loader


class ShardedPrefetchedLoader:
def __init__(
self,
loader,
partition_manager: Optional[PartitionManager] = None,
shard: bool = True,
partition_spec: Optional[PartitionSpec] = None
):
self._loader = loader
self.partition_manager = partition_manager
self.shard = shard
self.partition_spec = partition_spec
if partition_manager is None and shard:
raise ValueError("`partition_manager` cannot be None when `shard` is set to True.")

def _shard(self, data: Union[Batch, InputData, Targets]):
return device_put(data, NamedSharding(self.partition_manager.partitioner.mesh, self.partition_spec))

def __iter__(self, *args, **kwargs):
if self.shard:
loader = map(lambda data: tree_map(self._shard, data), self._loader)
loader = prefetch_to_mesh(loader, 2, self.partition_manager.partitioner.mesh, self.partition_spec)
else:
loader = jax_utils.prefetch_to_device(self._loader, 2)
yield from loader
2 changes: 1 addition & 1 deletion fortuna/data/loader/huggingface_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
Parameters
----------
iterable : Union[Iterable[Dict[str, Array]], Iterable[Tuple[Dict[str, Array],Array]]]
A data loader obtained via :func:`~HuggingFaceClassificationDataset.get_dataloader`.
A data loader obtained via :func:`~HuggingFaceClassificationDataset.get_data_loader`.
num_unique_labels: int
Number of unique target labels in the task (classification only)
num_inputs: Optional[int]
Expand Down
8 changes: 4 additions & 4 deletions fortuna/data/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def _inner():
return cls(_inner)

@classmethod
def from_tf_dataloader(cls, tf_dataloader) -> IterableData:
def from_tf_data_loader(cls, tf_data_loader) -> IterableData:
def _inner():
for batch_inputs, batch_targets in tf_dataloader:
for batch_inputs, batch_targets in tf_data_loader:
if not isinstance(batch_inputs, dict):
batch_inputs = batch_inputs.numpy()
else:
Expand All @@ -57,9 +57,9 @@ def _inner():
return cls(_inner)

@classmethod
def from_torch_dataloader(cls, torch_dataloader) -> IterableData:
def from_torch_data_loader(cls, torch_data_loader) -> IterableData:
def _inner():
for batch_inputs, batch_targets in torch_dataloader:
for batch_inputs, batch_targets in torch_data_loader:
if not isinstance(batch_inputs, dict):
batch_inputs = batch_inputs.numpy()
else:
Expand Down
5 changes: 3 additions & 2 deletions fortuna/likelihood/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,10 @@ def _batched_log_joint_prob(
mutable=mutable,
rng=rng,
)
if "mutable" in return_aux:
if mutable is not None:
outputs, aux = outs
mutable = aux["mutable"]
if mutable in return_aux:
mutable = aux["mutable"]
else:
outputs = outs

Expand Down
2 changes: 1 addition & 1 deletion fortuna/model/model_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from flax import linen as nn
from flax.core import FrozenDict
from flax.training.checkpoints import PyTree
from optax._src.base import PyTree
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp

Expand Down
2 changes: 1 addition & 1 deletion fortuna/model/model_manager/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from flax.core import FrozenDict
import flax.linen as nn
from flax.training.checkpoints import PyTree
from optax._src.base import PyTree
import jax
from jax import random
from jax._src.prng import PRNGKeyArray
Expand Down
3 changes: 2 additions & 1 deletion fortuna/model/model_manager/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from flax.core import FrozenDict
import flax.linen as nn
from flax.training.checkpoints import PyTree
from optax._src.base import PyTree
import jax
from jax import random
from jax._src.prng import PRNGKeyArray
Expand Down Expand Up @@ -65,6 +65,7 @@ def apply(
lik_log_var_rngs = None

if mutable is not None:
mutable = mutable.unfreeze()
mutable["model"] = mutable.get("model")
mutable["lik_log_var"] = mutable.get("lik_log_var")

Expand Down
2 changes: 1 addition & 1 deletion fortuna/model/model_manager/transformers/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from flax import linen as nn
from flax.core import FrozenDict
from flax.training.checkpoints import PyTree
from optax._src.base import PyTree
import jax
from jax import (
numpy as jnp,
Expand Down
6 changes: 2 additions & 4 deletions fortuna/output_calib_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

from fortuna.output_calib_model.config.base import Config
from fortuna.output_calib_model.loss import Loss
from fortuna.output_calib_model.output_calib_mixin import (
WithOutputCalibCheckpointingMixin,
)
from fortuna.training.mixins.checkpointing import WithCheckpointingMixin
from fortuna.output_calib_model.output_calib_model_calibrator import (
JittedOutputCalibModelCalibrator,
MultiDeviceOutputCalibModelCalibrator,
Expand All @@ -34,7 +32,7 @@
from fortuna.utils.random import RandomNumberGenerator


class OutputCalibModel(WithOutputCalibCheckpointingMixin, abc.ABC):
class OutputCalibModel(WithCheckpointingMixin, abc.ABC):
"""
Abstract calibration model class.
"""
Expand Down
40 changes: 0 additions & 40 deletions fortuna/output_calib_model/output_calib_mixin.py

This file was deleted.

2 changes: 1 addition & 1 deletion fortuna/output_calibrator/output_calib_manager/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from flax.core import FrozenDict
import flax.linen as nn
from flax.training.checkpoints import PyTree
from optax._src.base import PyTree
from jax import random
from jax._src.prng import PRNGKeyArray
import jax.numpy as jnp
Expand Down
2 changes: 1 addition & 1 deletion fortuna/partitioner/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
n_devices: Optional[int] = None,
):
if axis_dims is None:
axis_dims = {"dp": 1, "fsdp": 1, "mp": 1}
axis_dims = {"dp": 1, "fsdp": 1, "mp": -1}
if rules is None:
rules = {}
self.specs = {
Expand Down
24 changes: 12 additions & 12 deletions fortuna/prob_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from fortuna.prob_model.calib_config.base import CalibConfig
from fortuna.prob_model.fit_config.base import FitConfig
from fortuna.prob_model.prob_model_calibrator import (
JittedProbModelOutputCalibrator,
MultiDeviceProbModelOutputCalibrator,
ShardedProbModelOutputCalibrator,
ProbModelOutputCalibrator,
)
from fortuna.typing import (
Expand Down Expand Up @@ -137,7 +136,7 @@ def _calibrate(
"Pre-compute ensemble of outputs on the calibration data loader."
)

distribute = jax.local_devices()[0].platform != "cpu"
shard = not calib_config.processor.disable_jit

(
calib_ensemble_outputs_loader,
Expand All @@ -146,7 +145,7 @@ def _calibrate(
inputs_loader=calib_data_loader.to_inputs_loader(),
n_output_samples=calib_config.processor.n_posterior_samples,
return_size=True,
distribute=distribute,
shard=shard,
)
if calib_config.monitor.verbose:
logging.info(
Expand All @@ -157,19 +156,20 @@ def _calibrate(
inputs_loader=val_data_loader.to_inputs_loader(),
n_output_samples=calib_config.processor.n_posterior_samples,
return_size=True,
distribute=distribute,
shard=shard,
)
if val_data_loader is not None
else (None, None)
)

trainer_cls = select_trainer_given_devices(
devices=calib_config.processor.devices,
base_trainer_cls=ProbModelOutputCalibrator,
jitted_trainer_cls=JittedProbModelOutputCalibrator,
multi_device_trainer_cls=MultiDeviceProbModelOutputCalibrator,
disable_jit=calib_config.processor.disable_jit,
)
# trainer_cls = select_trainer_given_devices(
# devices=calib_config.processor.devices,
# base_trainer_cls=ProbModelOutputCalibrator,
# jitted_trainer_cls=JittedProbModelOutputCalibrator,
# multi_device_trainer_cls=MultiDeviceProbModelOutputCalibrator,
# disable_jit=calib_config.processor.disable_jit,
# )
trainer_cls = ShardedProbModelOutputCalibrator

calibrator = trainer_cls(
calib_outputs_loader=calib_ensemble_outputs_loader,
Expand Down
Loading

0 comments on commit 85303f3

Please sign in to comment.