Skip to content

Commit 433a194

Browse files
authored
feat: Add value_array to/from array conversion ops (#2101)
1 parent 26a079c commit 433a194

File tree

5 files changed

+625
-2
lines changed

5 files changed

+625
-2
lines changed

hugr-core/src/std_extensions/collections/array.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Fixed-length array type and operations extension.
22
33
mod array_clone;
4+
mod array_conversion;
45
mod array_discard;
56
mod array_kind;
67
mod array_op;
@@ -24,6 +25,7 @@ use crate::types::{CustomCheckFailure, Type, TypeBound, TypeName};
2425
use crate::Extension;
2526

2627
pub use array_clone::{GenericArrayClone, GenericArrayCloneDef, ARRAY_CLONE_OP_ID};
28+
pub use array_conversion::{Direction, GenericArrayConvert, GenericArrayConvertDef, FROM, INTO};
2729
pub use array_discard::{GenericArrayDiscard, GenericArrayDiscardDef, ARRAY_DISCARD_OP_ID};
2830
pub use array_kind::ArrayKind;
2931
pub use array_op::{GenericArrayOp, GenericArrayOpDef};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
//! Operations for converting between the different array extensions
2+
3+
use std::marker::PhantomData;
4+
use std::str::FromStr;
5+
use std::sync::{Arc, Weak};
6+
7+
use crate::extension::simple_op::{
8+
HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
9+
};
10+
use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef};
11+
use crate::ops::{ExtensionOp, NamedOp, OpName};
12+
use crate::types::type_param::{TypeArg, TypeParam};
13+
use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound};
14+
use crate::Extension;
15+
16+
use super::array_kind::ArrayKind;
17+
18+
/// Array conversion direction.
19+
///
20+
/// Either the current array type [INTO] the other one, or the current array type [FROM] the
21+
/// other one.
22+
pub type Direction = bool;
23+
24+
/// Array conversion direction to turn the current array type [INTO] the other one.
25+
pub const INTO: Direction = true;
26+
27+
/// Array conversion direction to obtain the current array type [FROM] the other one.
28+
pub const FROM: Direction = false;
29+
30+
/// Definition of array conversion operations.
31+
///
32+
/// Generic over the concrete array implementation of the extension containing the operation, as
33+
/// well as over another array implementation that should be converted between. Also generic over
34+
/// the conversion [Direction].
35+
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
36+
pub struct GenericArrayConvertDef<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>(
37+
PhantomData<AK>,
38+
PhantomData<OtherAK>,
39+
);
40+
41+
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>
42+
GenericArrayConvertDef<AK, DIR, OtherAK>
43+
{
44+
/// Creates a new array conversion definition.
45+
pub fn new() -> Self {
46+
GenericArrayConvertDef(PhantomData, PhantomData)
47+
}
48+
}
49+
50+
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> Default
51+
for GenericArrayConvertDef<AK, DIR, OtherAK>
52+
{
53+
fn default() -> Self {
54+
Self::new()
55+
}
56+
}
57+
58+
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> NamedOp
59+
for GenericArrayConvertDef<AK, DIR, OtherAK>
60+
{
61+
fn name(&self) -> OpName {
62+
match DIR {
63+
INTO => format!("to_{}", OtherAK::TYPE_NAME).into(),
64+
FROM => format!("from_{}", OtherAK::TYPE_NAME).into(),
65+
}
66+
}
67+
}
68+
69+
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> FromStr
70+
for GenericArrayConvertDef<AK, DIR, OtherAK>
71+
{
72+
type Err = ();
73+
74+
fn from_str(s: &str) -> Result<Self, Self::Err> {
75+
let def = GenericArrayConvertDef::new();
76+
if s == def.name() {
77+
Ok(def)
78+
} else {
79+
Err(())
80+
}
81+
}
82+
}
83+
84+
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>
85+
GenericArrayConvertDef<AK, DIR, OtherAK>
86+
{
87+
/// To avoid recursion when defining the extension, take the type definition as an argument.
88+
fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc {
89+
let params = vec![TypeParam::max_nat(), TypeBound::Any.into()];
90+
let size = TypeArg::new_var_use(0, TypeParam::max_nat());
91+
let element_ty = Type::new_var_use(1, TypeBound::Any);
92+
93+
let this_ty = AK::instantiate_ty(array_def, size.clone(), element_ty.clone())
94+
.expect("Array type instantiation failed");
95+
let other_ty =
96+
OtherAK::ty_parametric(size, element_ty).expect("Array type instantiation failed");
97+
98+
let sig = match DIR {
99+
INTO => FuncValueType::new(this_ty, other_ty),
100+
FROM => FuncValueType::new(other_ty, this_ty),
101+
};
102+
PolyFuncTypeRV::new(params, sig).into()
103+
}
104+
}
105+
106+
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> MakeOpDef
107+
for GenericArrayConvertDef<AK, DIR, OtherAK>
108+
{
109+
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
110+
where
111+
Self: Sized,
112+
{
113+
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
114+
}
115+
116+
fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
117+
self.signature_from_def(AK::type_def())
118+
}
119+
120+
fn extension_ref(&self) -> Weak<Extension> {
121+
Arc::downgrade(AK::extension())
122+
}
123+
124+
fn extension(&self) -> ExtensionId {
125+
AK::EXTENSION_ID
126+
}
127+
128+
fn description(&self) -> String {
129+
match DIR {
130+
INTO => format!("Turns `{}` into `{}`", AK::TYPE_NAME, OtherAK::TYPE_NAME),
131+
FROM => format!("Turns `{}` into `{}`", OtherAK::TYPE_NAME, AK::TYPE_NAME),
132+
}
133+
}
134+
135+
/// Add an operation implemented as a [MakeOpDef], which can provide the data
136+
/// required to define an [OpDef], to an extension.
137+
//
138+
// This method is re-defined here since we need to pass the array type def while
139+
// computing the signature, to avoid recursive loops initializing the extension.
140+
fn add_to_extension(
141+
&self,
142+
extension: &mut Extension,
143+
extension_ref: &Weak<Extension>,
144+
) -> Result<(), crate::extension::ExtensionBuildError> {
145+
let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap());
146+
let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?;
147+
self.post_opdef(def);
148+
Ok(())
149+
}
150+
}
151+
152+
/// Definition of the array conversion op.
153+
///
154+
/// Generic over the concrete array implementation of the extension containing the operation, as
155+
/// well as over another array implementation that should be converted between. Also generic over
156+
/// the conversion [Direction].
157+
#[derive(Clone, Debug, PartialEq)]
158+
pub struct GenericArrayConvert<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> {
159+
/// The element type of the array.
160+
pub elem_ty: Type,
161+
/// Size of the array.
162+
pub size: u64,
163+
_kind: PhantomData<AK>,
164+
_other_kind: PhantomData<OtherAK>,
165+
}
166+
167+
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind>
168+
GenericArrayConvert<AK, DIR, OtherAK>
169+
{
170+
/// Creates a new array conversion op.
171+
pub fn new(elem_ty: Type, size: u64) -> Self {
172+
GenericArrayConvert {
173+
elem_ty,
174+
size,
175+
_kind: PhantomData,
176+
_other_kind: PhantomData,
177+
}
178+
}
179+
}
180+
181+
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> NamedOp
182+
for GenericArrayConvert<AK, DIR, OtherAK>
183+
{
184+
fn name(&self) -> OpName {
185+
match DIR {
186+
INTO => format!("to_{}", OtherAK::TYPE_NAME).into(),
187+
FROM => format!("from_{}", OtherAK::TYPE_NAME).into(),
188+
}
189+
}
190+
}
191+
192+
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> MakeExtensionOp
193+
for GenericArrayConvert<AK, DIR, OtherAK>
194+
{
195+
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
196+
where
197+
Self: Sized,
198+
{
199+
let def = GenericArrayConvertDef::<AK, DIR, OtherAK>::from_def(ext_op.def())?;
200+
def.instantiate(ext_op.args())
201+
}
202+
203+
fn type_args(&self) -> Vec<TypeArg> {
204+
vec![
205+
TypeArg::BoundedNat { n: self.size },
206+
self.elem_ty.clone().into(),
207+
]
208+
}
209+
}
210+
211+
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> MakeRegisteredOp
212+
for GenericArrayConvert<AK, DIR, OtherAK>
213+
{
214+
fn extension_id(&self) -> ExtensionId {
215+
AK::EXTENSION_ID
216+
}
217+
218+
fn extension_ref(&self) -> Weak<Extension> {
219+
Arc::downgrade(AK::extension())
220+
}
221+
}
222+
223+
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> HasDef
224+
for GenericArrayConvert<AK, DIR, OtherAK>
225+
{
226+
type Def = GenericArrayConvertDef<AK, DIR, OtherAK>;
227+
}
228+
229+
impl<AK: ArrayKind, const DIR: Direction, OtherAK: ArrayKind> HasConcrete
230+
for GenericArrayConvertDef<AK, DIR, OtherAK>
231+
{
232+
type Concrete = GenericArrayConvert<AK, DIR, OtherAK>;
233+
234+
fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
235+
match type_args {
236+
[TypeArg::BoundedNat { n }, TypeArg::Type { ty }] => {
237+
Ok(GenericArrayConvert::new(ty.clone(), *n))
238+
}
239+
_ => Err(SignatureError::InvalidTypeArgs.into()),
240+
}
241+
}
242+
}
243+
244+
#[cfg(test)]
245+
mod tests {
246+
use rstest::rstest;
247+
248+
use crate::extension::prelude::bool_t;
249+
use crate::ops::{OpTrait, OpType};
250+
use crate::std_extensions::collections::array::Array;
251+
use crate::std_extensions::collections::value_array::ValueArray;
252+
253+
use super::*;
254+
255+
#[rstest]
256+
#[case(ValueArray, Array)]
257+
fn test_convert_from_def<AK: ArrayKind, OtherAK: ArrayKind>(
258+
#[case] _kind: AK,
259+
#[case] _other_kind: OtherAK,
260+
) {
261+
let op = GenericArrayConvert::<AK, FROM, OtherAK>::new(bool_t(), 2);
262+
let optype: OpType = op.clone().into();
263+
let new_op: GenericArrayConvert<AK, FROM, OtherAK> = optype.cast().unwrap();
264+
assert_eq!(new_op, op);
265+
}
266+
267+
#[rstest]
268+
#[case(ValueArray, Array)]
269+
fn test_convert_into_def<AK: ArrayKind, OtherAK: ArrayKind>(
270+
#[case] _kind: AK,
271+
#[case] _other_kind: OtherAK,
272+
) {
273+
let op = GenericArrayConvert::<AK, INTO, OtherAK>::new(bool_t(), 2);
274+
let optype: OpType = op.clone().into();
275+
let new_op: GenericArrayConvert<AK, INTO, OtherAK> = optype.cast().unwrap();
276+
assert_eq!(new_op, op);
277+
}
278+
279+
#[rstest]
280+
#[case(ValueArray, Array)]
281+
fn test_convert_from<AK: ArrayKind, OtherAK: ArrayKind>(
282+
#[case] _kind: AK,
283+
#[case] _other_kind: OtherAK,
284+
) {
285+
let size = 2;
286+
let element_ty = bool_t();
287+
let op = GenericArrayConvert::<AK, FROM, OtherAK>::new(element_ty.clone(), size);
288+
let optype: OpType = op.into();
289+
let sig = optype.dataflow_signature().unwrap();
290+
assert_eq!(
291+
sig.io(),
292+
(
293+
&vec![OtherAK::ty(size, element_ty.clone())].into(),
294+
&vec![AK::ty(size, element_ty.clone())].into(),
295+
)
296+
);
297+
}
298+
299+
#[rstest]
300+
#[case(ValueArray, Array)]
301+
fn test_convert_into<AK: ArrayKind, OtherAK: ArrayKind>(
302+
#[case] _kind: AK,
303+
#[case] _other_kind: OtherAK,
304+
) {
305+
let size = 2;
306+
let element_ty = bool_t();
307+
let op = GenericArrayConvert::<AK, INTO, OtherAK>::new(element_ty.clone(), size);
308+
let optype: OpType = op.into();
309+
let sig = optype.dataflow_signature().unwrap();
310+
assert_eq!(
311+
sig.io(),
312+
(
313+
&vec![AK::ty(size, element_ty.clone())].into(),
314+
&vec![OtherAK::ty(size, element_ty.clone())].into(),
315+
)
316+
);
317+
}
318+
}

hugr-core/src/std_extensions/collections/value_array.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ use crate::types::{CustomCheckFailure, Type, TypeBound, TypeName};
1717
use crate::Extension;
1818

1919
use super::array::{
20-
ArrayKind, GenericArrayOp, GenericArrayOpDef, GenericArrayRepeat, GenericArrayRepeatDef,
21-
GenericArrayScan, GenericArrayScanDef, GenericArrayValue,
20+
Array, ArrayKind, GenericArrayConvert, GenericArrayConvertDef, GenericArrayOp,
21+
GenericArrayOpDef, GenericArrayRepeat, GenericArrayRepeatDef, GenericArrayScan,
22+
GenericArrayScanDef, GenericArrayValue, FROM, INTO,
2223
};
2324

2425
/// Reported unique name of the value array type.
@@ -56,13 +57,21 @@ pub type VArrayOpDef = GenericArrayOpDef<ValueArray>;
5657
pub type VArrayRepeatDef = GenericArrayRepeatDef<ValueArray>;
5758
/// Value array scan operation definition.
5859
pub type VArrayScanDef = GenericArrayScanDef<ValueArray>;
60+
/// Value array to default array conversion operation definition.
61+
pub type VArrayToArrayDef = GenericArrayConvertDef<ValueArray, INTO, Array>;
62+
/// Value array from default array conversion operation definition.
63+
pub type VArrayFromArrayDef = GenericArrayConvertDef<ValueArray, FROM, Array>;
5964

6065
/// Value array operations.
6166
pub type VArrayOp = GenericArrayOp<ValueArray>;
6267
/// The value array repeat operation.
6368
pub type VArrayRepeat = GenericArrayRepeat<ValueArray>;
6469
/// The value array scan operation.
6570
pub type VArrayScan = GenericArrayScan<ValueArray>;
71+
/// The value array to default array conversion operation.
72+
pub type VArrayToArray = GenericArrayConvert<ValueArray, INTO, Array>;
73+
/// The value array from default array conversion operation.
74+
pub type VArrayFromArray = GenericArrayConvert<ValueArray, FROM, Array>;
6675

6776
/// A value array extension value.
6877
pub type VArrayValue = GenericArrayValue<ValueArray>;
@@ -84,6 +93,8 @@ lazy_static! {
8493
VArrayOpDef::load_all_ops(extension, extension_ref).unwrap();
8594
VArrayRepeatDef::new().add_to_extension(extension, extension_ref).unwrap();
8695
VArrayScanDef::new().add_to_extension(extension, extension_ref).unwrap();
96+
VArrayToArrayDef::new().add_to_extension(extension, extension_ref).unwrap();
97+
VArrayFromArrayDef::new().add_to_extension(extension, extension_ref).unwrap();
8798
})
8899
};
89100
}

0 commit comments

Comments
 (0)