Skip to content

Commit c8090ca

Browse files
authored
feat: move ArrayOpBuilder to hugr-core (#2115)
Not doing ints or logic atm because those builders are not comprehensive. Closes #2116
1 parent 2222b8c commit c8090ca

File tree

5 files changed

+353
-194
lines changed

5 files changed

+353
-194
lines changed

hugr-core/src/std_extensions/collections/array.rs

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
mod array_op;
44
mod array_repeat;
55
mod array_scan;
6+
pub mod op_builder;
67

78
use std::sync::Arc;
89

@@ -26,6 +27,7 @@ use crate::Extension;
2627
pub use array_op::{ArrayOp, ArrayOpDef, ArrayOpDefIter};
2728
pub use array_repeat::{ArrayRepeat, ArrayRepeatDef, ARRAY_REPEAT_OP_ID};
2829
pub use array_scan::{ArrayScan, ArrayScanDef, ARRAY_SCAN_OP_ID};
30+
pub use op_builder::ArrayOpBuilder;
2931

3032
/// Reported unique name of the array type.
3133
pub const ARRAY_TYPENAME: TypeName = TypeName::new_inline("array");
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
//! Builder trait for array operations in the dataflow graph.
2+
3+
use crate::std_extensions::collections::array::{new_array_op, ArrayOpDef};
4+
use crate::{
5+
builder::{BuildError, Dataflow},
6+
extension::simple_op::HasConcrete as _,
7+
types::Type,
8+
Wire,
9+
};
10+
use itertools::Itertools as _;
11+
12+
/// Trait for building array operations in a dataflow graph.
13+
pub trait ArrayOpBuilder: Dataflow {
14+
/// Adds a new array operation to the dataflow graph and return the wire
15+
/// representing the new array.
16+
///
17+
/// # Arguments
18+
///
19+
/// * `elem_ty` - The type of the elements in the array.
20+
/// * `values` - An iterator over the values to initialize the array with.
21+
///
22+
/// # Errors
23+
///
24+
/// If building the operation fails.
25+
///
26+
/// # Returns
27+
///
28+
/// The wire representing the new array.
29+
fn add_new_array(
30+
&mut self,
31+
elem_ty: Type,
32+
values: impl IntoIterator<Item = Wire>,
33+
) -> Result<Wire, BuildError> {
34+
let inputs = values.into_iter().collect_vec();
35+
let [out] = self
36+
.add_dataflow_op(new_array_op(elem_ty, inputs.len() as u64), inputs)?
37+
.outputs_arr();
38+
Ok(out)
39+
}
40+
41+
/// Adds an array get operation to the dataflow graph.
42+
///
43+
/// # Arguments
44+
///
45+
/// * `elem_ty` - The type of the elements in the array.
46+
/// * `size` - The size of the array.
47+
/// * `input` - The wire representing the array.
48+
/// * `index` - The wire representing the index to get.
49+
///
50+
/// # Errors
51+
///
52+
/// If building the operation fails.
53+
///
54+
/// # Returns
55+
///
56+
/// The wire representing the value at the specified index in the array.
57+
fn add_array_get(
58+
&mut self,
59+
elem_ty: Type,
60+
size: u64,
61+
input: Wire,
62+
index: Wire,
63+
) -> Result<Wire, BuildError> {
64+
let op = ArrayOpDef::get.instantiate(&[size.into(), elem_ty.into()])?;
65+
let [out] = self.add_dataflow_op(op, vec![input, index])?.outputs_arr();
66+
Ok(out)
67+
}
68+
69+
/// Adds an array set operation to the dataflow graph.
70+
///
71+
/// This operation sets the value at a specified index in the array.
72+
///
73+
/// # Arguments
74+
///
75+
/// * `elem_ty` - The type of the elements in the array.
76+
/// * `size` - The size of the array.
77+
/// * `input` - The wire representing the array.
78+
/// * `index` - The wire representing the index to set.
79+
/// * `value` - The wire representing the value to set at the specified index.
80+
///
81+
/// # Errors
82+
///
83+
/// Returns an error if building the operation fails.
84+
///
85+
/// # Returns
86+
///
87+
/// The wire representing the updated array after the set operation.
88+
fn add_array_set(
89+
&mut self,
90+
elem_ty: Type,
91+
size: u64,
92+
input: Wire,
93+
index: Wire,
94+
value: Wire,
95+
) -> Result<Wire, BuildError> {
96+
let op = ArrayOpDef::set.instantiate(&[size.into(), elem_ty.into()])?;
97+
let [out] = self
98+
.add_dataflow_op(op, vec![input, index, value])?
99+
.outputs_arr();
100+
Ok(out)
101+
}
102+
103+
/// Adds an array swap operation to the dataflow graph.
104+
///
105+
/// This operation swaps the values at two specified indices in the array.
106+
///
107+
/// # Arguments
108+
///
109+
/// * `elem_ty` - The type of the elements in the array.
110+
/// * `size` - The size of the array.
111+
/// * `input` - The wire representing the array.
112+
/// * `index1` - The wire representing the first index to swap.
113+
/// * `index2` - The wire representing the second index to swap.
114+
///
115+
/// # Errors
116+
///
117+
/// Returns an error if building the operation fails.
118+
///
119+
/// # Returns
120+
///
121+
/// The wire representing the updated array after the swap operation.
122+
fn add_array_swap(
123+
&mut self,
124+
elem_ty: Type,
125+
size: u64,
126+
input: Wire,
127+
index1: Wire,
128+
index2: Wire,
129+
) -> Result<Wire, BuildError> {
130+
let op = ArrayOpDef::swap.instantiate(&[size.into(), elem_ty.into()])?;
131+
let [out] = self
132+
.add_dataflow_op(op, vec![input, index1, index2])?
133+
.outputs_arr();
134+
Ok(out)
135+
}
136+
137+
/// Adds an array pop-left operation to the dataflow graph.
138+
///
139+
/// This operation removes the leftmost element from the array.
140+
///
141+
/// # Arguments
142+
///
143+
/// * `elem_ty` - The type of the elements in the array.
144+
/// * `size` - The size of the array.
145+
/// * `input` - The wire representing the array.
146+
///
147+
/// # Errors
148+
///
149+
/// Returns an error if building the operation fails.
150+
///
151+
/// # Returns
152+
///
153+
/// The wire representing the Option<elemty, array<SIZE-1, elemty>>
154+
fn add_array_pop_left(
155+
&mut self,
156+
elem_ty: Type,
157+
size: u64,
158+
input: Wire,
159+
) -> Result<Wire, BuildError> {
160+
let op = ArrayOpDef::pop_left.instantiate(&[size.into(), elem_ty.into()])?;
161+
Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0))
162+
}
163+
164+
/// Adds an array pop-right operation to the dataflow graph.
165+
///
166+
/// This operation removes the rightmost element from the array.
167+
///
168+
/// # Arguments
169+
///
170+
/// * `elem_ty` - The type of the elements in the array.
171+
/// * `size` - The size of the array.
172+
/// * `input` - The wire representing the array.
173+
///
174+
/// # Errors
175+
///
176+
/// Returns an error if building the operation fails.
177+
///
178+
/// # Returns
179+
///
180+
/// The wire representing the Option<elemty, array<SIZE-1, elemty>>
181+
fn add_array_pop_right(
182+
&mut self,
183+
elem_ty: Type,
184+
size: u64,
185+
input: Wire,
186+
) -> Result<Wire, BuildError> {
187+
let op = ArrayOpDef::pop_right.instantiate(&[size.into(), elem_ty.into()])?;
188+
Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0))
189+
}
190+
191+
/// Adds an operation to discard an empty array from the dataflow graph.
192+
///
193+
/// # Arguments
194+
///
195+
/// * `elem_ty` - The type of the elements in the array.
196+
/// * `input` - The wire representing the array.
197+
///
198+
/// # Errors
199+
///
200+
/// Returns an error if building the operation fails.
201+
fn add_array_discard_empty(&mut self, elem_ty: Type, input: Wire) -> Result<(), BuildError> {
202+
self.add_dataflow_op(
203+
ArrayOpDef::discard_empty
204+
.instantiate(&[elem_ty.into()])
205+
.unwrap(),
206+
[input],
207+
)?;
208+
Ok(())
209+
}
210+
}
211+
212+
impl<D: Dataflow> ArrayOpBuilder for D {}
213+
214+
#[cfg(test)]
215+
mod test {
216+
use crate::extension::prelude::PRELUDE_ID;
217+
use crate::extension::ExtensionSet;
218+
use crate::std_extensions::collections::array::{self, array_type};
219+
use crate::{
220+
builder::{DFGBuilder, HugrBuilder},
221+
extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _},
222+
types::Signature,
223+
Hugr,
224+
};
225+
use rstest::rstest;
226+
227+
use super::*;
228+
229+
#[rstest::fixture]
230+
#[default(DFGBuilder<Hugr>)]
231+
fn all_array_ops<B: Dataflow>(
232+
#[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW)
233+
.with_extension_delta(ExtensionSet::from_iter([
234+
PRELUDE_ID,
235+
array::EXTENSION_ID
236+
]))).unwrap())]
237+
mut builder: B,
238+
) -> B {
239+
let us0 = builder.add_load_value(ConstUsize::new(0));
240+
let us1 = builder.add_load_value(ConstUsize::new(1));
241+
let us2 = builder.add_load_value(ConstUsize::new(2));
242+
let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap();
243+
let [arr] = {
244+
let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap();
245+
let res_sum_ty = {
246+
let array_type = array_type(2, usize_t());
247+
either_type(array_type.clone(), array_type)
248+
};
249+
builder.build_unwrap_sum(1, res_sum_ty, r).unwrap()
250+
};
251+
252+
let [elem_0] = {
253+
let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap();
254+
builder
255+
.build_unwrap_sum(1, option_type(usize_t()), r)
256+
.unwrap()
257+
};
258+
259+
let [_elem_1, arr] = {
260+
let r = builder
261+
.add_array_set(usize_t(), 2, arr, us1, elem_0)
262+
.unwrap();
263+
let res_sum_ty = {
264+
let row = vec![usize_t(), array_type(2, usize_t())];
265+
either_type(row.clone(), row)
266+
};
267+
builder.build_unwrap_sum(1, res_sum_ty, r).unwrap()
268+
};
269+
270+
let [_elem_left, arr] = {
271+
let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap();
272+
builder
273+
.build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r)
274+
.unwrap()
275+
};
276+
let [_elem_right, arr] = {
277+
let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap();
278+
builder
279+
.build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r)
280+
.unwrap()
281+
};
282+
283+
builder.add_array_discard_empty(usize_t(), arr).unwrap();
284+
builder
285+
}
286+
287+
#[rstest]
288+
fn build_all_ops(all_array_ops: DFGBuilder<Hugr>) {
289+
all_array_ops.finish_hugr().unwrap();
290+
}
291+
}

hugr-llvm/src/extension/collections/array.rs

+55-2
Original file line numberDiff line numberDiff line change
@@ -707,6 +707,7 @@ pub fn emit_scan_op<'c, H: HugrView<Node = Node>>(
707707
#[cfg(test)]
708708
mod test {
709709
use hugr_core::builder::Container as _;
710+
use hugr_core::extension::prelude::either_type;
710711
use hugr_core::extension::ExtensionSet;
711712
use hugr_core::ops::Tag;
712713
use hugr_core::std_extensions::collections::array::{self, array_type, ArrayRepeat, ArrayScan};
@@ -724,6 +725,7 @@ mod test {
724725
int_ops::{self},
725726
int_types::{self, int_type, ConstInt},
726727
},
728+
collections::array::ArrayOpBuilder,
727729
logic,
728730
},
729731
type_row,
@@ -736,15 +738,66 @@ mod test {
736738
check_emission,
737739
emit::test::SimpleHugrConfig,
738740
test::{exec_ctx, llvm_ctx, TestContext},
739-
utils::{array_op_builder, ArrayOpBuilder, IntOpBuilder, LogicOpBuilder},
741+
utils::{IntOpBuilder, LogicOpBuilder},
740742
};
741743

744+
/// Build all array ops
745+
/// Copied from `hugr_core::std_extensions::collections::array::builder::test`
746+
fn all_array_ops<B: Dataflow>(mut builder: B) -> B {
747+
let us0 = builder.add_load_value(ConstUsize::new(0));
748+
let us1 = builder.add_load_value(ConstUsize::new(1));
749+
let us2 = builder.add_load_value(ConstUsize::new(2));
750+
let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap();
751+
let [arr] = {
752+
let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap();
753+
let res_sum_ty = {
754+
let array_type = array_type(2, usize_t());
755+
either_type(array_type.clone(), array_type)
756+
};
757+
builder.build_unwrap_sum(1, res_sum_ty, r).unwrap()
758+
};
759+
760+
let [elem_0] = {
761+
let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap();
762+
builder
763+
.build_unwrap_sum(1, option_type(usize_t()), r)
764+
.unwrap()
765+
};
766+
767+
let [_elem_1, arr] = {
768+
let r = builder
769+
.add_array_set(usize_t(), 2, arr, us1, elem_0)
770+
.unwrap();
771+
let res_sum_ty = {
772+
let row = vec![usize_t(), array_type(2, usize_t())];
773+
either_type(row.clone(), row)
774+
};
775+
builder.build_unwrap_sum(1, res_sum_ty, r).unwrap()
776+
};
777+
778+
let [_elem_left, arr] = {
779+
let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap();
780+
builder
781+
.build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r)
782+
.unwrap()
783+
};
784+
let [_elem_right, arr] = {
785+
let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap();
786+
builder
787+
.build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r)
788+
.unwrap()
789+
};
790+
791+
builder.add_array_discard_empty(usize_t(), arr).unwrap();
792+
builder
793+
}
794+
742795
#[rstest]
743796
fn emit_all_ops(mut llvm_ctx: TestContext) {
744797
let hugr = SimpleHugrConfig::new()
745798
.with_extensions(STD_REG.to_owned())
746799
.finish(|mut builder| {
747-
array_op_builder::test::all_array_ops(builder.dfg_builder_endo([]).unwrap())
800+
all_array_ops(builder.dfg_builder_endo([]).unwrap())
748801
.finish_sub_container()
749802
.unwrap();
750803
builder.finish_sub_container().unwrap()

0 commit comments

Comments
 (0)