Skip to content

Fixes Issue#2: fix type references #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ members = [
"examples/pure",
"examples/mixed",
"examples/mixed_sub",
"examples/mixed_sub_import_type",
"examples/mixed_sub_multiple",
]
resolver = "2"
Expand Down
16 changes: 16 additions & 0 deletions examples/mixed_sub_import_type/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "mixed_sub_import_type"
version.workspace = true
edition.workspace = true

[lib]
crate-type = ["cdylib", "rlib"]

[dependencies]
env_logger.workspace = true
pyo3-stub-gen = { path = "../../pyo3-stub-gen" }
pyo3.workspace = true

[[bin]]
name = "stub_gen"
doc = false
16 changes: 16 additions & 0 deletions examples/mixed_sub_import_type/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[build-system]
requires = ["maturin>=1.1,<2.0"]
build-backend = "maturin"

[project]
name = "mixed_sub_import_type"
version = "0.1"
requires-python = ">=3.9"

[project.optional-dependencies]
test = ["pytest", "pyright", "ruff"]

[tool.maturin]
python-source = "python"
module-name = "mixed_sub_import_type.main_mod"
features = ["pyo3/extension-module"]
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# This file is automatically generated by pyo3_stub_gen
# ruff: noqa: E501, F401

import builtins
from . import int
from . import sub_mod

class A:
def show_x(self) -> None: ...

class B:
def show_x(self) -> None: ...

def create_a(x:builtins.int) -> A: ...

def create_b(x:builtins.int) -> B: ...

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# This file is automatically generated by pyo3_stub_gen
# ruff: noqa: E501, F401

import builtins

def dummy_int_fun(x:builtins.int) -> builtins.int: ...

Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# This file is automatically generated by pyo3_stub_gen
# ruff: noqa: E501, F401

from mixed_sub_import_type.main_mod import A, B

class C:
def show_x(self) -> None: ...

def create_c(a:A, b:B) -> C: ...

Empty file.
8 changes: 8 additions & 0 deletions examples/mixed_sub_import_type/src/bin/stub_gen.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
use pyo3_stub_gen::Result;

fn main() -> Result<()> {
env_logger::Builder::from_env(env_logger::Env::default().filter_or("RUST_LOG", "info")).init();
let stub = mixed_sub_import_type::stub_info()?;
stub.generate()?;
Ok(())
}
119 changes: 119 additions & 0 deletions examples/mixed_sub_import_type/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
use pyo3::prelude::*;
use pyo3_stub_gen::{define_stub_info_gatherer, derive::*};

// Specify the module name explicitly
#[gen_stub_pyclass]
#[pyclass(module = "mixed_sub_import_type.main_mod")]
#[derive(Debug, Clone)]
struct A {
x: usize,
}

#[gen_stub_pymethods]
#[pymethods]
impl A {
fn show_x(&self) {
println!("x = {}", self.x);
}
}

#[gen_stub_pyfunction(module = "mixed_sub_import_type.main_mod")]
#[pyfunction]
fn create_a(x: usize) -> A {
A { x }
}

// Do not specify the module name explicitly
// This will be placed in the main module
#[gen_stub_pyclass]
#[pyclass]
#[derive(Debug, Clone)]
struct B {
x: usize,
}

#[gen_stub_pymethods]
#[pymethods]
impl B {
fn show_x(&self) {
println!("x = {}", self.x);
}
}

#[gen_stub_pyfunction]
#[pyfunction]
fn create_b(x: usize) -> B {
B { x }
}

// Class in submodule
#[gen_stub_pyclass]
#[pyclass(module = "mixed_sub_import_type.main_mod.sub_mod")]
#[derive(Debug)]
struct C {
a: A,
b: B
}

#[gen_stub_pymethods]
#[pymethods]
impl C {
fn show_x(&self) {
println!("a.x");
self.a.show_x();
println!("b.x");
self.b.show_x()
}
}

#[gen_stub_pyfunction(module = "mixed_sub_import_type.main_mod.sub_mod")]
#[pyfunction]
fn create_c(a: A, b: B) -> C {
C { a, b }
}

#[gen_stub_pyfunction(module = "mixed_sub_import_type.main_mod.int")]
#[pyfunction]
fn dummy_int_fun(x: usize) -> usize {
x
}

#[pymodule]
fn main_mod(m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<A>()?;
m.add_class::<B>()?;
m.add_function(wrap_pyfunction!(create_a, m)?)?;
m.add_function(wrap_pyfunction!(create_b, m)?)?;
sub_mod(m)?;
int_mod(m)?;
Ok(())
}

fn sub_mod(parent: &Bound<PyModule>) -> PyResult<()> {
let py = parent.py();
let sub = PyModule::new(py, "sub_mod")?;
sub.add_class::<C>()?;
sub.add_function(wrap_pyfunction!(create_c, &sub)?)?;
parent.add_submodule(&sub)?;
Ok(())
}

/// A dummy module to pollute namespace with unqualified `int`
fn int_mod(parent: &Bound<PyModule>) -> PyResult<()> {
let py = parent.py();
let sub = PyModule::new(py, "int")?;
sub.add_function(wrap_pyfunction!(dummy_int_fun, &sub)?)?;
parent.add_submodule(&sub)?;
Ok(())
}

define_stub_info_gatherer!(stub_info);

/// Test of unit test for testing link problem
#[cfg(test)]
mod test {
#[test]
fn test() {
assert_eq!(2 + 2, 4);
}
}
17 changes: 17 additions & 0 deletions examples/mixed_sub_import_type/tests/test_mixed_sub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from mixed_sub_import_type import main_mod


def test_main_mod():
a = main_mod.create_a(1)
a.show_x()

b = main_mod.create_b(1)
b.show_x()


def test_sub_mod():
a = main_mod.create_a(1)
b = main_mod.create_b(1)

c = main_mod.sub_mod.create_c(a, b)
c.show_x()
2 changes: 1 addition & 1 deletion pyo3-stub-gen-derive/src/gen_stub/stub_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl ToTokens for StubType {
#[automatically_derived]
impl ::pyo3_stub_gen::PyStubType for #ty {
fn type_output() -> ::pyo3_stub_gen::TypeInfo {
::pyo3_stub_gen::TypeInfo::with_module(#name, #module_tt)
::pyo3_stub_gen::TypeInfo::with_type(#name, #module_tt)
}
}
})
Expand Down
4 changes: 2 additions & 2 deletions pyo3-stub-gen/src/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ pub use module::*;
pub use stub_info::*;
pub use variable::*;

use crate::stub_type::ModuleRef;
use crate::stub_type::ImportRef;
use std::collections::HashSet;

fn indent() -> &'static str {
" "
}

pub trait Import {
fn import(&self) -> HashSet<ModuleRef>;
fn import(&self) -> HashSet<ImportRef>;
}
4 changes: 2 additions & 2 deletions pyo3-stub-gen/src/generate/arg.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{generate::Import, stub_type::ModuleRef, type_info::*, TypeInfo};
use crate::{generate::Import, stub_type::ImportRef, type_info::*, TypeInfo};
use std::{collections::HashSet, fmt};

#[derive(Debug, Clone, PartialEq)]
Expand All @@ -9,7 +9,7 @@ pub struct Arg {
}

impl Import for Arg {
fn import(&self) -> HashSet<ModuleRef> {
fn import(&self) -> HashSet<ImportRef> {
self.r#type.import.clone()
}
}
Expand Down
2 changes: 1 addition & 1 deletion pyo3-stub-gen/src/generate/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub struct ClassDef {
}

impl Import for ClassDef {
fn import(&self) -> HashSet<ModuleRef> {
fn import(&self) -> HashSet<ImportRef> {
let mut import = HashSet::new();
for base in &self.bases {
import.extend(base.import.clone());
Expand Down
2 changes: 1 addition & 1 deletion pyo3-stub-gen/src/generate/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub struct FunctionDef {
}

impl Import for FunctionDef {
fn import(&self) -> HashSet<ModuleRef> {
fn import(&self) -> HashSet<ImportRef> {
let mut import = self.r#return.import.clone();
for arg in &self.args {
import.extend(arg.import().into_iter());
Expand Down
2 changes: 1 addition & 1 deletion pyo3-stub-gen/src/generate/member.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub struct MemberDef {
}

impl Import for MemberDef {
fn import(&self) -> HashSet<ModuleRef> {
fn import(&self) -> HashSet<ImportRef> {
self.r#type.import.clone()
}
}
Expand Down
2 changes: 1 addition & 1 deletion pyo3-stub-gen/src/generate/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub struct MethodDef {
}

impl Import for MethodDef {
fn import(&self) -> HashSet<ModuleRef> {
fn import(&self) -> HashSet<ImportRef> {
let mut import = self.r#return.import.clone();
for arg in &self.args {
import.extend(arg.import().into_iter());
Expand Down
35 changes: 30 additions & 5 deletions pyo3-stub-gen/src/generate/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub struct Module {
}

impl Import for Module {
fn import(&self) -> HashSet<ModuleRef> {
fn import(&self) -> HashSet<ImportRef> {
let mut imports = HashSet::new();
for class in self.class.values() {
imports.extend(class.import());
Expand All @@ -38,12 +38,37 @@ impl fmt::Display for Module {
writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
writeln!(f, "# ruff: noqa: E501, F401")?;
writeln!(f)?;
for import in self.import().into_iter().sorted() {
let name = import.get().unwrap_or(&self.default_module_name);
if name != self.name {
writeln!(f, "import {}", name)?;
let package_name = self.default_module_name.split('.').next().unwrap();
let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
for common_ref in self.import().into_iter().sorted() {
match common_ref {
ImportRef::Module(module_ref) => {
let name = module_ref.get().unwrap_or(&self.default_module_name);
if name != self.name {
println!("Name: {}", name,);
writeln!(f, "import {}", name)?;
}
}
ImportRef::Type(type_ref) => {
let module_name = type_ref.module.get().unwrap_or(&self.default_module_name);
if module_name != self.name {
if module_name.starts_with(package_name) {
type_ref_grouped
.entry(module_name.to_string())
.or_default()
.push(type_ref.name);
} else {
writeln!(f, "import {}", module_name)?;
}
}
}
}
}
for (module_name, type_names) in type_ref_grouped {
let mut sorted_type_names = type_names.clone();
sorted_type_names.sort();
writeln!(f, "from {} import {}", module_name, sorted_type_names.join(", "))?;
}
for submod in &self.submodules {
writeln!(f, "from . import {}", submod)?;
}
Expand Down
Loading