Skip to content

Commit 62bc583

Browse files
authored
Try #269:
2 parents 045fab2 + 5718276 commit 62bc583

File tree

10 files changed

+57
-10
lines changed

10 files changed

+57
-10
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ version = "0.7.0"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
99
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
12+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1113
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1214
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1315
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

examples/matmul.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,9 @@ function matmul!(a, b, c)
2323
println("Matrix size mismatch!")
2424
return nothing
2525
end
26-
if isa(a, Array)
27-
kernel! = matmul_kernel!(CPU(),4)
28-
else
29-
kernel! = matmul_kernel!(CUDADevice(),256)
30-
end
26+
device = KernelAbstractions.get_device(a)
27+
n = device isa GPU ? 256 : 4
28+
kernel! = matmul_kernel!(device, n)
3129
kernel!(a, b, c, ndrange=size(c))
3230
end
3331

examples/naive_transpose.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@ function naive_transpose!(a, b)
1616
println("Matrix size mismatch!")
1717
return nothing
1818
end
19-
if isa(a, Array)
20-
kernel! = naive_transpose_kernel!(CPU(),4)
21-
else
22-
kernel! = naive_transpose_kernel!(CUDADevice(),256)
23-
end
19+
device = KernelAbstractions.get_device(a)
20+
n = device isa GPU ? 256 : 4
21+
kernel! = naive_transpose_kernel!(device, n)
2422
kernel!(a, b, ndrange=size(a))
2523
end
2624

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ import KernelAbstractions
1010

1111
export CUDADevice
1212

13+
KernelAbstractions.get_device(::CUDA.CuArray) = CUDADevice()
14+
KernelAbstractions.get_device(::CUDA.CUSPARSE.AbstractCuSparseArray) = CUDADevice()
15+
1316
const FREE_STREAMS = CUDA.CuStream[]
1417
const STREAMS = CUDA.CuStream[]
1518
const STREAM_GC_THRESHOLD = Ref{Int}(16)

lib/CUDAKernels/test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
44
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
55
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
66
KernelGradients = "e5faadeb-7f6c-408e-9747-a7a26e81c66a"
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
89
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
10+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
911
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1012
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1113
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

lib/CUDAKernels/test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,16 @@ using Enzyme
44
using CUDA
55
using CUDAKernels
66
using Test
7+
using SparseArrays
78

89
include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))
910
include(joinpath(dirname(pathof(KernelGradients)), "..", "test", "testsuite.jl"))
1011

12+
@testset "get_device" begin
13+
@test @inferred(KernelAbstractions.get_device(CuArray(rand(Float32, 3,3)))) == CUDADevice()
14+
@test @inferred(KernelAbstractions.get_device(CuArray(sparse(rand(Float32, 3,3))))) == CUDADevice()
15+
end
16+
1117
if parse(Bool, get(ENV, "CI", "false"))
1218
default = "CPU"
1319
else

lib/ROCKernels/src/ROCKernels.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ import KernelAbstractions
1111

1212
export ROCDevice
1313

14+
KernelAbstractions.get_device(::AMDGPU.ROCArray) = ROCDevice()
15+
16+
1417
const FREE_QUEUES = HSAQueue[]
1518
const QUEUES = HSAQueue[]
1619
const QUEUE_GC_THRESHOLD = Ref{Int}(16)

lib/ROCKernels/test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ using Test
88
include(joinpath(dirname(pathof(KernelAbstractions)), "..", "test", "testsuite.jl"))
99
include(joinpath(dirname(pathof(KernelGradients)), "..", "test", "testsuite.jl"))
1010

11+
@test "get_device" begin
12+
@test @inferred(KernelAbstractions.get_device(ROCArray(rand(Float32, 3, 3)))) == ROCDevice()
13+
end
14+
1115
if parse(Bool, get(ENV, "CI", "false"))
1216
default = "CPU"
1317
else

src/KernelAbstractions.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ export Device, GPU, CPU, Event, MultiEvent, NoneEvent
66
export async_copy!
77

88

9+
using LinearAlgebra
910
using MacroTools
11+
using SparseArrays
1012
using StaticArrays
1113
using Cassette
1214
using Adapt
@@ -337,6 +339,23 @@ abstract type GPU <: Device end
337339

338340
struct CPU <: Device end
339341

342+
343+
"""
344+
KernelAbstractions.get_device(A::AbstractArray)::KernelAbstractions.Device
345+
346+
Get a `KernelAbstractions.Device` instance suitable for array `A`.
347+
"""
348+
function get_device end
349+
350+
# Should cover SubArray, ReshapedArray, ReinterpretArray, Hermitian, AbstractTriangular, etc.:
351+
get_device(A::AbstractArray) = get_device(parent(A))
352+
353+
get_device(A::AbstractSparseArray) = get_device(rowvals(A))
354+
get_device(A::Diagonal) = get_device(A.diag)
355+
get_device(A::Tridiagonal) = get_device(A.d)
356+
357+
get_device(::Array) = CPU()
358+
340359
include("nditeration.jl")
341360
using .NDIteration
342361
import .NDIteration: get

test/test.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using KernelAbstractions
22
using KernelAbstractions.NDIteration
33
using InteractiveUtils
4+
using LinearAlgebra
5+
using SparseArrays
46
import SpecialFunctions
57

68
identity(x) = x
@@ -64,6 +66,16 @@ end
6466
A[I] = i
6567
end
6668

69+
@testset "get_device" begin
70+
x = rand(5)
71+
A = rand(5,5)
72+
@test @inferred(KernelAbstractions.get_device(A)) == CPU()
73+
@test @inferred(KernelAbstractions.get_device(view(A, 2:4, 1:3))) == CPU()
74+
@test @inferred(KernelAbstractions.get_device(sparse(A))) == CPU()
75+
@test @inferred(KernelAbstractions.get_device(Diagonal(x))) == CPU()
76+
@test @inferred(KernelAbstractions.get_device(Tridiagonal(A))) == CPU()
77+
end
78+
6779
@testset "indextest" begin
6880
# TODO: add test for _group and _local_cartesian
6981
A = ArrayT{Int}(undef, 16, 16)

0 commit comments

Comments
 (0)