Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 22 additions & 47 deletions crates/krnl-macros/src/kernel_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ impl Kernel {
.inputs
.iter()
.filter(|x| x.is_spec())
.map(|x| &x.ty); //self.sig.generics.const_params().map(|x| &x.ty);
.map(|x| &x.ty);
let args = self.sig.inputs.iter().filter_map(|x| x.host_arg_ty());
let kernel_import_visit = self.kernel_import_visit();
let visit_build_args = self
Expand Down Expand Up @@ -416,7 +416,6 @@ impl Kernel {
}
});
quote! {
let __krnl_items = *__krnl_items as usize;
let mut __krnl_item_id = __krnl_global_thread_id;
while __krnl_item_id < __krnl_items {
#(#item_loads_stores)*
Expand Down Expand Up @@ -517,14 +516,6 @@ impl Kernel {
]
.into_iter()
.map(FnArg::unwrap_pat_type);
let items = if self.is_item() {
Some(FnArg::unwrap_pat_type(parse_quote! {
#[spirv(push_constant)]
__krnl_items: &u32
}))
} else {
None
};
let mut binding = 0;
let args = self
.sig
Expand All @@ -537,7 +528,7 @@ impl Kernel {
.iter()
.enumerate()
.map(move |(i, x)| x.entry_point_arg(&mut binding));
builtins.chain(items).chain(args).chain(group_buffers)
builtins.chain(args).chain(group_buffers)
}
fn is_item(&self) -> bool {
self.sig.inputs.iter().any(|x| x.is_item())
Expand Down Expand Up @@ -640,26 +631,27 @@ impl Kernel {
let __krnl_global_threads = __krnl_groups * __krnl_threads;
let __krnl_global_thread_id = __krnl_group_id * __krnl_threads + __krnl_thread_id;
};
/*
let specs = self
.sig
.generics
.const_params()
.map(|ConstParam { ident, ty, .. }| -> Stmt {
parse_quote! {
let #ident = {
unsafe { krnl::kernel::__private::__spec_constant::<#ty>(&#ident) };
*#ident
};
let items = {
let mut items = self
.sig
.inputs
.iter()
.filter(|x| x.is_item())
.map(|x| x.ident.clone())
.peekable();
if let Some(first) = items.next() {
if items.peek().is_some() {
quote! {
let __krnl_items = #first.len() #(.min(#items.len()))*;
}
} else {
quote! {
let __krnl_items = #first.len();
}
}
});
*/
let items = if self.is_item() {
quote! {
unsafe { krnl::kernel::__private::__push_constant(__krnl_items) };
} else {
TokenStream::new()
}
} else {
TokenStream::new()
};
let inputs = self.sig.inputs.iter().map(move |x| x.device_decl());
quote! {
Expand Down Expand Up @@ -715,14 +707,7 @@ impl Kernel {
}
}
}
let items = if self.is_item() {
Some(KernelImportInput::items())
} else {
None
};
let inputs = items
.into_iter()
.chain(self.sig.inputs.iter().filter_map(|x| x.kernel_import()));
let inputs = self.sig.inputs.iter().filter_map(|x| x.kernel_import());
let is_generic = !self.sig.generics.params.is_empty();
let safety = if self.sig.unsafety.is_some() {
quote!(())
Expand Down Expand Up @@ -1198,16 +1183,6 @@ impl KernelImportInput {
ty: parse2(const_param.ty.to_token_stream()).unwrap(),
}
}
fn items() -> Self {
KernelImportInput {
ident: format_ident!("__krnl_items"),
kind: format_ident!("__Push"),
colon_token: Colon {
spans: [Span::call_site()],
},
ty: parse_quote!(u32),
}
}
}

impl Clone for KernelImportInput {
Expand Down
1 change: 0 additions & 1 deletion crates/krnl-macros/src/root_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ pub fn root() -> Result<TokenStream> {
let kernel_tokens: TokenStream = kernels
.into_iter()
.flat_map(|(name, binary)| {
dbg!(&name);
let binary: Punctuated<Literal, syn::token::Comma> =
binary.into_iter().map(Literal::u8_unsuffixed).collect();
let (module, name) = name.rsplit_once("::").unwrap();
Expand Down
Binary file modified crates/krnl/krnl.spv
Binary file not shown.
33 changes: 21 additions & 12 deletions crates/krnl/src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
context::{Buffer as RawBuffer, Context, Slice as RawSlice, SliceMut as RawSliceMut},
};
use bytemuck::Pod;
use core::ops::RangeBounds;
use rayon::iter::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
};
Expand Down Expand Up @@ -184,6 +185,9 @@ impl<'a, T> From<&'a [T]> for Slice<'a, T> {
}

impl<'a, T> Slice<'a, T> {
pub(crate) fn as_context_slice(&self) -> &crate::context::Slice<'a, T> {
&self.data.raw
}
pub fn into_host_slice(self) -> Option<&'a [T]> {
#[allow(irrefutable_let_patterns)]
if let RawSlice::Host(slice) = self.data.raw {
Expand All @@ -192,6 +196,13 @@ impl<'a, T> Slice<'a, T> {
None
}
}
pub fn slice(self, bounds: impl RangeBounds<usize>) -> Self {
Self {
data: SliceRepr {
raw: self.data.raw.slice(bounds),
},
}
}
}

impl<'a, T> SliceMut<'a, T> {
Expand All @@ -203,6 +214,16 @@ impl<'a, T> SliceMut<'a, T> {
None
}
}
pub fn slice_mut(self, bounds: impl RangeBounds<usize>) -> Self {
Self {
data: SliceMutRepr {
raw: self.data.raw.slice_mut(bounds),
},
}
}
pub(crate) fn as_context_slice_mut(&mut self) -> &mut crate::context::SliceMut<'a, T> {
&mut self.data.raw
}
}

impl<T: Pod, S: DataOwned<Elem = T>> BufferBase<S> {
Expand Down Expand Up @@ -266,18 +287,6 @@ impl<S: Data> BufferBase<S> {
}
}

impl<'a, T> Slice<'a, T> {
pub(crate) fn as_context_slice(&self) -> &crate::context::Slice<'a, T> {
&self.data.raw
}
}

impl<'a, T> SliceMut<'a, T> {
pub(crate) fn as_context_slice_mut(&mut self) -> &mut crate::context::SliceMut<'a, T> {
&mut self.data.raw
}
}

impl<'a, T: Pod> SliceMut<'a, T> {
pub fn try_bitcast_mut<Y: Pod>(self) -> Option<SliceMut<'a, Y>> {
self.data.raw.try_bitcast_mut().map(|raw| SliceMut {
Expand Down
24 changes: 24 additions & 0 deletions crates/krnl/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use core::ops::RangeBounds;

use crate::Result;
use bytemuck::Pod;
use derive_more::From;
Expand Down Expand Up @@ -100,6 +102,17 @@ impl<T> Slice<'_, T> {
Self::Device(x) => Slice::Device(x.as_slice()),
}
}
pub(crate) fn slice(self, bounds: impl RangeBounds<usize>) -> Self {
match self {
Self::Host(x) => {
let start_bound = bounds.start_bound().map(|x| *x);
let end_bound = bounds.end_bound().map(|x| *x);
Self::Host(&x[(start_bound, end_bound)])
}
#[cfg(feature = "device")]
Self::Device(x) => Self::Device(x.slice(bounds)),
}
}
}

impl<T: Pod> Slice<'_, T> {
Expand Down Expand Up @@ -174,6 +187,17 @@ impl<T> SliceMut<'_, T> {
Self::Device(x) => SliceMut::Device(x.as_slice_mut()),
}
}
pub(crate) fn slice_mut(self, bounds: impl RangeBounds<usize>) -> Self {
match self {
Self::Host(x) => {
let start_bound = bounds.start_bound().map(|x| *x);
let end_bound = bounds.end_bound().map(|x| *x);
Self::Host(&mut x[(start_bound, end_bound)])
}
#[cfg(feature = "device")]
Self::Device(x) => Self::Device(x.slice_mut(bounds)),
}
}
}

impl<'a, T: Pod> SliceMut<'a, T> {
Expand Down
66 changes: 49 additions & 17 deletions crates/krnl/src/context/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use crate::Result;
#[cfg(feature = "device")]
use crate::kernel::{KernelCreateInfo, KernelDesc, KernelKey};
use bytemuck::{Pod, cast_slice, cast_slice_mut};
#[cfg(feature = "device")]
use core::ops::RangeBounds;
use parking_lot::Mutex;
use std::{
marker::PhantomData,
Expand All @@ -12,8 +14,8 @@ use std::{
mod backend;
#[cfg(feature = "device")]
use backend::{
Backend as _, Buffer as _, Device as _, DeviceOwned, DeviceSpecifier, Event as _, Kernel as _,
Slice as _,
Backend as _, Buffer as _, BufferRange, Device as _, DeviceOwned, DeviceSpecifier, Event as _,
Kernel as _, Slice as _,
backend_impl::{
Backend, Buffer as RawBuffer, Device as RawDevice, Event as RawEvent, Kernel as RawKernel,
Slice as RawSlice,
Expand Down Expand Up @@ -331,6 +333,22 @@ impl<T> Slice<'_, T> {
_m: PhantomData,
}
}
pub(super) fn slice(self, bounds: impl RangeBounds<usize>) -> Self {
let start_bound = bounds.start_bound().map(|x| *x * size_of::<T>());
let end_bound = bounds.end_bound().map(|x| *x * size_of::<T>());
let size = self.len() * size_of::<T>();
let raw = self.raw.slice(
BufferRange {
start: 0,
end: size,
}
.slice((start_bound, end_bound)),
);
Self {
raw,
_m: PhantomData,
}
}
}

#[cfg(feature = "device")]
Expand Down Expand Up @@ -373,6 +391,22 @@ impl<T> SliceMut<'_, T> {
_m: PhantomData,
}
}
pub(super) fn slice_mut(self, bounds: impl RangeBounds<usize>) -> Self {
let start_bound = bounds.start_bound().map(|x| *x * size_of::<T>());
let end_bound = bounds.end_bound().map(|x| *x * size_of::<T>());
let size = self.len() * size_of::<T>();
let raw = self.raw.slice(
BufferRange {
start: 0,
end: size,
}
.slice((start_bound, end_bound)),
);
Self {
raw,
_m: PhantomData,
}
}
}

#[cfg(feature = "device")]
Expand Down Expand Up @@ -437,26 +471,24 @@ pub(crate) struct BufferBindingVec(Vec<RawBufferBinding>);

#[cfg(feature = "device")]
impl BufferBindingVec {
pub(crate) fn set_slice<T: Pod>(&mut self, index: usize, slice: &Slice<T>) {
self.0[index] = RawBufferBinding {
slice: slice.raw.clone(),
mutable: false,
};
}

pub(crate) fn with_capacity(capacity: usize) -> Self {
Self(Vec::with_capacity(capacity))
}
pub(crate) fn push_slice<T: Pod>(&mut self, slice: &Slice<T>) {
self.0.push(RawBufferBinding {
slice: slice.raw.clone(),
mutable: false,
});
self.0.push(RawBufferBinding::new(
slice.raw.clone(),
false,
size_of::<T>(),
));
}
pub(crate) fn push_slice_mut<T: Pod>(&mut self, slice: &mut SliceMut<T>) {
self.0.push(RawBufferBinding {
slice: slice.raw.clone(),
mutable: true,
});
self.0.push(RawBufferBinding::new(
slice.raw.clone(),
true,
size_of::<T>(),
));
}
pub(crate) fn buffer_offsets(&self) -> impl Iterator<Item = u32> {
self.0.iter().map(|x| x.offset())
}
}
Loading