Skip to content

Commit 9afc38a

Browse files
authored
fixes PyO3#4285 -- allow full-path to pymodule with nested declarative modules (PyO3#4288)
1 parent 5860c4f commit 9afc38a

3 files changed

Lines changed: 94 additions & 6 deletions

File tree

newsfragments/4288.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
allow `#[pyo3::prelude::pymodule]` with nested declarative modules

pyo3-macros-backend/src/module.rs

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::{
88
get_doc,
99
pyclass::PyClassPyO3Option,
1010
pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
11-
utils::{Ctx, LitCStr},
11+
utils::{Ctx, LitCStr, PyO3CratePath},
1212
};
1313
use proc_macro2::{Span, TokenStream};
1414
use quote::quote;
@@ -183,7 +183,18 @@ pub fn pymodule_module_impl(
183183
);
184184
ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified");
185185
pymodule_init = Some(quote! { #ident(module)?; });
186-
} else if has_attribute(&item_fn.attrs, "pyfunction") {
186+
} else if has_attribute(&item_fn.attrs, "pyfunction")
187+
|| has_attribute_with_namespace(
188+
&item_fn.attrs,
189+
Some(pyo3_path),
190+
&["pyfunction"],
191+
)
192+
|| has_attribute_with_namespace(
193+
&item_fn.attrs,
194+
Some(pyo3_path),
195+
&["prelude", "pyfunction"],
196+
)
197+
{
187198
module_items.push(ident.clone());
188199
module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs));
189200
}
@@ -193,7 +204,18 @@ pub fn pymodule_module_impl(
193204
!has_attribute(&item_struct.attrs, "pymodule_export"),
194205
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
195206
);
196-
if has_attribute(&item_struct.attrs, "pyclass") {
207+
if has_attribute(&item_struct.attrs, "pyclass")
208+
|| has_attribute_with_namespace(
209+
&item_struct.attrs,
210+
Some(pyo3_path),
211+
&["pyclass"],
212+
)
213+
|| has_attribute_with_namespace(
214+
&item_struct.attrs,
215+
Some(pyo3_path),
216+
&["prelude", "pyclass"],
217+
)
218+
{
197219
module_items.push(item_struct.ident.clone());
198220
module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
199221
if !has_pyo3_module_declared::<PyClassPyO3Option>(
@@ -210,7 +232,14 @@ pub fn pymodule_module_impl(
210232
!has_attribute(&item_enum.attrs, "pymodule_export"),
211233
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
212234
);
213-
if has_attribute(&item_enum.attrs, "pyclass") {
235+
if has_attribute(&item_enum.attrs, "pyclass")
236+
|| has_attribute_with_namespace(&item_enum.attrs, Some(pyo3_path), &["pyclass"])
237+
|| has_attribute_with_namespace(
238+
&item_enum.attrs,
239+
Some(pyo3_path),
240+
&["prelude", "pyclass"],
241+
)
242+
{
214243
module_items.push(item_enum.ident.clone());
215244
module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
216245
if !has_pyo3_module_declared::<PyClassPyO3Option>(
@@ -227,7 +256,14 @@ pub fn pymodule_module_impl(
227256
!has_attribute(&item_mod.attrs, "pymodule_export"),
228257
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
229258
);
230-
if has_attribute(&item_mod.attrs, "pymodule") {
259+
if has_attribute(&item_mod.attrs, "pymodule")
260+
|| has_attribute_with_namespace(&item_mod.attrs, Some(pyo3_path), &["pymodule"])
261+
|| has_attribute_with_namespace(
262+
&item_mod.attrs,
263+
Some(pyo3_path),
264+
&["prelude", "pymodule"],
265+
)
266+
{
231267
module_items.push(item_mod.ident.clone());
232268
module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
233269
if !has_pyo3_module_declared::<PyModulePyO3Option>(
@@ -555,8 +591,48 @@ fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bo
555591
found
556592
}
557593

594+
enum IdentOrStr<'a> {
595+
Str(&'a str),
596+
Ident(syn::Ident),
597+
}
598+
599+
impl<'a> PartialEq<syn::Ident> for IdentOrStr<'a> {
600+
fn eq(&self, other: &syn::Ident) -> bool {
601+
match self {
602+
IdentOrStr::Str(s) => other == s,
603+
IdentOrStr::Ident(i) => other == i,
604+
}
605+
}
606+
}
558607
fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
559-
attrs.iter().any(|attr| attr.path().is_ident(ident))
608+
has_attribute_with_namespace(attrs, None, &[ident])
609+
}
610+
611+
fn has_attribute_with_namespace(
612+
attrs: &[syn::Attribute],
613+
crate_path: Option<&PyO3CratePath>,
614+
idents: &[&str],
615+
) -> bool {
616+
let mut segments = vec![];
617+
if let Some(c) = crate_path {
618+
match c {
619+
PyO3CratePath::Given(paths) => {
620+
for p in &paths.segments {
621+
segments.push(IdentOrStr::Ident(p.ident.clone()));
622+
}
623+
}
624+
PyO3CratePath::Default => segments.push(IdentOrStr::Str("pyo3")),
625+
}
626+
};
627+
for i in idents {
628+
segments.push(IdentOrStr::Str(i));
629+
}
630+
631+
attrs.iter().any(|attr| {
632+
segments
633+
.iter()
634+
.eq(attr.path().segments.iter().map(|v| &v.ident))
635+
})
560636
}
561637

562638
fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {

tests/test_declarative_module.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ mod declarative_module {
124124
struct Struct;
125125
}
126126

127+
#[pyo3::prelude::pymodule]
128+
mod full_path_inner {}
129+
127130
#[pymodule_init]
128131
fn init(m: &Bound<'_, PyModule>) -> PyResult<()> {
129132
m.add("double2", m.getattr("double")?)
@@ -247,3 +250,11 @@ fn test_module_names() {
247250
);
248251
})
249252
}
253+
254+
#[test]
255+
fn test_inner_module_full_path() {
256+
Python::with_gil(|py| {
257+
let m = declarative_module(py);
258+
py_assert!(py, m, "m.full_path_inner");
259+
})
260+
}

0 commit comments

Comments
 (0)