diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index aa43b403e..0332ff351 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -3,6 +3,7 @@ mod array_op; mod array_repeat; mod array_scan; +pub mod op_builder; use std::sync::Arc; @@ -26,6 +27,7 @@ use crate::Extension; pub use array_op::{ArrayOp, ArrayOpDef, ArrayOpDefIter}; pub use array_repeat::{ArrayRepeat, ArrayRepeatDef, ARRAY_REPEAT_OP_ID}; pub use array_scan::{ArrayScan, ArrayScanDef, ARRAY_SCAN_OP_ID}; +pub use op_builder::ArrayOpBuilder; /// Reported unique name of the array type. pub const ARRAY_TYPENAME: TypeName = TypeName::new_inline("array"); diff --git a/hugr-core/src/std_extensions/collections/array/op_builder.rs b/hugr-core/src/std_extensions/collections/array/op_builder.rs new file mode 100644 index 000000000..46338dd43 --- /dev/null +++ b/hugr-core/src/std_extensions/collections/array/op_builder.rs @@ -0,0 +1,291 @@ +//! Builder trait for array operations in the dataflow graph. + +use crate::std_extensions::collections::array::{new_array_op, ArrayOpDef}; +use crate::{ + builder::{BuildError, Dataflow}, + extension::simple_op::HasConcrete as _, + types::Type, + Wire, +}; +use itertools::Itertools as _; + +/// Trait for building array operations in a dataflow graph. +pub trait ArrayOpBuilder: Dataflow { + /// Adds a new array operation to the dataflow graph and return the wire + /// representing the new array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `values` - An iterator over the values to initialize the array with. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// The wire representing the new array. + fn add_new_array( + &mut self, + elem_ty: Type, + values: impl IntoIterator, + ) -> Result { + let inputs = values.into_iter().collect_vec(); + let [out] = self + .add_dataflow_op(new_array_op(elem_ty, inputs.len() as u64), inputs)? + .outputs_arr(); + Ok(out) + } + + /// Adds an array get operation to the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to get. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// The wire representing the value at the specified index in the array. + fn add_array_get( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + ) -> Result { + let op = ArrayOpDef::get.instantiate(&[size.into(), elem_ty.into()])?; + let [out] = self.add_dataflow_op(op, vec![input, index])?.outputs_arr(); + Ok(out) + } + + /// Adds an array set operation to the dataflow graph. + /// + /// This operation sets the value at a specified index in the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to set. + /// * `value` - The wire representing the value to set at the specified index. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the updated array after the set operation. + fn add_array_set( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + value: Wire, + ) -> Result { + let op = ArrayOpDef::set.instantiate(&[size.into(), elem_ty.into()])?; + let [out] = self + .add_dataflow_op(op, vec![input, index, value])? + .outputs_arr(); + Ok(out) + } + + /// Adds an array swap operation to the dataflow graph. + /// + /// This operation swaps the values at two specified indices in the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index1` - The wire representing the first index to swap. + /// * `index2` - The wire representing the second index to swap. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the updated array after the swap operation. + fn add_array_swap( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index1: Wire, + index2: Wire, + ) -> Result { + let op = ArrayOpDef::swap.instantiate(&[size.into(), elem_ty.into()])?; + let [out] = self + .add_dataflow_op(op, vec![input, index1, index2])? + .outputs_arr(); + Ok(out) + } + + /// Adds an array pop-left operation to the dataflow graph. + /// + /// This operation removes the leftmost element from the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the Option> + fn add_array_pop_left( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result { + let op = ArrayOpDef::pop_left.instantiate(&[size.into(), elem_ty.into()])?; + Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0)) + } + + /// Adds an array pop-right operation to the dataflow graph. + /// + /// This operation removes the rightmost element from the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the Option> + fn add_array_pop_right( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result { + let op = ArrayOpDef::pop_right.instantiate(&[size.into(), elem_ty.into()])?; + Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0)) + } + + /// Adds an operation to discard an empty array from the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + fn add_array_discard_empty(&mut self, elem_ty: Type, input: Wire) -> Result<(), BuildError> { + self.add_dataflow_op( + ArrayOpDef::discard_empty + .instantiate(&[elem_ty.into()]) + .unwrap(), + [input], + )?; + Ok(()) + } +} + +impl ArrayOpBuilder for D {} + +#[cfg(test)] +mod test { + use crate::extension::prelude::PRELUDE_ID; + use crate::extension::ExtensionSet; + use crate::std_extensions::collections::array::{self, array_type}; + use crate::{ + builder::{DFGBuilder, HugrBuilder}, + extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _}, + types::Signature, + Hugr, + }; + use rstest::rstest; + + use super::*; + + #[rstest::fixture] + #[default(DFGBuilder)] + fn all_array_ops( + #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW) + .with_extension_delta(ExtensionSet::from_iter([ + PRELUDE_ID, + array::EXTENSION_ID + ]))).unwrap())] + mut builder: B, + ) -> B { + let us0 = builder.add_load_value(ConstUsize::new(0)); + let us1 = builder.add_load_value(ConstUsize::new(1)); + let us2 = builder.add_load_value(ConstUsize::new(2)); + let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); + let [arr] = { + let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap(); + let res_sum_ty = { + let array_type = array_type(2, usize_t()); + either_type(array_type.clone(), array_type) + }; + builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() + }; + + let [elem_0] = { + let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap(); + builder + .build_unwrap_sum(1, option_type(usize_t()), r) + .unwrap() + }; + + let [_elem_1, arr] = { + let r = builder + .add_array_set(usize_t(), 2, arr, us1, elem_0) + .unwrap(); + let res_sum_ty = { + let row = vec![usize_t(), array_type(2, usize_t())]; + either_type(row.clone(), row) + }; + builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() + }; + + let [_elem_left, arr] = { + let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap(); + builder + .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r) + .unwrap() + }; + let [_elem_right, arr] = { + let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap(); + builder + .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r) + .unwrap() + }; + + builder.add_array_discard_empty(usize_t(), arr).unwrap(); + builder + } + + #[rstest] + fn build_all_ops(all_array_ops: DFGBuilder) { + all_array_ops.finish_hugr().unwrap(); + } +} diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 65e0599ea..55dcecefc 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -707,6 +707,7 @@ pub fn emit_scan_op<'c, H: HugrView>( #[cfg(test)] mod test { use hugr_core::builder::Container as _; + use hugr_core::extension::prelude::either_type; use hugr_core::extension::ExtensionSet; use hugr_core::ops::Tag; use hugr_core::std_extensions::collections::array::{self, array_type, ArrayRepeat, ArrayScan}; @@ -724,6 +725,7 @@ mod test { int_ops::{self}, int_types::{self, int_type, ConstInt}, }, + collections::array::ArrayOpBuilder, logic, }, type_row, @@ -736,15 +738,66 @@ mod test { check_emission, emit::test::SimpleHugrConfig, test::{exec_ctx, llvm_ctx, TestContext}, - utils::{array_op_builder, ArrayOpBuilder, IntOpBuilder, LogicOpBuilder}, + utils::{IntOpBuilder, LogicOpBuilder}, }; + /// Build all array ops + /// Copied from `hugr_core::std_extensions::collections::array::builder::test` + fn all_array_ops(mut builder: B) -> B { + let us0 = builder.add_load_value(ConstUsize::new(0)); + let us1 = builder.add_load_value(ConstUsize::new(1)); + let us2 = builder.add_load_value(ConstUsize::new(2)); + let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); + let [arr] = { + let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap(); + let res_sum_ty = { + let array_type = array_type(2, usize_t()); + either_type(array_type.clone(), array_type) + }; + builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() + }; + + let [elem_0] = { + let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap(); + builder + .build_unwrap_sum(1, option_type(usize_t()), r) + .unwrap() + }; + + let [_elem_1, arr] = { + let r = builder + .add_array_set(usize_t(), 2, arr, us1, elem_0) + .unwrap(); + let res_sum_ty = { + let row = vec![usize_t(), array_type(2, usize_t())]; + either_type(row.clone(), row) + }; + builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() + }; + + let [_elem_left, arr] = { + let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap(); + builder + .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r) + .unwrap() + }; + let [_elem_right, arr] = { + let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap(); + builder + .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r) + .unwrap() + }; + + builder.add_array_discard_empty(usize_t(), arr).unwrap(); + builder + } + #[rstest] fn emit_all_ops(mut llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { - array_op_builder::test::all_array_ops(builder.dfg_builder_endo([]).unwrap()) + all_array_ops(builder.dfg_builder_endo([]).unwrap()) .finish_sub_container() .unwrap(); builder.finish_sub_container().unwrap() diff --git a/hugr-llvm/src/utils.rs b/hugr-llvm/src/utils.rs index 7285aa141..047a12c88 100644 --- a/hugr-llvm/src/utils.rs +++ b/hugr-llvm/src/utils.rs @@ -1,5 +1,6 @@ //! Module for utilities that do not depend on LLVM. These are candidates for //! upstreaming. +#[deprecated(note = "This module is deprecated and will be removed in a future release.")] pub mod array_op_builder; pub mod fat; pub mod inline_constant_functions; @@ -7,7 +8,8 @@ pub mod int_op_builder; pub mod logic_op_builder; pub mod type_map; -pub use array_op_builder::ArrayOpBuilder; +#[deprecated(note = "Import from hugr_core::std_extensions::collections::array.")] +pub use hugr_core::std_extensions::collections::array::ArrayOpBuilder; pub use inline_constant_functions::inline_constant_functions; pub use int_op_builder::IntOpBuilder; pub use logic_op_builder::LogicOpBuilder; diff --git a/hugr-llvm/src/utils/array_op_builder.rs b/hugr-llvm/src/utils/array_op_builder.rs index dfe2faba4..c6bc2922c 100644 --- a/hugr-llvm/src/utils/array_op_builder.rs +++ b/hugr-llvm/src/utils/array_op_builder.rs @@ -1,191 +1,2 @@ -use hugr_core::std_extensions::collections::array::{new_array_op, ArrayOpDef}; -use hugr_core::{ - builder::{BuildError, Dataflow}, - extension::simple_op::HasConcrete as _, - types::Type, - Wire, -}; -use itertools::Itertools as _; - -pub trait ArrayOpBuilder: Dataflow { - fn add_new_array( - &mut self, - elem_ty: Type, - values: impl IntoIterator, - ) -> Result { - let inputs = values.into_iter().collect_vec(); - let [out] = self - .add_dataflow_op(new_array_op(elem_ty, inputs.len() as u64), inputs)? - .outputs_arr(); - Ok(out) - } - - fn add_array_get( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - index: Wire, - ) -> Result { - // TODO Add an OpLoadError variant to BuildError. - let op = ArrayOpDef::get - .instantiate(&[size.into(), elem_ty.into()]) - .unwrap(); - let [out] = self.add_dataflow_op(op, vec![input, index])?.outputs_arr(); - Ok(out) - } - - fn add_array_set( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - index: Wire, - value: Wire, - ) -> Result { - // TODO Add an OpLoadError variant to BuildError - let op = ArrayOpDef::set - .instantiate(&[size.into(), elem_ty.into()]) - .unwrap(); - let [out] = self - .add_dataflow_op(op, vec![input, index, value])? - .outputs_arr(); - Ok(out) - } - - fn add_array_swap( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - index1: Wire, - index2: Wire, - ) -> Result { - // TODO Add an OpLoadError variant to BuildError - let op = ArrayOpDef::swap - .instantiate(&[size.into(), elem_ty.into()]) - .unwrap(); - let [out] = self - .add_dataflow_op(op, vec![input, index1, index2])? - .outputs_arr(); - Ok(out) - } - - fn add_array_pop_left( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - ) -> Result { - // TODO Add an OpLoadError variant to BuildError - let op = ArrayOpDef::pop_left - .instantiate(&[size.into(), elem_ty.into()]) - .unwrap(); - Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0)) - } - - fn add_array_pop_right( - &mut self, - elem_ty: Type, - size: u64, - input: Wire, - ) -> Result { - // TODO Add an OpLoadError variant to BuildError - let op = ArrayOpDef::pop_right - .instantiate(&[size.into(), elem_ty.into()]) - .unwrap(); - Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0)) - } - - fn add_array_discard_empty(&mut self, elem_ty: Type, input: Wire) -> Result<(), BuildError> { - // TODO Add an OpLoadError variant to BuildError - self.add_dataflow_op( - ArrayOpDef::discard_empty - .instantiate(&[elem_ty.into()]) - .unwrap(), - [input], - )?; - Ok(()) - } -} - -impl ArrayOpBuilder for D {} - -#[cfg(test)] -pub mod test { - use hugr_core::extension::prelude::PRELUDE_ID; - use hugr_core::extension::ExtensionSet; - use hugr_core::std_extensions::collections::array::{self, array_type}; - use hugr_core::{ - builder::{DFGBuilder, HugrBuilder}, - extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _}, - types::Signature, - Hugr, - }; - use rstest::rstest; - - use super::*; - - #[rstest::fixture] - #[default(DFGBuilder)] - pub fn all_array_ops( - #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW) - .with_extension_delta(ExtensionSet::from_iter([ - PRELUDE_ID, - array::EXTENSION_ID - ]))).unwrap())] - mut builder: B, - ) -> B { - let us0 = builder.add_load_value(ConstUsize::new(0)); - let us1 = builder.add_load_value(ConstUsize::new(1)); - let us2 = builder.add_load_value(ConstUsize::new(2)); - let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); - let [arr] = { - let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap(); - let res_sum_ty = { - let array_type = array_type(2, usize_t()); - either_type(array_type.clone(), array_type) - }; - builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() - }; - - let [elem_0] = { - let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap(); - builder - .build_unwrap_sum(1, option_type(usize_t()), r) - .unwrap() - }; - - let [_elem_1, arr] = { - let r = builder - .add_array_set(usize_t(), 2, arr, us1, elem_0) - .unwrap(); - let res_sum_ty = { - let row = vec![usize_t(), array_type(2, usize_t())]; - either_type(row.clone(), row) - }; - builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() - }; - - let [_elem_left, arr] = { - let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap(); - builder - .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r) - .unwrap() - }; - let [_elem_right, arr] = { - let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap(); - builder - .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r) - .unwrap() - }; - - builder.add_array_discard_empty(usize_t(), arr).unwrap(); - builder - } - - #[rstest] - fn build_all_ops(all_array_ops: DFGBuilder) { - all_array_ops.finish_hugr().unwrap(); - } -} +#[deprecated(note = "Import from hugr_core::std_extensions::collections::array.")] +pub use hugr_core::std_extensions::collections::array::op_builder::*;