Skip to content

Commit 8c6d342

Browse files
authored
feat!: Make array.get give back the passed array (#2110)
Since the array type is now linear following #2097, we need to change the signature of the `array.get` op to give the passed array back to the user. BREAKING CHANGE: `std.collections.array.get` now also returns the passed array as an extra output
1 parent 433a194 commit 8c6d342

File tree

12 files changed

+293
-171
lines changed

12 files changed

+293
-171
lines changed

hugr-core/src/std_extensions/collections/array/array_op.rs

+15-4
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub enum GenericArrayOpDef<AK: ArrayKind> {
3131
/// where `SIZE` must be statically known (not a variable)
3232
new_array,
3333
/// Copies an element out of the array ([TypeBound::Copyable] elements only):
34-
/// `get<size,elemty>: array<size, elemty>, index -> option<elemty>`
34+
/// `get<size,elemty>: array<size, elemty>, index -> option<elemty>, array`
3535
get,
3636
/// Exchanges an element of the array with an external value:
3737
/// `set<size, elemty>: array<size, elemty>, index, elemty -> either(elemty, array | elemty, array)`
@@ -148,7 +148,10 @@ impl<AK: ArrayKind> GenericArrayOpDef<AK> {
148148
let option_type: Type = option_type(copy_elem_ty).into();
149149
PolyFuncTypeRV::new(
150150
params,
151-
FuncValueType::new(vec![copy_array_ty, usize_t], option_type),
151+
FuncValueType::new(
152+
vec![copy_array_ty.clone(), usize_t],
153+
vec![option_type, copy_array_ty],
154+
),
152155
)
153156
}
154157
set => {
@@ -385,7 +388,11 @@ mod tests {
385388
sig.io(),
386389
(
387390
&vec![AK::ty(size, element_ty.clone()), usize_t()].into(),
388-
&vec![option_type(element_ty.clone()).into()].into()
391+
&vec![
392+
option_type(element_ty.clone()).into(),
393+
AK::ty(size, element_ty.clone())
394+
]
395+
.into()
389396
)
390397
);
391398
}
@@ -500,7 +507,11 @@ mod tests {
500507
sig.io(),
501508
(
502509
&vec![AK::ty(size, element_ty.clone()), usize_t()].into(),
503-
&vec![option_type(element_ty.clone()).into()].into()
510+
&vec![
511+
option_type(element_ty.clone()).into(),
512+
AK::ty(size, element_ty.clone())
513+
]
514+
.into()
504515
)
505516
);
506517
}

hugr-core/src/std_extensions/collections/array/op_builder.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,18 @@ pub trait ArrayOpBuilder<AK: ArrayKind>: Dataflow {
5959
///
6060
/// # Returns
6161
///
62-
/// The wire representing the value at the specified index in the array.
62+
/// * The wire representing the value at the specified index in the array
63+
/// * The wire representing the array
6364
fn add_array_get(
6465
&mut self,
6566
elem_ty: Type,
6667
size: u64,
6768
input: Wire,
6869
index: Wire,
69-
) -> Result<Wire, BuildError> {
70+
) -> Result<(Wire, Wire), BuildError> {
7071
let op = GenericArrayOpDef::<AK>::get.instantiate(&[size.into(), elem_ty.into()])?;
71-
let [out] = self.add_dataflow_op(op, vec![input, index])?.outputs_arr();
72-
Ok(out)
72+
let [out, arr] = self.add_dataflow_op(op, vec![input, index])?.outputs_arr();
73+
Ok((out, arr))
7374
}
7475

7576
/// Adds an array set operation to the dataflow graph.
@@ -256,7 +257,7 @@ mod test {
256257
};
257258

258259
let [elem_0] = {
259-
let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap();
260+
let (r, _) = builder.add_array_get(usize_t(), 2, arr, us0).unwrap();
260261
builder
261262
.build_unwrap_sum(1, option_type(usize_t()), r)
262263
.unwrap()

hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__value_array__test__emit_all_ops@llvm14.snap

+16-16
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,12 @@ cond_28_case_1: ; preds = %37
114114
br label %cond_exit_28
115115

116116
cond_exit_28: ; preds = %cond_28_case_1, %cond_28_case_0
117-
%"023.0" = phi i64 [ 0, %cond_28_case_0 ], [ %38, %cond_28_case_1 ]
117+
%"024.0" = phi i64 [ 0, %cond_28_case_0 ], [ %38, %cond_28_case_1 ]
118118
%42 = icmp ult i64 1, 2
119119
br i1 %42, label %46, label %43
120120

121121
43: ; preds = %cond_exit_28
122-
%44 = insertvalue { i1, i64, [2 x i64] } { i1 false, i64 poison, [2 x i64] poison }, i64 %"023.0", 1
122+
%44 = insertvalue { i1, i64, [2 x i64] } { i1 false, i64 poison, [2 x i64] poison }, i64 %"024.0", 1
123123
%45 = insertvalue { i1, i64, [2 x i64] } %44, [2 x i64] %"08.0", 2
124124
br label %55
125125

@@ -129,28 +129,28 @@ cond_exit_28: ; preds = %cond_28_case_1, %co
129129
store [2 x i64] %"08.0", [2 x i64]* %48, align 4
130130
%49 = getelementptr inbounds i64, i64* %47, i64 1
131131
%50 = load i64, i64* %49, align 4
132-
store i64 %"023.0", i64* %49, align 4
132+
store i64 %"024.0", i64* %49, align 4
133133
%51 = bitcast i64* %47 to [2 x i64]*
134134
%52 = load [2 x i64], [2 x i64]* %51, align 4
135135
%53 = insertvalue { i1, i64, [2 x i64] } { i1 true, i64 poison, [2 x i64] poison }, i64 %50, 1
136136
%54 = insertvalue { i1, i64, [2 x i64] } %53, [2 x i64] %52, 2
137137
br label %55
138138

139139
55: ; preds = %43, %46
140-
%"033.0" = phi { i1, i64, [2 x i64] } [ %54, %46 ], [ %45, %43 ]
141-
%56 = extractvalue { i1, i64, [2 x i64] } %"033.0", 0
140+
%"034.0" = phi { i1, i64, [2 x i64] } [ %54, %46 ], [ %45, %43 ]
141+
%56 = extractvalue { i1, i64, [2 x i64] } %"034.0", 0
142142
switch i1 %56, label %57 [
143143
i1 true, label %60
144144
]
145145

146146
57: ; preds = %55
147-
%58 = extractvalue { i1, i64, [2 x i64] } %"033.0", 1
148-
%59 = extractvalue { i1, i64, [2 x i64] } %"033.0", 2
147+
%58 = extractvalue { i1, i64, [2 x i64] } %"034.0", 1
148+
%59 = extractvalue { i1, i64, [2 x i64] } %"034.0", 2
149149
br label %cond_40_case_0
150150

151151
60: ; preds = %55
152-
%61 = extractvalue { i1, i64, [2 x i64] } %"033.0", 1
153-
%62 = extractvalue { i1, i64, [2 x i64] } %"033.0", 2
152+
%61 = extractvalue { i1, i64, [2 x i64] } %"034.0", 1
153+
%62 = extractvalue { i1, i64, [2 x i64] } %"034.0", 2
154154
br label %cond_40_case_1
155155

156156
cond_40_case_0: ; preds = %57
@@ -164,11 +164,11 @@ cond_40_case_1: ; preds = %60
164164
br label %cond_exit_40
165165

166166
cond_exit_40: ; preds = %cond_40_case_1, %cond_40_case_0
167-
%"036.0" = phi i64 [ 0, %cond_40_case_0 ], [ %61, %cond_40_case_1 ]
168-
%"1.0" = phi [2 x i64] [ zeroinitializer, %cond_40_case_0 ], [ %62, %cond_40_case_1 ]
167+
%"037.0" = phi i64 [ 0, %cond_40_case_0 ], [ %61, %cond_40_case_1 ]
168+
%"138.0" = phi [2 x i64] [ zeroinitializer, %cond_40_case_0 ], [ %62, %cond_40_case_1 ]
169169
%66 = alloca i64, i32 2, align 8
170170
%67 = bitcast i64* %66 to [2 x i64]*
171-
store [2 x i64] %"1.0", [2 x i64]* %67, align 4
171+
store [2 x i64] %"138.0", [2 x i64]* %67, align 4
172172
%68 = getelementptr i64, i64* %66, i32 1
173173
%69 = load i64, i64* %66, align 4
174174
%70 = bitcast i64* %68 to [1 x i64]*
@@ -199,11 +199,11 @@ cond_51_case_1: ; preds = %76
199199
br label %cond_exit_51
200200

201201
cond_exit_51: ; preds = %cond_51_case_1, %cond_51_case_0
202-
%"056.0" = phi i64 [ 0, %cond_51_case_0 ], [ %77, %cond_51_case_1 ]
203-
%"157.0" = phi [1 x i64] [ zeroinitializer, %cond_51_case_0 ], [ %78, %cond_51_case_1 ]
202+
%"058.0" = phi i64 [ 0, %cond_51_case_0 ], [ %77, %cond_51_case_1 ]
203+
%"159.0" = phi [1 x i64] [ zeroinitializer, %cond_51_case_0 ], [ %78, %cond_51_case_1 ]
204204
%82 = alloca i64, align 8
205205
%83 = bitcast i64* %82 to [1 x i64]*
206-
store [1 x i64] %"157.0", [1 x i64]* %83, align 4
206+
store [1 x i64] %"159.0", [1 x i64]* %83, align 4
207207
%84 = getelementptr i64, i64* %82, i32 0
208208
%85 = load i64, i64* %84, align 4
209209
%86 = bitcast i64* %82 to [0 x i64]*
@@ -232,7 +232,7 @@ cond_62_case_1: ; preds = %91
232232
br label %cond_exit_62
233233

234234
cond_exit_62: ; preds = %cond_62_case_1, %cond_62_case_0
235-
%"071.0" = phi i64 [ 0, %cond_62_case_0 ], [ %92, %cond_62_case_1 ]
235+
%"073.0" = phi i64 [ 0, %cond_62_case_0 ], [ %92, %cond_62_case_1 ]
236236
ret void
237237
}
238238

0 commit comments

Comments
 (0)