Skip to content

Commit 918c02e

Browse files
committed
Fixes #2, auto generate python files and fix type ref
1 parent cdc574d commit 918c02e

File tree

4 files changed

+90
-39
lines changed

4 files changed

+90
-39
lines changed

pyo3-stub-gen-derive/src/gen_stub/stub_type.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ impl ToTokens for StubType {
2121
#[automatically_derived]
2222
impl ::pyo3_stub_gen::PyStubType for #ty {
2323
fn type_output() -> ::pyo3_stub_gen::TypeInfo {
24-
::pyo3_stub_gen::TypeInfo::with_module(#name, #module_tt)
24+
::pyo3_stub_gen::TypeInfo::with_type(#name, #module_tt)
2525
}
2626
}
2727
})

pyo3-stub-gen/src/generate/module.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,26 @@ impl fmt::Display for Module {
3838
writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
3939
writeln!(f, "# ruff: noqa: E501, F401")?;
4040
writeln!(f)?;
41-
for import in self.import().into_iter().sorted() {
42-
let name = import.get().unwrap_or(&self.default_module_name);
43-
if name != self.name {
44-
writeln!(f, "import {}", name)?;
41+
let package_name = self.default_module_name.split('.').next().unwrap();
42+
let mut grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
43+
for mut module_ref in self.import().into_iter().sorted() {
44+
if module_ref.module.is_empty() {
45+
module_ref.module = self.default_module_name.clone();
4546
}
47+
if module_ref.module != self.name {
48+
// writeln!(f, "from {} import {}", module_ref.module, module_ref.name)?;
49+
if module_ref.module.starts_with(package_name) {
50+
grouped
51+
.entry(module_ref.module)
52+
.or_default()
53+
.push(module_ref.name);
54+
} else {
55+
writeln!(f, "import {}", module_ref.module)?;
56+
}
57+
}
58+
}
59+
for (module, names) in grouped {
60+
writeln!(f, "from {} import {}", module, names.join(", "))?;
4661
}
4762
for submod in &self.submodules {
4863
writeln!(f, "from . import {}", submod)?;

pyo3-stub-gen/src/generate/stub_info.rs

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,32 @@ use std::{
66
io::Write,
77
path::*,
88
};
9-
9+
fn get_destination_paths(module: &Module, python_root: &Path, default_module: &str) -> (PathBuf, Option<PathBuf>) {
10+
let path = module.name.replace(".", "/");
11+
if module.submodules.is_empty() {
12+
if module.name != default_module {
13+
(
14+
python_root.join(format!("{path}.pyi")),
15+
Some(python_root.join(format!("{path}.py"))),
16+
)
17+
} else {
18+
(
19+
python_root.join(format!("{path}.pyi")),
20+
None,
21+
)
22+
}
23+
} else {
24+
(
25+
python_root.join(path).join("__init__.pyi"),
26+
None,
27+
)
28+
}
29+
}
1030
#[derive(Debug, Clone, PartialEq)]
1131
pub struct StubInfo {
1232
pub modules: BTreeMap<String, Module>,
1333
pub python_root: PathBuf,
34+
pub default_module: String
1435
}
1536

1637
impl StubInfo {
@@ -30,13 +51,7 @@ impl StubInfo {
3051

3152
pub fn generate(&self) -> Result<()> {
3253
for (name, module) in self.modules.iter() {
33-
let path = name.replace(".", "/");
34-
let dest = if module.submodules.is_empty() {
35-
self.python_root.join(format!("{path}.pyi"))
36-
} else {
37-
self.python_root.join(path).join("__init__.pyi")
38-
};
39-
54+
let (dest, dest_py) = get_destination_paths(module, &self.python_root, &self.default_module);
4055
let dir = dest.parent().context("Cannot get parent directory")?;
4156
if !dir.exists() {
4257
fs::create_dir_all(dir)?;
@@ -48,6 +63,15 @@ impl StubInfo {
4863
"Generate stub file of a module `{name}` at {dest}",
4964
dest = dest.display()
5065
);
66+
67+
if let Some(dest_py) = dest_py {
68+
let mut f_py = fs::File::create(&dest_py)?;
69+
write!(f_py, "# Fixed pylance reportMissingModuleSource warning when using \n# \"from {} import xxx \" and \"import {}\"", module.name, module.name)?;
70+
log::info!(
71+
"Generate python file of a (not main)module `{name}` at {dest}",
72+
dest = dest_py.display()
73+
);
74+
}
5175
}
5276
Ok(())
5377
}
@@ -189,6 +213,7 @@ impl StubInfoBuilder {
189213
StubInfo {
190214
modules: self.modules,
191215
python_root: self.python_root,
216+
default_module: self.default_module_name
192217
}
193218
}
194219
}

pyo3-stub-gen/src/stub_type.rs

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,27 @@ mod numpy;
77

88
use maplit::hashset;
99
use std::{collections::HashSet, fmt, ops};
10-
11-
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
12-
pub enum ModuleRef {
13-
Named(String),
14-
15-
/// Default module that PyO3 creates.
16-
///
17-
/// - For pure Rust project, the default module name is the crate name specified in `Cargo.toml`
18-
/// or `project.name` specified in `pyproject.toml`
19-
/// - For mixed Rust/Python project, the default module name is `tool.maturin.module-name` specified in `pyproject.toml`
20-
///
21-
/// Because the default module name cannot be known at compile time, it will be resolved at the time of the stub file generation.
22-
/// This is a placeholder for the default module name.
23-
#[default]
24-
Default,
10+
/**
11+
* Indicates the dependent module or type;
12+
* dependent module('import module', eg builtins): name="" , module = module name;
13+
* dependent type(eg class enum)('from module import type'): name=type name , module = module name(which type defined).
14+
*/
15+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash)]
16+
pub struct ModuleRef {
17+
pub name: String,
18+
pub module: String,
2519
}
26-
27-
impl ModuleRef {
28-
pub fn get(&self) -> Option<&str> {
29-
match self {
30-
Self::Named(name) => Some(name),
31-
Self::Default => None,
32-
}
20+
/// From the dependent module
21+
impl From<&str> for ModuleRef {
22+
fn from(m: &str) -> Self {
23+
Self{name:"".to_string(), module:m.to_string()}
3324
}
3425
}
3526

36-
impl From<&str> for ModuleRef {
37-
fn from(s: &str) -> Self {
38-
Self::Named(s.to_string())
27+
impl fmt::Display for ModuleRef {
28+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
29+
writeln!(f, "name: {}, module: {}", self.name, self.module)?;
30+
Ok(())
3931
}
4032
}
4133

@@ -136,14 +128,33 @@ impl TypeInfo {
136128
/// ```
137129
/// pyo3_stub_gen::TypeInfo::with_module("pathlib.Path", "pathlib".into());
138130
/// ```
139-
pub fn with_module(name: &str, module: ModuleRef) -> Self {
131+
pub fn with_module(name: &str, module: ModuleRef) -> Self
132+
{
140133
let mut import = HashSet::new();
141134
import.insert(module);
142135
Self {
143136
name: name.to_string(),
144137
import,
145138
}
146139
}
140+
141+
/// A type annotation of a type that must be imported.
142+
///
143+
/// ```
144+
/// ClassA defined in ModuleA
145+
/// pyo3_stub_gen::TypeInfo::with_type("ClassA", "ModuleA");
146+
/// ```
147+
pub fn with_type(type_name: &str, module: ModuleRef) -> Self {
148+
let mut import = HashSet::new();
149+
let mut type_ref: ModuleRef = module.clone();
150+
type_ref.name = type_name.to_string();
151+
import.insert(type_ref);
152+
153+
Self {
154+
name: type_name.to_string(),
155+
import,
156+
}
157+
}
147158
}
148159

149160
impl ops::BitOr for TypeInfo {

0 commit comments

Comments
 (0)