Skip to content

Commit

Permalink
added functionalty to handle observation and action scalers for d3rlp…
Browse files Browse the repository at this point in the history
…y api
  • Loading branch information
joshuaspear committed Sep 27, 2024
1 parent 084f780 commit 4de0676
Show file tree
Hide file tree
Showing 8 changed files with 647 additions and 23 deletions.
63 changes: 51 additions & 12 deletions src/offline_rl_ope/api/d3rlpy/Misc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from abc import abstractmethod
from typing import Optional, Union
import torch
from jaxtyping import jaxtyped
from typeguard import typechecked as typechecker
from d3rlpy.models.torch.policies import build_squashed_gaussian_distribution

from d3rlpy.preprocessing import ActionScaler, ObservationScaler
from .types import (
D3rlpyAlgoPredictProtocal, D3rlpyPolicyProtocal,
)
Expand All @@ -15,16 +17,30 @@
"D3RlPyStochasticWrapper"
]

class D3RlPyDeterministicWrapper:

class D3RlPyWrapperBase:
def __init__(
self,
predict_func:D3rlpyAlgoPredictProtocal,
action_dim:int
predict_func:Union[D3rlpyAlgoPredictProtocal,D3rlpyPolicyProtocal]
):
self.predict_func = predict_func
self.action_dim = action_dim

@abstractmethod
def __call__(self, x:StateTensor)->TorchPolicyReturn:
pass


class D3RlPyDeterministicWrapper(D3RlPyWrapperBase):

def __init__(
self,
predict_func: Union[D3rlpyAlgoPredictProtocal, D3rlpyPolicyProtocal],
action_dim: int
):
self.action_dim = action_dim
super().__init__(
predict_func=predict_func
)

@jaxtyped(typechecker=typechecker)
def __call__(self, x:StateTensor)->TorchPolicyReturn:
pred = self.predict_func(x.cpu().numpy()).reshape(
Expand All @@ -45,26 +61,49 @@ def __init__(
assert action_dim==1, "D3RlPy action dimension is 1 for discrete tasks"
super().__init__(
predict_func=predict_func,
action_dim=action_dim
action_dim=action_dim
)

class D3RlPyStochasticWrapper:


class D3RlPyStochasticWrapper(D3RlPyWrapperBase):

def __init__(
self,
policy_func:D3rlpyPolicyProtocal,
observation_scaler:Optional[ObservationScaler] = None,
action_scaler:Optional[ActionScaler]=None
) -> None:
self.policy_func = policy_func
super().__init__(
predict_func=policy_func
)
if action_scaler is not None:
assert action_scaler.built, "Action scaler is not built"
self.action_scaler = action_scaler
if observation_scaler is not None:
assert observation_scaler.built, "Observation scaler is not built"
self.observation_scaler = observation_scaler



@jaxtyped(typechecker=typechecker)
def __call__(
self,
state: StateTensor,
action: ActionTensor
) -> TorchPolicyReturn:
dist = build_squashed_gaussian_distribution(self.policy_func(state))
if self.observation_scaler is not None:
x_scaled = self.observation_scaler.transform(x=state)
else:
x_scaled = state
dist = build_squashed_gaussian_distribution(
self.predict_func(x_scaled)
)
if self.action_scaler is not None:
scaled_action = self.action_scaler.transform(x=action)
else:
scaled_action = action
with torch.no_grad():
res = torch.exp(dist.log_prob(action))
res = torch.exp(dist.log_prob(scaled_action))
return TorchPolicyReturn(
actions=action,
action_prs=res
Expand Down
25 changes: 16 additions & 9 deletions src/offline_rl_ope/api/d3rlpy/Policy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Optional
from d3rlpy.interface import QLearningAlgoProtocol
from typing import Optional, Union
from d3rlpy.algos.qlearning.base import QLearningAlgoBase

from .Misc import D3RlPyDeterministicWrapper, D3RlPyStochasticWrapper
from jaxtyping import jaxtyped
from typeguard import typechecked as typechecker

from .Misc import (
D3RlPyDeterministicDiscreteWrapper, D3RlPyStochasticWrapper
)
from ...components.Policy import (
Policy, GreedyDeterministic)

Expand All @@ -26,10 +31,9 @@ def __init__(
if self.deterministic:
assert action_dim is not None
self.action_dim = action_dim


def __create_deterministic(self, algo):
policy_func = D3RlPyDeterministicWrapper(

def __create_deterministic(self, algo:QLearningAlgoBase):
policy_func = D3RlPyDeterministicDiscreteWrapper(
predict_func=algo.predict,
action_dim=self.action_dim
)
Expand All @@ -39,9 +43,11 @@ def __create_deterministic(self, algo):
)
return eval_policy

def __create_stochastic(self, algo:QLearningAlgoProtocol):
def __create_stochastic(self, algo:QLearningAlgoBase):
policy_func = D3RlPyStochasticWrapper(
policy_func=algo.impl.policy,
observation_scaler=algo._config.observation_scaler,
action_scaler=algo._config.action_scaler
)

eval_policy = Policy(
Expand All @@ -50,7 +56,8 @@ def __create_stochastic(self, algo:QLearningAlgoProtocol):
)
return eval_policy

def create(self, algo: QLearningAlgoProtocol)->Policy:
@jaxtyped(typechecker=typechecker)
def create(self, algo: QLearningAlgoBase)->Policy:
if self.deterministic:
res = self.__create_deterministic(algo=algo)
else:
Expand Down
Empty file added tests/api/__init__.py
Empty file.
Empty file added tests/api/d3rlpy/__init__.py
Empty file.
72 changes: 72 additions & 0 deletions tests/api/d3rlpy/algo_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from d3rlpy.algos import QLearningAlgoBase, QLearningAlgoImplBase
from d3rlpy.base import LearnableConfig
from d3rlpy.logging import NoopAdapterFactory
from d3rlpy.dataset import ReplayBuffer
import numpy.typing as npt
from typing import Sequence, Union, Any, overload, cast
import torch
from unittest.mock import Mock
import numpy as np


NDArray = npt.NDArray[Any]
Float32NDArray = npt.NDArray[np.float32]
Int32NDArray = npt.NDArray[np.int32]
UInt8NDArray = npt.NDArray[np.uint8]
DType = npt.DTypeLike

Observation = Union[NDArray, Sequence[NDArray]]
ObservationSequence = Union[NDArray, Sequence[NDArray]]
Shape = Union[Sequence[int], Sequence[Sequence[int]]]
TorchObservation = Union[torch.Tensor, Sequence[torch.Tensor]]



@overload
def create_observations(
observation_shape: Sequence[int], length: int, dtype: DType = np.float32
) -> NDArray: ...


@overload
def create_observations(
observation_shape: Sequence[Sequence[int]],
length: int,
dtype: DType = np.float32,
) -> Sequence[NDArray]: ...


def create_observations(
observation_shape: Shape, length: int, dtype: DType = np.float32
) -> ObservationSequence:
observations: ObservationSequence
if isinstance(observation_shape[0], (list, tuple)):
observations = [
np.random.random((length, *shape)).astype(dtype)
for shape in cast(Sequence[Sequence[int]], observation_shape)
]
else:
observations = np.random.random((length, *observation_shape)).astype(
dtype
)
return observations

def init_trained_algo(
algo: QLearningAlgoBase[QLearningAlgoImplBase, LearnableConfig],
dataset:ReplayBuffer
):
algo.update = Mock(return_value={"loss": np.random.random()}) # type: ignore
n_batch = algo.config.batch_size
n_steps = 10
n_steps_per_epoch = 5
n_epochs = n_steps // n_steps_per_epoch
# data_size = n_episodes * episode_length

# check fit
results = algo.fit(
dataset,
n_steps=n_steps,
n_steps_per_epoch=n_steps_per_epoch,
logger_adapter=NoopAdapterFactory(),
show_progress=False,
)
Loading

0 comments on commit 4de0676

Please sign in to comment.