diff --git a/Cargo.lock b/Cargo.lock index 7d498cc..908ee5c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -394,6 +394,15 @@ dependencies = [ "pyo3-stub-gen", ] +[[package]] +name = "mixed_sub_import_type" +version = "0.9.0" +dependencies = [ + "env_logger", + "pyo3", + "pyo3-stub-gen", +] + [[package]] name = "mixed_sub_multiple" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index ba2aff5..63ad179 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "examples/pure", "examples/mixed", "examples/mixed_sub", + "examples/mixed_sub_import_type", "examples/mixed_sub_multiple", ] resolver = "2" diff --git a/examples/mixed_sub_import_type/Cargo.toml b/examples/mixed_sub_import_type/Cargo.toml new file mode 100644 index 0000000..856f67a --- /dev/null +++ b/examples/mixed_sub_import_type/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "mixed_sub_import_type" +version.workspace = true +edition.workspace = true + +[lib] +crate-type = ["cdylib", "rlib"] + +[dependencies] +env_logger.workspace = true +pyo3-stub-gen = { path = "../../pyo3-stub-gen" } +pyo3.workspace = true + +[[bin]] +name = "stub_gen" +doc = false diff --git a/examples/mixed_sub_import_type/pyproject.toml b/examples/mixed_sub_import_type/pyproject.toml new file mode 100644 index 0000000..3e19390 --- /dev/null +++ b/examples/mixed_sub_import_type/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["maturin>=1.1,<2.0"] +build-backend = "maturin" + +[project] +name = "mixed_sub_import_type" +version = "0.1" +requires-python = ">=3.9" + +[project.optional-dependencies] +test = ["pytest", "pyright", "ruff"] + +[tool.maturin] +python-source = "python" +module-name = "mixed_sub_import_type.main_mod" +features = ["pyo3/extension-module"] diff --git a/examples/mixed_sub_import_type/python/mixed_sub_import_type/__init__.py b/examples/mixed_sub_import_type/python/mixed_sub_import_type/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/mixed_sub_import_type/python/mixed_sub_import_type/main_mod/__init__.pyi b/examples/mixed_sub_import_type/python/mixed_sub_import_type/main_mod/__init__.pyi new file mode 100644 index 0000000..9378bdb --- /dev/null +++ b/examples/mixed_sub_import_type/python/mixed_sub_import_type/main_mod/__init__.pyi @@ -0,0 +1,17 @@ +# This file is automatically generated by pyo3_stub_gen +# ruff: noqa: E501, F401 + +import builtins +from . import int +from . import sub_mod + +class A: + def show_x(self) -> None: ... + +class B: + def show_x(self) -> None: ... + +def create_a(x:builtins.int) -> A: ... + +def create_b(x:builtins.int) -> B: ... + diff --git a/examples/mixed_sub_import_type/python/mixed_sub_import_type/main_mod/int.pyi b/examples/mixed_sub_import_type/python/mixed_sub_import_type/main_mod/int.pyi new file mode 100644 index 0000000..1cc6213 --- /dev/null +++ b/examples/mixed_sub_import_type/python/mixed_sub_import_type/main_mod/int.pyi @@ -0,0 +1,7 @@ +# This file is automatically generated by pyo3_stub_gen +# ruff: noqa: E501, F401 + +import builtins + +def dummy_int_fun(x:builtins.int) -> builtins.int: ... + diff --git a/examples/mixed_sub_import_type/python/mixed_sub_import_type/main_mod/sub_mod.pyi b/examples/mixed_sub_import_type/python/mixed_sub_import_type/main_mod/sub_mod.pyi new file mode 100644 index 0000000..dc75f6b --- /dev/null +++ b/examples/mixed_sub_import_type/python/mixed_sub_import_type/main_mod/sub_mod.pyi @@ -0,0 +1,10 @@ +# This file is automatically generated by pyo3_stub_gen +# ruff: noqa: E501, F401 + +from mixed_sub_import_type.main_mod import A, B + +class C: + def show_x(self) -> None: ... + +def create_c(a:A, b:B) -> C: ... + diff --git a/examples/mixed_sub_import_type/python/mixed_sub_import_type/py.typed b/examples/mixed_sub_import_type/python/mixed_sub_import_type/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/examples/mixed_sub_import_type/src/bin/stub_gen.rs b/examples/mixed_sub_import_type/src/bin/stub_gen.rs new file mode 100644 index 0000000..85b233b --- /dev/null +++ b/examples/mixed_sub_import_type/src/bin/stub_gen.rs @@ -0,0 +1,8 @@ +use pyo3_stub_gen::Result; + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init(); + let stub = mixed_sub_import_type::stub_info()?; + stub.generate()?; + Ok(()) +} diff --git a/examples/mixed_sub_import_type/src/lib.rs b/examples/mixed_sub_import_type/src/lib.rs new file mode 100644 index 0000000..bf8b02f --- /dev/null +++ b/examples/mixed_sub_import_type/src/lib.rs @@ -0,0 +1,119 @@ +use pyo3::prelude::*; +use pyo3_stub_gen::{define_stub_info_gatherer, derive::*}; + +// Specify the module name explicitly +#[gen_stub_pyclass] +#[pyclass(module = "mixed_sub_import_type.main_mod")] +#[derive(Debug, Clone)] +struct A { + x: usize, +} + +#[gen_stub_pymethods] +#[pymethods] +impl A { + fn show_x(&self) { + println!("x = {}", self.x); + } +} + +#[gen_stub_pyfunction(module = "mixed_sub_import_type.main_mod")] +#[pyfunction] +fn create_a(x: usize) -> A { + A { x } +} + +// Do not specify the module name explicitly +// This will be placed in the main module +#[gen_stub_pyclass] +#[pyclass] +#[derive(Debug, Clone)] +struct B { + x: usize, +} + +#[gen_stub_pymethods] +#[pymethods] +impl B { + fn show_x(&self) { + println!("x = {}", self.x); + } +} + +#[gen_stub_pyfunction] +#[pyfunction] +fn create_b(x: usize) -> B { + B { x } +} + +// Class in submodule +#[gen_stub_pyclass] +#[pyclass(module = "mixed_sub_import_type.main_mod.sub_mod")] +#[derive(Debug)] +struct C { + a: A, + b: B +} + +#[gen_stub_pymethods] +#[pymethods] +impl C { + fn show_x(&self) { + println!("a.x"); + self.a.show_x(); + println!("b.x"); + self.b.show_x() + } +} + +#[gen_stub_pyfunction(module = "mixed_sub_import_type.main_mod.sub_mod")] +#[pyfunction] +fn create_c(a: A, b: B) -> C { + C { a, b } +} + +#[gen_stub_pyfunction(module = "mixed_sub_import_type.main_mod.int")] +#[pyfunction] +fn dummy_int_fun(x: usize) -> usize { + x +} + +#[pymodule] +fn main_mod(m: &Bound) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_function(wrap_pyfunction!(create_a, m)?)?; + m.add_function(wrap_pyfunction!(create_b, m)?)?; + sub_mod(m)?; + int_mod(m)?; + Ok(()) +} + +fn sub_mod(parent: &Bound) -> PyResult<()> { + let py = parent.py(); + let sub = PyModule::new(py, "sub_mod")?; + sub.add_class::()?; + sub.add_function(wrap_pyfunction!(create_c, &sub)?)?; + parent.add_submodule(&sub)?; + Ok(()) +} + +/// A dummy module to pollute namespace with unqualified `int` +fn int_mod(parent: &Bound) -> PyResult<()> { + let py = parent.py(); + let sub = PyModule::new(py, "int")?; + sub.add_function(wrap_pyfunction!(dummy_int_fun, &sub)?)?; + parent.add_submodule(&sub)?; + Ok(()) +} + +define_stub_info_gatherer!(stub_info); + +/// Test of unit test for testing link problem +#[cfg(test)] +mod test { + #[test] + fn test() { + assert_eq!(2 + 2, 4); + } +} diff --git a/examples/mixed_sub_import_type/tests/test_mixed_sub.py b/examples/mixed_sub_import_type/tests/test_mixed_sub.py new file mode 100644 index 0000000..524f681 --- /dev/null +++ b/examples/mixed_sub_import_type/tests/test_mixed_sub.py @@ -0,0 +1,17 @@ +from mixed_sub_import_type import main_mod + + +def test_main_mod(): + a = main_mod.create_a(1) + a.show_x() + + b = main_mod.create_b(1) + b.show_x() + + +def test_sub_mod(): + a = main_mod.create_a(1) + b = main_mod.create_b(1) + + c = main_mod.sub_mod.create_c(a, b) + c.show_x() diff --git a/pyo3-stub-gen-derive/src/gen_stub/stub_type.rs b/pyo3-stub-gen-derive/src/gen_stub/stub_type.rs index cbff729..4b2999b 100644 --- a/pyo3-stub-gen-derive/src/gen_stub/stub_type.rs +++ b/pyo3-stub-gen-derive/src/gen_stub/stub_type.rs @@ -21,7 +21,7 @@ impl ToTokens for StubType { #[automatically_derived] impl ::pyo3_stub_gen::PyStubType for #ty { fn type_output() -> ::pyo3_stub_gen::TypeInfo { - ::pyo3_stub_gen::TypeInfo::with_module(#name, #module_tt) + ::pyo3_stub_gen::TypeInfo::with_type(#name, #module_tt) } } }) diff --git a/pyo3-stub-gen/src/generate.rs b/pyo3-stub-gen/src/generate.rs index 7244572..dc7142b 100644 --- a/pyo3-stub-gen/src/generate.rs +++ b/pyo3-stub-gen/src/generate.rs @@ -23,7 +23,7 @@ pub use module::*; pub use stub_info::*; pub use variable::*; -use crate::stub_type::ModuleRef; +use crate::stub_type::ImportRef; use std::collections::HashSet; fn indent() -> &'static str { @@ -31,5 +31,5 @@ fn indent() -> &'static str { } pub trait Import { - fn import(&self) -> HashSet; + fn import(&self) -> HashSet; } diff --git a/pyo3-stub-gen/src/generate/arg.rs b/pyo3-stub-gen/src/generate/arg.rs index 9669681..162e21c 100644 --- a/pyo3-stub-gen/src/generate/arg.rs +++ b/pyo3-stub-gen/src/generate/arg.rs @@ -1,4 +1,4 @@ -use crate::{generate::Import, stub_type::ModuleRef, type_info::*, TypeInfo}; +use crate::{generate::Import, stub_type::ImportRef, type_info::*, TypeInfo}; use std::{collections::HashSet, fmt}; #[derive(Debug, Clone, PartialEq)] @@ -9,7 +9,7 @@ pub struct Arg { } impl Import for Arg { - fn import(&self) -> HashSet { + fn import(&self) -> HashSet { self.r#type.import.clone() } } diff --git a/pyo3-stub-gen/src/generate/class.rs b/pyo3-stub-gen/src/generate/class.rs index d422ad6..dfe2dbf 100644 --- a/pyo3-stub-gen/src/generate/class.rs +++ b/pyo3-stub-gen/src/generate/class.rs @@ -12,7 +12,7 @@ pub struct ClassDef { } impl Import for ClassDef { - fn import(&self) -> HashSet { + fn import(&self) -> HashSet { let mut import = HashSet::new(); for base in &self.bases { import.extend(base.import.clone()); diff --git a/pyo3-stub-gen/src/generate/function.rs b/pyo3-stub-gen/src/generate/function.rs index 58072f4..8578734 100644 --- a/pyo3-stub-gen/src/generate/function.rs +++ b/pyo3-stub-gen/src/generate/function.rs @@ -11,7 +11,7 @@ pub struct FunctionDef { } impl Import for FunctionDef { - fn import(&self) -> HashSet { + fn import(&self) -> HashSet { let mut import = self.r#return.import.clone(); for arg in &self.args { import.extend(arg.import().into_iter()); diff --git a/pyo3-stub-gen/src/generate/member.rs b/pyo3-stub-gen/src/generate/member.rs index 699c344..a051c06 100644 --- a/pyo3-stub-gen/src/generate/member.rs +++ b/pyo3-stub-gen/src/generate/member.rs @@ -13,7 +13,7 @@ pub struct MemberDef { } impl Import for MemberDef { - fn import(&self) -> HashSet { + fn import(&self) -> HashSet { self.r#type.import.clone() } } diff --git a/pyo3-stub-gen/src/generate/method.rs b/pyo3-stub-gen/src/generate/method.rs index 30ad5c7..f9eba60 100644 --- a/pyo3-stub-gen/src/generate/method.rs +++ b/pyo3-stub-gen/src/generate/method.rs @@ -14,7 +14,7 @@ pub struct MethodDef { } impl Import for MethodDef { - fn import(&self) -> HashSet { + fn import(&self) -> HashSet { let mut import = self.r#return.import.clone(); for arg in &self.args { import.extend(arg.import().into_iter()); diff --git a/pyo3-stub-gen/src/generate/module.rs b/pyo3-stub-gen/src/generate/module.rs index e443fed..525d63c 100644 --- a/pyo3-stub-gen/src/generate/module.rs +++ b/pyo3-stub-gen/src/generate/module.rs @@ -21,7 +21,7 @@ pub struct Module { } impl Import for Module { - fn import(&self) -> HashSet { + fn import(&self) -> HashSet { let mut imports = HashSet::new(); for class in self.class.values() { imports.extend(class.import()); @@ -38,12 +38,37 @@ impl fmt::Display for Module { writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?; writeln!(f, "# ruff: noqa: E501, F401")?; writeln!(f)?; - for import in self.import().into_iter().sorted() { - let name = import.get().unwrap_or(&self.default_module_name); - if name != self.name { - writeln!(f, "import {}", name)?; + let package_name = self.default_module_name.split('.').next().unwrap(); + let mut type_ref_grouped: BTreeMap> = BTreeMap::new(); + for common_ref in self.import().into_iter().sorted() { + match common_ref { + ImportRef::Module(module_ref) => { + let name = module_ref.get().unwrap_or(&self.default_module_name); + if name != self.name { + println!("Name: {}", name,); + writeln!(f, "import {}", name)?; + } + } + ImportRef::Type(type_ref) => { + let module_name = type_ref.module.get().unwrap_or(&self.default_module_name); + if module_name != self.name { + if module_name.starts_with(package_name) { + type_ref_grouped + .entry(module_name.to_string()) + .or_default() + .push(type_ref.name); + } else { + writeln!(f, "import {}", module_name)?; + } + } + } } } + for (module_name, type_names) in type_ref_grouped { + let mut sorted_type_names = type_names.clone(); + sorted_type_names.sort(); + writeln!(f, "from {} import {}", module_name, sorted_type_names.join(", "))?; + } for submod in &self.submodules { writeln!(f, "from . import {}", submod)?; } diff --git a/pyo3-stub-gen/src/stub_type.rs b/pyo3-stub-gen/src/stub_type.rs index 3b4a59a..91db727 100644 --- a/pyo3-stub-gen/src/stub_type.rs +++ b/pyo3-stub-gen/src/stub_type.rs @@ -7,6 +7,33 @@ mod numpy; use maplit::hashset; use std::{collections::HashSet, fmt, ops}; +use std::cmp::Ordering; + +/// Indicates what to import. +/// Module: The purpose is to import the entire module(eg import builtins). +/// Type: The purpose is to import the types in the module(eg from moduleX import typeX). +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ImportRef { + Module(ModuleRef), + Type(TypeRef), +} + +impl PartialOrd for ImportRef { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ImportRef { + fn cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (ImportRef::Module(a), ImportRef::Module(b)) => a.get().cmp(&b.get()), + (ImportRef::Type(a), ImportRef::Type(b)) => a.cmp(b), + (ImportRef::Module(_), ImportRef::Type(_)) => Ordering::Greater, + (ImportRef::Type(_), ImportRef::Module(_)) => Ordering::Less, + } + } +} #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)] pub enum ModuleRef { @@ -39,6 +66,22 @@ impl From<&str> for ModuleRef { } } + +/// Indicates the type of import(eg class enum). +/// from module import type. +/// name, type name. module, module name(which type defined). + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)] +pub struct TypeRef { + pub module: ModuleRef, + pub name: String, +} + +impl TypeRef { + pub fn new(module_ref: ModuleRef, name: String) -> Self { + Self{module: module_ref, name} + } +} + /// Type information for creating Python stub files annotated by [PyStubType] trait. #[derive(Debug, Clone, PartialEq, Eq)] pub struct TypeInfo { @@ -49,7 +92,7 @@ pub struct TypeInfo { /// /// For example, when `name` is `typing.Sequence[int]`, `import` should contain `typing`. /// This makes it possible to use user-defined types in the stub file. - pub import: HashSet, + pub import: HashSet, } impl fmt::Display for TypeInfo { @@ -73,14 +116,14 @@ impl TypeInfo { pub fn any() -> Self { Self { name: "typing.Any".to_string(), - import: hashset! { "typing".into() }, + import: hashset! { ImportRef::Module("builtins".into()) }, } } /// A `list[Type]` type annotation. pub fn list_of() -> Self { let TypeInfo { name, mut import } = T::type_output(); - import.insert("builtins".into()); + import.insert(ImportRef::Module("builtins".into())); TypeInfo { name: format!("builtins.list[{}]", name), import, @@ -90,7 +133,7 @@ impl TypeInfo { /// A `set[Type]` type annotation. pub fn set_of() -> Self { let TypeInfo { name, mut import } = T::type_output(); - import.insert("builtins".into()); + import.insert(ImportRef::Module("builtins".into())); TypeInfo { name: format!("builtins.set[{}]", name), import, @@ -108,7 +151,7 @@ impl TypeInfo { import: import_v, } = V::type_output(); import.extend(import_v); - import.insert("builtins".into()); + import.insert(ImportRef::Module("builtins".into())); TypeInfo { name: format!("builtins.set[{}, {}]", name_k, name_v), import, @@ -119,7 +162,7 @@ impl TypeInfo { pub fn builtin(name: &str) -> Self { Self { name: format!("builtins.{name}"), - import: hashset! { "builtins".into() }, + import: hashset! { ImportRef::Module("builtins".into()) }, } } @@ -138,12 +181,29 @@ impl TypeInfo { /// ``` pub fn with_module(name: &str, module: ModuleRef) -> Self { let mut import = HashSet::new(); - import.insert(module); + import.insert(ImportRef::Module(module)); Self { name: name.to_string(), import, } } + + /// A type annotation of a type that must be imported. + /// + /// ``` + /// Class A is defined in module A, referenced in module B. "from ModuleA import ClassA" + /// pyo3_stub_gen::TypeInfo::with_type("ClassA", "ModuleA".into()); + /// ``` + pub fn with_type(type_name: &str, module: ModuleRef) -> Self { + let mut import = HashSet::new(); + let type_ref = TypeRef::new(module, type_name.to_string()); + import.insert(ImportRef::Type(type_ref)); + + Self { + name: type_name.to_string(), + import, + } + } } impl ops::BitOr for TypeInfo { @@ -232,19 +292,19 @@ mod test { use std::collections::HashMap; use test_case::test_case; - #[test_case(bool::type_input(), "builtins.bool", hashset! { "builtins".into() } ; "bool_input")] - #[test_case(<&str>::type_input(), "builtins.str", hashset! { "builtins".into() } ; "str_input")] - #[test_case(Vec::::type_input(), "typing.Sequence[builtins.int]", hashset! { "typing".into(), "builtins".into() } ; "Vec_u32_input")] - #[test_case(Vec::::type_output(), "builtins.list[builtins.int]", hashset! { "builtins".into() } ; "Vec_u32_output")] - #[test_case(HashMap::::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "HashMap_u32_String_input")] - #[test_case(HashMap::::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "HashMap_u32_String_output")] - #[test_case(indexmap::IndexMap::::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { "typing".into(), "builtins".into() } ; "IndexMap_u32_String_input")] - #[test_case(indexmap::IndexMap::::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { "builtins".into() } ; "IndexMap_u32_String_output")] - #[test_case(HashMap::>::type_input(), "typing.Mapping[builtins.int, typing.Sequence[builtins.int]]", hashset! { "builtins".into(), "typing".into() } ; "HashMap_u32_Vec_u32_input")] - #[test_case(HashMap::>::type_output(), "builtins.dict[builtins.int, builtins.list[builtins.int]]", hashset! { "builtins".into() } ; "HashMap_u32_Vec_u32_output")] - #[test_case(HashSet::::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "HashSet_u32_input")] - #[test_case(indexmap::IndexSet::::type_input(), "builtins.set[builtins.int]", hashset! { "builtins".into() } ; "IndexSet_u32_input")] - fn test(tinfo: TypeInfo, name: &str, import: HashSet) { + #[test_case(bool::type_input(), "builtins.bool", hashset! { ImportRef::Module("builtins".into()) } ; "bool_input")] + #[test_case(<&str>::type_input(), "builtins.str", hashset! { ImportRef::Module("builtins".into()) } ; "str_input")] + #[test_case(Vec::::type_input(), "typing.Sequence[builtins.int]", hashset! { ImportRef::Module("typing".into()), ImportRef::Module("builtins".into()) } ; "Vec_u32_input")] + #[test_case(Vec::::type_output(), "builtins.list[builtins.int]", hashset! { ImportRef::Module("builtins".into()) } ; "Vec_u32_output")] + #[test_case(HashMap::::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { ImportRef::Module("typing".into()), ImportRef::Module("builtins".into()) } ; "HashMap_u32_String_input")] + #[test_case(HashMap::::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { ImportRef::Module("builtins".into()) } ; "HashMap_u32_String_output")] + #[test_case(indexmap::IndexMap::::type_input(), "typing.Mapping[builtins.int, builtins.str]", hashset! { ImportRef::Module("typing".into()), ImportRef::Module("builtins".into()) } ; "IndexMap_u32_String_input")] + #[test_case(indexmap::IndexMap::::type_output(), "builtins.dict[builtins.int, builtins.str]", hashset! { ImportRef::Module("builtins".into()) } ; "IndexMap_u32_String_output")] + #[test_case(HashMap::>::type_input(), "typing.Mapping[builtins.int, typing.Sequence[builtins.int]]", hashset! { ImportRef::Module("builtins".into()), ImportRef::Module("typing".into()) } ; "HashMap_u32_Vec_u32_input")] + #[test_case(HashMap::>::type_output(), "builtins.dict[builtins.int, builtins.list[builtins.int]]", hashset! { ImportRef::Module("builtins".into()) } ; "HashMap_u32_Vec_u32_output")] + #[test_case(HashSet::::type_input(), "builtins.set[builtins.int]", hashset! { ImportRef::Module("builtins".into()) } ; "HashSet_u32_input")] + #[test_case(indexmap::IndexSet::::type_input(), "builtins.set[builtins.int]", hashset! { ImportRef::Module("builtins".into()) } ; "IndexSet_u32_input")] + fn test(tinfo: TypeInfo, name: &str, import: HashSet) { assert_eq!(tinfo.name, name); if import.is_empty() { assert!(tinfo.import.is_empty()); diff --git a/pyo3-stub-gen/src/stub_type/collections.rs b/pyo3-stub-gen/src/stub_type/collections.rs index b5931d9..d73e84e 100644 --- a/pyo3-stub-gen/src/stub_type/collections.rs +++ b/pyo3-stub-gen/src/stub_type/collections.rs @@ -4,7 +4,7 @@ use std::collections::{BTreeMap, BTreeSet, HashMap}; impl PyStubType for Option { fn type_input() -> TypeInfo { let TypeInfo { name, mut import } = T::type_input(); - import.insert("typing".into()); + import.insert(ImportRef::Module("typing".into())); TypeInfo { name: format!("typing.Optional[{}]", name), import, @@ -12,7 +12,7 @@ impl PyStubType for Option { } fn type_output() -> TypeInfo { let TypeInfo { name, mut import } = T::type_output(); - import.insert("typing".into()); + import.insert(ImportRef::Module("typing".into())); TypeInfo { name: format!("typing.Optional[{}]", name), import, @@ -41,7 +41,7 @@ impl PyStubType for Result { impl PyStubType for Vec { fn type_input() -> TypeInfo { let TypeInfo { name, mut import } = T::type_input(); - import.insert("typing".into()); + import.insert(ImportRef::Module("typing".into())); TypeInfo { name: format!("typing.Sequence[{}]", name), import, @@ -55,7 +55,7 @@ impl PyStubType for Vec { impl PyStubType for [T; N] { fn type_input() -> TypeInfo { let TypeInfo { name, mut import } = T::type_input(); - import.insert("typing".into()); + import.insert(ImportRef::Module("typing".into())); TypeInfo { name: format!("typing.Sequence[{}]", name), import, @@ -96,7 +96,7 @@ macro_rules! impl_map_inner { import: value_import, } = Value::type_input(); import.extend(value_import); - import.insert("typing".into()); + import.insert(ImportRef::Module("typing".into())); TypeInfo { name: format!("typing.Mapping[{}, {}]", key_name, value_name), import, @@ -112,7 +112,7 @@ macro_rules! impl_map_inner { import: value_import, } = Value::type_output(); import.extend(value_import); - import.insert("builtins".into()); + import.insert(ImportRef::Module("builtins".into())); TypeInfo { name: format!("builtins.dict[{}, {}]", key_name, value_name), import, diff --git a/pyo3-stub-gen/src/stub_type/numpy.rs b/pyo3-stub-gen/src/stub_type/numpy.rs index 83f59c6..71c42c7 100644 --- a/pyo3-stub-gen/src/stub_type/numpy.rs +++ b/pyo3-stub-gen/src/stub_type/numpy.rs @@ -1,4 +1,4 @@ -use super::{PyStubType, TypeInfo}; +use super::{PyStubType, TypeInfo, ImportRef}; use maplit::hashset; use numpy::{ ndarray::Dimension, Element, PyArray, PyArrayDescr, PyReadonlyArray, PyReadwriteArray, @@ -15,7 +15,7 @@ macro_rules! impl_numpy_scalar { fn type_() -> TypeInfo { TypeInfo { name: format!("numpy.{}", $name), - import: hashset!["numpy".into()], + import: hashset![ImportRef::Module("numpy".into())], } } } @@ -38,7 +38,7 @@ impl_numpy_scalar!(num_complex::Complex64, "complex128"); impl PyStubType for PyArray { fn type_output() -> TypeInfo { let TypeInfo { name, mut import } = T::type_(); - import.insert("numpy.typing".into()); + import.insert(ImportRef::Module("numpy.typing".into())); TypeInfo { name: format!("numpy.typing.NDArray[{name}]"), import, @@ -50,7 +50,7 @@ impl PyStubType for PyUntypedArray { fn type_output() -> TypeInfo { TypeInfo { name: "numpy.typing.NDArray[typing.Any]".into(), - import: hashset!["numpy.typing".into(), "typing".into()], + import: hashset![ImportRef::Module("numpy.typing".into()), ImportRef::Module("typing".into())], } } } @@ -79,7 +79,7 @@ impl PyStubType for PyArrayDescr { fn type_output() -> TypeInfo { TypeInfo { name: "numpy.dtype".into(), - import: hashset!["numpy".into()], + import: hashset![ImportRef::Module("numpy".into())], } } } diff --git a/pyo3-stub-gen/src/stub_type/pyo3.rs b/pyo3-stub-gen/src/stub_type/pyo3.rs index 5fc077f..bf6707c 100644 --- a/pyo3-stub-gen/src/stub_type/pyo3.rs +++ b/pyo3-stub-gen/src/stub_type/pyo3.rs @@ -12,7 +12,7 @@ impl PyStubType for PyAny { fn type_output() -> TypeInfo { TypeInfo { name: "typing.Any".to_string(), - import: hashset! { "typing".into() }, + import: hashset! { ImportRef::Module("typing".into()) }, } } } @@ -87,7 +87,7 @@ macro_rules! impl_simple { fn type_output() -> TypeInfo { TypeInfo { name: concat!($mod, ".", $pytype).to_string(), - import: hashset! { $mod.into() }, + import: hashset! { ImportRef::Module($mod.into()) }, } } }