Skip to content

Commit 9370dc7

Browse files
committed
feat: Add array clone and discard ops
1 parent dc638c4 commit 9370dc7

File tree

5 files changed

+748
-0
lines changed

5 files changed

+748
-0
lines changed

hugr-core/src/std_extensions/collections/array.rs

+14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
//! Fixed-length array type and operations extension.
22
3+
mod array_clone;
4+
mod array_discard;
35
mod array_kind;
46
mod array_op;
57
mod array_repeat;
@@ -20,6 +22,8 @@ use crate::types::type_param::{TypeArg, TypeParam};
2022
use crate::types::{CustomCheckFailure, Type, TypeBound, TypeName};
2123
use crate::Extension;
2224

25+
pub use array_clone::{GenericArrayClone, GenericArrayCloneDef};
26+
pub use array_discard::{GenericArrayDiscard, GenericArrayDiscardDef};
2327
pub use array_kind::ArrayKind;
2428
pub use array_op::{GenericArrayOp, GenericArrayOpDef};
2529
pub use array_repeat::{GenericArrayRepeat, GenericArrayRepeatDef, ARRAY_REPEAT_OP_ID};
@@ -57,13 +61,21 @@ impl ArrayKind for Array {
5761

5862
/// Array operation definitions.
5963
pub type ArrayOpDef = GenericArrayOpDef<Array>;
64+
/// Array clone operation definition.
65+
pub type ArrayCloneDef = GenericArrayCloneDef<Array>;
66+
/// Array discard operation definition.
67+
pub type ArrayDiscardDef = GenericArrayDiscardDef<Array>;
6068
/// Array repeat operation definition.
6169
pub type ArrayRepeatDef = GenericArrayRepeatDef<Array>;
6270
/// Array scan operation definition.
6371
pub type ArrayScanDef = GenericArrayScanDef<Array>;
6472

6573
/// Array operations.
6674
pub type ArrayOp = GenericArrayOp<Array>;
75+
/// The array clone operation.
76+
pub type ArrayClone = GenericArrayRepeat<Array>;
77+
/// The array discard operation.
78+
pub type ArrayDiscard = GenericArrayRepeat<Array>;
6779
/// The array repeat operation.
6880
pub type ArrayRepeat = GenericArrayRepeat<Array>;
6981
/// The array scan operation.
@@ -87,6 +99,8 @@ lazy_static! {
8799
.unwrap();
88100

89101
ArrayOpDef::load_all_ops(extension, extension_ref).unwrap();
102+
ArrayCloneDef::new().add_to_extension(extension, extension_ref).unwrap();
103+
ArrayDiscardDef::new().add_to_extension(extension, extension_ref).unwrap();
90104
ArrayRepeatDef::new().add_to_extension(extension, extension_ref).unwrap();
91105
ArrayScanDef::new().add_to_extension(extension, extension_ref).unwrap();
92106
})
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
//! Definition of the array clone operation.
2+
3+
use std::marker::PhantomData;
4+
use std::str::FromStr;
5+
use std::sync::{Arc, Weak};
6+
7+
use crate::extension::simple_op::{
8+
HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError,
9+
};
10+
use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef};
11+
use crate::ops::{ExtensionOp, NamedOp, OpName};
12+
use crate::types::type_param::{TypeArg, TypeParam};
13+
use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound};
14+
use crate::Extension;
15+
16+
use super::array_kind::ArrayKind;
17+
18+
/// Name of the operation to clone an array
19+
pub const ARRAY_CLONE_OP_ID: OpName = OpName::new_inline("clone");
20+
21+
/// Definition of the array clone operation. Generic over the concrete array implementation.
22+
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
23+
pub struct GenericArrayCloneDef<AK: ArrayKind>(PhantomData<AK>);
24+
25+
impl<AK: ArrayKind> GenericArrayCloneDef<AK> {
26+
/// Creates a new clone operation definition.
27+
pub fn new() -> Self {
28+
GenericArrayCloneDef(PhantomData)
29+
}
30+
}
31+
32+
impl<AK: ArrayKind> Default for GenericArrayCloneDef<AK> {
33+
fn default() -> Self {
34+
Self::new()
35+
}
36+
}
37+
38+
impl<AK: ArrayKind> NamedOp for GenericArrayCloneDef<AK> {
39+
fn name(&self) -> OpName {
40+
ARRAY_CLONE_OP_ID
41+
}
42+
}
43+
44+
impl<AK: ArrayKind> FromStr for GenericArrayCloneDef<AK> {
45+
type Err = ();
46+
47+
fn from_str(s: &str) -> Result<Self, Self::Err> {
48+
if s == ARRAY_CLONE_OP_ID {
49+
Ok(GenericArrayCloneDef::new())
50+
} else {
51+
Err(())
52+
}
53+
}
54+
}
55+
56+
impl<AK: ArrayKind> GenericArrayCloneDef<AK> {
57+
/// To avoid recursion when defining the extension, take the type definition as an argument.
58+
fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc {
59+
let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()];
60+
let size = TypeArg::new_var_use(0, TypeParam::max_nat());
61+
let element_ty = Type::new_var_use(1, TypeBound::Copyable);
62+
let array_ty = AK::instantiate_ty(array_def, size, element_ty)
63+
.expect("Array type instantiation failed");
64+
PolyFuncTypeRV::new(
65+
params,
66+
FuncValueType::new(array_ty.clone(), vec![array_ty; 2]),
67+
)
68+
.into()
69+
}
70+
}
71+
72+
impl<AK: ArrayKind> MakeOpDef for GenericArrayCloneDef<AK> {
73+
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
74+
where
75+
Self: Sized,
76+
{
77+
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id())
78+
}
79+
80+
fn init_signature(&self, _extension_ref: &Weak<Extension>) -> SignatureFunc {
81+
self.signature_from_def(AK::type_def())
82+
}
83+
84+
fn extension_ref(&self) -> Weak<Extension> {
85+
Arc::downgrade(AK::extension())
86+
}
87+
88+
fn extension(&self) -> ExtensionId {
89+
AK::EXTENSION_ID
90+
}
91+
92+
fn description(&self) -> String {
93+
"Clones an array with copyable elements".into()
94+
}
95+
96+
/// Add an operation implemented as a [MakeOpDef], which can provide the data
97+
/// required to define an [OpDef], to an extension.
98+
//
99+
// This method is re-defined here since we need to pass the array type def while
100+
// computing the signature, to avoid recursive loops initializing the extension.
101+
fn add_to_extension(
102+
&self,
103+
extension: &mut Extension,
104+
extension_ref: &Weak<Extension>,
105+
) -> Result<(), crate::extension::ExtensionBuildError> {
106+
let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap());
107+
let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?;
108+
self.post_opdef(def);
109+
Ok(())
110+
}
111+
}
112+
113+
/// Definition of the array clone op. Generic over the concrete array implementation.
114+
#[derive(Clone, Debug, PartialEq)]
115+
pub struct GenericArrayClone<AK: ArrayKind> {
116+
/// The element type of the array.
117+
pub elem_ty: Type,
118+
/// Size of the array.
119+
pub size: u64,
120+
_kind: PhantomData<AK>,
121+
}
122+
123+
impl<AK: ArrayKind> GenericArrayClone<AK> {
124+
/// Creates a new array clone op.
125+
pub fn new(elem_ty: Type, size: u64) -> Option<Self> {
126+
elem_ty.copyable().then_some(GenericArrayClone {
127+
elem_ty,
128+
size,
129+
_kind: PhantomData,
130+
})
131+
}
132+
}
133+
134+
impl<AK: ArrayKind> NamedOp for GenericArrayClone<AK> {
135+
fn name(&self) -> OpName {
136+
ARRAY_CLONE_OP_ID
137+
}
138+
}
139+
140+
impl<AK: ArrayKind> MakeExtensionOp for GenericArrayClone<AK> {
141+
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
142+
where
143+
Self: Sized,
144+
{
145+
let def = GenericArrayCloneDef::<AK>::from_def(ext_op.def())?;
146+
def.instantiate(ext_op.args())
147+
}
148+
149+
fn type_args(&self) -> Vec<TypeArg> {
150+
vec![
151+
TypeArg::BoundedNat { n: self.size },
152+
self.elem_ty.clone().into(),
153+
]
154+
}
155+
}
156+
157+
impl<AK: ArrayKind> MakeRegisteredOp for GenericArrayClone<AK> {
158+
fn extension_id(&self) -> ExtensionId {
159+
AK::EXTENSION_ID
160+
}
161+
162+
fn extension_ref(&self) -> Weak<Extension> {
163+
Arc::downgrade(AK::extension())
164+
}
165+
}
166+
167+
impl<AK: ArrayKind> HasDef for GenericArrayClone<AK> {
168+
type Def = GenericArrayCloneDef<AK>;
169+
}
170+
171+
impl<AK: ArrayKind> HasConcrete for GenericArrayCloneDef<AK> {
172+
type Concrete = GenericArrayClone<AK>;
173+
174+
fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
175+
match type_args {
176+
[TypeArg::BoundedNat { n }, TypeArg::Type { ty }] if ty.copyable() => {
177+
Ok(GenericArrayClone::new(ty.clone(), *n).unwrap())
178+
}
179+
_ => Err(SignatureError::InvalidTypeArgs.into()),
180+
}
181+
}
182+
}
183+
184+
#[cfg(test)]
185+
mod tests {
186+
use rstest::rstest;
187+
188+
use crate::extension::prelude::bool_t;
189+
use crate::std_extensions::collections::array::Array;
190+
use crate::{
191+
extension::prelude::qb_t,
192+
ops::{OpTrait, OpType},
193+
};
194+
195+
use super::*;
196+
197+
#[rstest]
198+
#[case(Array)]
199+
fn test_clone_def<AK: ArrayKind>(#[case] _kind: AK) {
200+
let op = GenericArrayClone::<AK>::new(bool_t(), 2).unwrap();
201+
let optype: OpType = op.clone().into();
202+
let new_op: GenericArrayClone<AK> = optype.cast().unwrap();
203+
assert_eq!(new_op, op);
204+
205+
assert_eq!(GenericArrayClone::<AK>::new(qb_t(), 2), None);
206+
}
207+
208+
#[rstest]
209+
#[case(Array)]
210+
fn test_clone<AK: ArrayKind>(#[case] _kind: AK) {
211+
let size = 2;
212+
let element_ty = bool_t();
213+
let op = GenericArrayClone::<AK>::new(element_ty.clone(), size).unwrap();
214+
let optype: OpType = op.into();
215+
let sig = optype.dataflow_signature().unwrap();
216+
assert_eq!(
217+
sig.io(),
218+
(
219+
&vec![AK::ty(size, element_ty.clone())].into(),
220+
&vec![AK::ty(size, element_ty.clone()); 2].into(),
221+
)
222+
);
223+
}
224+
}

0 commit comments

Comments
 (0)