From 15d85e20a365743971efac565485af97a992e15f Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 11 Apr 2025 15:32:44 +0200 Subject: [PATCH 1/2] Implement BroadcastStyle using Adapt's unwrap_type --- ext/OffsetArraysAdaptExt.jl | 2 ++ src/OffsetArrays.jl | 6 ------ 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/ext/OffsetArraysAdaptExt.jl b/ext/OffsetArraysAdaptExt.jl index 55969fd..b05c170 100644 --- a/ext/OffsetArraysAdaptExt.jl +++ b/ext/OffsetArraysAdaptExt.jl @@ -12,6 +12,8 @@ Adapt.adapt_structure(to, O::OffsetArray) = OffsetArrays.parent_call(x -> Adapt. # To support Adapt 3.0 which doesn't have parent_type defined Adapt.parent_type(::Type{OffsetArray{T,N,AA}}) where {T,N,AA} = AA Adapt.unwrap_type(W::Type{<:OffsetArray}) = unwrap_type(parent_type(W)) + + Base.Broadcast.BroadcastStyle(W::Type{<:OffsetArray}) = BroadcastStyle(unwrap_type(W)) end end diff --git a/src/OffsetArrays.jl b/src/OffsetArrays.jl index 9aa9f58..55b4c23 100644 --- a/src/OffsetArrays.jl +++ b/src/OffsetArrays.jl @@ -280,12 +280,6 @@ parenttype(A::OffsetArray) = parenttype(typeof(A)) Base.parent(A::OffsetArray) = A.parent -# TODO: Ideally we would delegate to the parent's broadcasting implementation, but that -# is currently broken in sufficiently many implementation, namely RecursiveArrayTools, DistributedArrays -# and StaticArrays, that it will take concentrated effort to get this working across the ecosystem. -# The goal would be to have `OffsetArray(CuArray) .+ 1 == OffsetArray{CuArray}`. -# Base.Broadcast.BroadcastStyle(::Type{<:OffsetArray{<:Any, <:Any, AA}}) where AA = Base.Broadcast.BroadcastStyle(AA) - @inline Base.size(A::OffsetArray) = size(parent(A)) # specializing length isn't necessary, as length(A) = prod(size(A)), # but specializing length enables constant-propagation for statically sized arrays From b3a3edcae629e98dc94b31c5b0e15f76d44636f4 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 14 Apr 2025 08:44:22 +0200 Subject: [PATCH 2/2] Update ext/OffsetArraysAdaptExt.jl --- ext/OffsetArraysAdaptExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/OffsetArraysAdaptExt.jl b/ext/OffsetArraysAdaptExt.jl index b05c170..40a30ab 100644 --- a/ext/OffsetArraysAdaptExt.jl +++ b/ext/OffsetArraysAdaptExt.jl @@ -13,7 +13,7 @@ Adapt.adapt_structure(to, O::OffsetArray) = OffsetArrays.parent_call(x -> Adapt. Adapt.parent_type(::Type{OffsetArray{T,N,AA}}) where {T,N,AA} = AA Adapt.unwrap_type(W::Type{<:OffsetArray}) = unwrap_type(parent_type(W)) - Base.Broadcast.BroadcastStyle(W::Type{<:OffsetArray}) = BroadcastStyle(unwrap_type(W)) + Base.Broadcast.BroadcastStyle(W::Type{<:OffsetArray}) = Base.Broadcast.BroadcastStyle(unwrap_type(W)) end end