Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #16

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Dev #16

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Jun Tian <[email protected]> and contributors"]
version = "0.2.1"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Continuous spaces have some additional interface functions:

- `bounds(space)` returns upper and lower bounds in a tuple. For example, if `space` is a unit circle, `bounds(space)` will return `([-1.0, -1.0], [1.0, 1.0])`. This allows agents to choose policies that appropriately cover the space e.g. a normal distribution with a mean of `mean(bounds(space))` and a standard deviation of half the distance between the bounds.
- `clamp(x, space)` returns an element of `space` that is near `x`. i.e. if `space` is a unit circle, `clamp([2.0, 0.0], space)` might return `[1.0, 0.0]`. This allows for a convenient way for an agent to find a valid action if they sample actions from a distribution that doesn't match the space exactly (e.g. a normal distribution).
- `clamp!(x, space)`, similar to `clamp`, but clamps `x` in place.
- [Not implemented] `clamp!(x, space)`, similar to `clamp`, but clamps `x` in place.

### Hybrid spaces

Expand Down Expand Up @@ -70,12 +70,12 @@ The `TupleSpaceProduct` constructor provides a specialized Cartesian product whe

|Category|Style|Example|
|:---|:----|:-----|
|Enumerable discrete space| `FiniteSpaceStyle{()}()` | `(:cat, :dog)`, `0:1`, `["a","b","c"]` |
|One dimensional continuous space| `ContinuousSpaceStyle{()}()` | `-1.2..3.3`, `Interval(1.0, 2.0)` |
|Multi-dimensional discrete space| `FiniteSpaceStyle{(3,4)}()` | `ArraySpace((:cat, :dog), 3, 4)`, `ArraySpace(0:1, 3, 4)`, `ArraySpace(1:2, 3, 4)`, `ArraySpace(Bool, 3, 4)`|
|Multi-dimensional variable discrete space| `FiniteSpaceStyle{(2,)}()` | `product((:cat, :dog), (:litchi, :longan, :mango))`, `product(-1:1, (false, true))`|
|Multi-dimensional continuous space| `ContinuousSpaceStyle{(2,)}()` or `ContinuousSpaceStyle{(3,4)}()` | `Box([-1.0, -2.0], [2.0, 4.0])`, `product(-1.2..3.3, -4.6..5.0)`, `ArraySpace(-1.2..3.3, 3, 4)`, `ArraySpace(Float32, 3, 4)` |
|Multi-dimensional hybrid space [planned for future]| `HybridSpaceStyle{(2,),()}()` | `product(-1.2..3.3, -4.6..5.0, [:cat, :dog])`, `product(Box([-1.0, -2.0], [2.0, 4.0]), [1,2,3])`|
|Enumerable discrete space| `FiniteSpaceStyle()` | `(:cat, :dog)`, `0:1`, `["a","b","c"]` |
|One dimensional continuous space| `ContinuousSpaceStyle()` | `-1.2..3.3`, `Interval(1.0, 2.0)` |
|Multi-dimensional discrete space| `FiniteSpaceStyle()` | `ArraySpace((:cat, :dog), 3, 4)`, `ArraySpace(0:1, 3, 4)`, `ArraySpace(1:2, 3, 4)`, `ArraySpace(Bool, 3, 4)`|
|Multi-dimensional variable discrete space| `FiniteSpaceStyle()` | `product((:cat, :dog), (:litchi, :longan, :mango))`, `product(-1:1, (false, true))`|
|Multi-dimensional continuous space| `ContinuousSpaceStyle()` | `Box([-1.0, -2.0], [2.0, 4.0])`, `product(-1.2..3.3, -4.6..5.0)`, `ArraySpace(-1.2..3.3, 3, 4)`, `ArraySpace(Float32, 3, 4)` |
|Multi-dimensional hybrid space [planned for future]| `HybridProductSpaceStyle()` | `product(-1.2..3.3, -4.6..5.0, [:cat, :dog])`, `product(Box([-1.0, -2.0], [2.0, 4.0]), [1,2,3])`|

### API

Expand Down
9 changes: 6 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
using CommonRLSpaces
using Random
using Documenter

DocMeta.setdocmeta!(CommonRLSpaces, :DocTestSetup, :(using CommonRLSpaces); recursive=true)

makedocs(;
modules=[CommonRLSpaces],
authors="Jun Tian <[email protected]> and contributors",
repo="https://github.com/Jun Tian/CommonRLSpaces.jl/blob/{commit}{path}#{line}",
repo="https://github.com/JuliaReinforcementLearning/CommonRLSpaces.jl/blob/{commit}{path}#{line}",
sitename="CommonRLSpaces.jl",
format=Documenter.HTML(;
prettyurls=get(ENV, "CI", "false") == "true",
canonical="https://Jun Tian.github.io/CommonRLSpaces.jl",
canonical="https://github.com/JuliaReinforcementLearning/CommonRLSpaces.jl",
assets=String[],
),
pages=[
"Home" => "index.md",
"array.md",
"extensions.md"
],
)

deploydocs(;
repo="github.com/Jun Tian/CommonRLSpaces.jl",
repo="https://github.com/JuliaReinforcementLearning/CommonRLSpaces.jl",
devbranch="main",
)
10 changes: 10 additions & 0 deletions docs/src/array.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Array Spaces

```@docs
AbstractArraySpace
elsize
Box
Base.rand(::AbstractRNG, ::Random.SamplerTrivial{Box{T}}) where {T}
RepeatedSpace
ArraySpace
```
2 changes: 2 additions & 0 deletions docs/src/extensions.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Extensions

25 changes: 22 additions & 3 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,30 @@ CurrentModule = CommonRLSpaces

# CommonRLSpaces

Documentation for [CommonRLSpaces](https://github.com/Jun Tian/CommonRLSpaces.jl).
Documentation for [CommonRLSpaces](https://github.com/JuliaReinforcementLearning/CommonRLSpaces.jl).

```@index
```
## Space Styles

```@autodocs
Modules = [CommonRLSpaces]
Filter = t -> typeof(t) === DataType && t <: AbstractSpaceStyle
```

```@docs
SpaceStyle
```

## Interface

Common
- Base.in
- Base.rand - https://docs.julialang.org/en/v1/stdlib/Random/#Hooking-into-the-Random-API
- Base.eltype
- product

Finite
- Base.collect

Continuous
- bounds
- Base.clamp
4 changes: 4 additions & 0 deletions src/CommonRLSpaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,25 @@ using Reexport
using StaticArrays
using FillArrays
using Random
using Distributions
import Base: clamp

export
SpaceStyle,
AbstractSpaceStyle,
FiniteSpaceStyle,
ContinuousSpaceStyle,
HybridProductSpaceStyle,
UnknownSpaceStyle,
AbstractArraySpace,
bounds,
elsize

include("basic.jl")

export
Box,
RepeatedSpace,
ArraySpace

include("array.jl")
Expand Down
133 changes: 112 additions & 21 deletions src/array.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
"""
AbstractArraySpace

Abstract base class for Array Spaces.
"""
abstract type AbstractArraySpace end
# Maybe AbstractArraySpace should have an eltype parameter so that you could call convert(AbstractArraySpace{Float32}, space)
# Maybe AbstractArraySpace should have an eltype parameter so that you could call
# convert(AbstractArraySpace{Float32}, space)

"""
elsize(::AbstractArraySpace)

Return the size of the objects in a space.
"""
function elsize end # note: different than Base.elsize


"""
Box(lower, upper)

A Box represents a space of real-valued arrays bounded element-wise above by `upper` and below by `lower`, e.g. `Box([-1, -2], [3, 4]` represents the two-dimensional vector space that is the Cartesian product of the two closed sets: ``[-1, 3] \\times [-2, 4]``.
A Box represents a space of real-valued arrays bounded element-wise above by `upper` and
below by `lower`, e.g. `Box([-1, -2], [3, 4]` represents the two-dimensional vector space
that is the Cartesian product of the two closed sets: ``[-1, 3] \\times [-2, 4]``.

The elements of a Box are always `AbstractArray`s with `AbstractFloat` elements. `Box`es always have `ContinuousSpaceStyle`, and products of `Box`es with `Box`es or `ClosedInterval`s are `Box`es when the dimensions are compatible.
The elements of a Box are always `AbstractArray`s with `AbstractFloat` elements. `Box`es
always have `ContinuousSpaceStyle`, and products of `Box`es with `Box`es or
`ClosedInterval`s are `Box`es when the dimensions are compatible.
"""
struct Box{A<:AbstractArray{<:AbstractFloat}} <: AbstractArraySpace
lower::A
Expand All @@ -17,30 +35,74 @@ end

function Box(lower, upper; convert_to_static::Bool=false)
@assert size(lower) == size(upper)
sz = size(lower)
continuous_lower = convert(AbstractArray{float(eltype(lower))}, lower)
continuous_upper = convert(AbstractArray{float(eltype(upper))}, upper)
T = promote_type(eltype(lower), eltype(upper)) |> float
continuous_lower = convert(AbstractArray{T}, lower)
continuous_upper = convert(AbstractArray{T}, upper)
if convert_to_static
final_lower = SArray{Tuple{sz...}}(continuous_lower)
final_upper = SArray{Tuple{sz...}}(continuous_upper)
final_lower = SArray{Tuple{size(continuous_lower)...}}(continuous_lower)
final_upper = SArray{Tuple{size(continuous_upper)...}}(continuous_upper)
else
final_lower, final_upper = promote(continuous_lower, continuous_upper)
final_lower, final_upper = continuous_lower, continuous_upper
end
return Box{typeof(final_lower)}(final_lower, final_upper)
end

function Base.:(==)(b1::T, b2::T) where {T <: Box}
return (b1.lower == b2.lower) && (b1.upper == b2.upper)
end

# By default, convert builtin arrays to static
Box(lower::Array, upper::Array) = Box(lower, upper; convert_to_static=true)

SpaceStyle(::Box) = ContinuousSpaceStyle()

function Base.rand(rng::AbstractRNG, sp::Random.SamplerTrivial{<:Box})
"""
Base.rand(::AbstractRNG, ::Random.SamplerTrivial{<:Box})

Generate an array where each element is sampled from a dimension of a Box space.

* Finite intervals [a,b] are sampled from uniform distributions.
* Semi-infinite intervals (a,Inf) and (-Inf,b) are sampled from shifted exponential
distributions.
* Infinite intervals (-Inf,Inf) are sampled from normal distributions.

# Example

```@repl
using CommonRLSpaces
using Random: seed!
using Distributions: Uniform, Normal, Exponential
box = Box([-10, -Inf, 3], [10, Inf, Inf])
seed!(0)
rand(box)
seed!(0)
[rand(Uniform(-10,10)), rand(Normal()), 3+rand(Exponential())]
```
"""
function Base.rand(rng::AbstractRNG, sp::Random.SamplerTrivial{Box{T}}) where {T}
box = sp[]
return box.lower + rand_similar(rng, box.lower) .* (box.upper-box.lower)
x = [rand_interval(rng, lb, ub) for (lb, ub) in zip(box.lower, box.upper)]
return T(x)
end

rand_similar(rng::AbstractRNG, a::StaticArray) = rand(rng, typeof(a))
rand_similar(rng::AbstractRNG, a::AbstractArray) = rand(rng, eltype(a), size(a)...)
function rand_interval(rng::AbstractRNG, lb::T, ub::T) where {T <: Real}
offset, sign = zero(T), one(T)

if isfinite(lb) && isfinite(ub)
dist = Uniform(lb, ub)
elseif isfinite(lb) && !isfinite(ub)
offset = lb
dist = Exponential(one(T))
elseif !isfinite(lb) && isfinite(ub)
offset = ub
sign = -one(T)
dist = Exponential(one(T))
else
dist = Normal(zero(T), one(T))
end

return offset + sign * rand(rng, dist)
end

Base.in(x::AbstractArray, b::Box) = all(b.lower .<= x .<= b.upper)

Expand All @@ -52,25 +114,32 @@ Base.clamp(x::AbstractArray, b::Box) = clamp.(x, b.lower, b.upper)

Base.convert(t::Type{<:Box}, i::ClosedInterval) = t(SA[minimum(i)], SA[maximum(i)])

"""
RepeatedSpace(base_space, elsize)

A RepeatedSpace reperesents a space of arrays with shape `elsize`, where each element of
the array is drawn from `base_space`.
"""
struct RepeatedSpace{B, S<:Tuple} <: AbstractArraySpace
base_space::B
elsize::S
end

"""
ArraySpace(base_space, size...)

Create a space of Arrays with shape `size`, where each element of the array is drawn from `base_space`.
"""
ArraySpace(base_space, size...) = RepeatedSpace(base_space, size)
RepeatedSpace(base_size, elsize...) = RepeatedSpace(base_size, elsize)

SpaceStyle(s::RepeatedSpace) = SpaceStyle(s.base_space)

Base.rand(rng::AbstractRNG, sp::Random.SamplerTrivial{<:RepeatedSpace}) = rand(rng, sp[].base_space, sp[].elsize...)
function Base.rand(rng::AbstractRNG, sp::Random.SamplerTrivial{<:RepeatedSpace})
return rand(rng, sp[].base_space, sp[].elsize...)
end

Base.in(x::AbstractArray, s::RepeatedSpace) = all(entry in s.base_space for entry in x)

Base.eltype(s::RepeatedSpace) = AbstractArray{eltype(s.base_space), length(s.elsize)}
Base.eltype(s::RepeatedSpace{<:AbstractInterval}) = AbstractArray{Random.gentype(s.base_space), length(s.elsize)}
function Base.eltype(s::RepeatedSpace{<:AbstractInterval})
return AbstractArray{Random.gentype(s.base_space), length(s.elsize)}
end

elsize(s::RepeatedSpace) = s.elsize

function bounds(s::RepeatedSpace)
Expand All @@ -79,3 +148,25 @@ function bounds(s::RepeatedSpace)
end

Base.clamp(x::AbstractArray, s::RepeatedSpace) = map(entry -> clamp(entry, s.base_space), x)

"""
ArraySpace(base_space, size...)

Constructor for RepeatedSpace and Box.

If `base_space` is an AbstractFloat or ClosedInterval return a Box (preferred), otherwise
return a RepeatedSpace.
"""
ArraySpace(base_space, size...) = RepeatedSpace(base_space, size)

function ArraySpace(::Type{T}, size...) where {T<:AbstractFloat}
lower = fill(typemin(T), size)
upper = fill(typemax(T), size)
return Box(lower, upper)
end

function ArraySpace(i::ClosedInterval, size...)
lower = fill(minimum(i), size)
upper = fill(maximum(T), size)
return Box(lower, upper)
end
Loading