pyo3_stub_gen/generate/
stub_info.rs

1use crate::{generate::*, pyproject::PyProject, type_info::*};
2use anyhow::{Context, Result};
3use std::{
4    collections::{BTreeMap, BTreeSet},
5    fs,
6    io::Write,
7    path::*,
8};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct StubInfo {
12    pub modules: BTreeMap<String, Module>,
13    pub python_root: PathBuf,
14}
15
16impl StubInfo {
17    /// Initialize [StubInfo] from a `pyproject.toml` file in `CARGO_MANIFEST_DIR`.
18    /// This is automatically set up by the [crate::define_stub_info_gatherer] macro.
19    pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
20        let pyproject = PyProject::parse_toml(path)?;
21        Ok(StubInfoBuilder::from_pyproject_toml(pyproject).build())
22    }
23
24    /// Initialize [StubInfo] with a specific module name and project root.
25    /// This must be placed in your PyO3 library crate, i.e. the same crate where [inventory::submit]ted,
26    /// not in the `gen_stub` executables due to [inventory]'s mechanism.
27    pub fn from_project_root(default_module_name: String, project_root: PathBuf) -> Result<Self> {
28        Ok(StubInfoBuilder::from_project_root(default_module_name, project_root).build())
29    }
30
31    pub fn generate(&self) -> Result<()> {
32        for (name, module) in self.modules.iter() {
33            let path = name.replace(".", "/");
34            let dest = if module.submodules.is_empty() {
35                self.python_root.join(format!("{path}.pyi"))
36            } else {
37                self.python_root.join(path).join("__init__.pyi")
38            };
39
40            let dir = dest.parent().context("Cannot get parent directory")?;
41            if !dir.exists() {
42                fs::create_dir_all(dir)?;
43            }
44
45            let mut f = fs::File::create(&dest)?;
46            write!(f, "{}", module)?;
47            log::info!(
48                "Generate stub file of a module `{name}` at {dest}",
49                dest = dest.display()
50            );
51        }
52        Ok(())
53    }
54}
55
56struct StubInfoBuilder {
57    modules: BTreeMap<String, Module>,
58    default_module_name: String,
59    python_root: PathBuf,
60}
61
62impl StubInfoBuilder {
63    fn from_pyproject_toml(pyproject: PyProject) -> Self {
64        StubInfoBuilder::from_project_root(
65            pyproject.module_name().to_string(),
66            pyproject
67                .python_source()
68                .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())),
69        )
70    }
71
72    fn from_project_root(default_module_name: String, project_root: PathBuf) -> Self {
73        Self {
74            modules: BTreeMap::new(),
75            default_module_name,
76            python_root: project_root,
77        }
78    }
79
80    fn get_module(&mut self, name: Option<&str>) -> &mut Module {
81        let name = name.unwrap_or(&self.default_module_name).to_string();
82        let module = self.modules.entry(name.clone()).or_default();
83        module.name = name;
84        module.default_module_name = self.default_module_name.clone();
85        module
86    }
87
88    fn register_submodules(&mut self) {
89        let mut map: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
90        for module in self.modules.keys() {
91            let path = module.split('.').collect::<Vec<_>>();
92            let n = path.len();
93            if n <= 1 {
94                continue;
95            }
96            map.entry(path[..n - 1].join("."))
97                .or_default()
98                .insert(path[n - 1].to_string());
99        }
100        for (parent, children) in map {
101            if let Some(module) = self.modules.get_mut(&parent) {
102                module.submodules.extend(children);
103            }
104        }
105    }
106
107    fn add_class(&mut self, info: &PyClassInfo) {
108        self.get_module(info.module)
109            .class
110            .insert((info.struct_id)(), ClassDef::from(info));
111    }
112
113    fn add_enum(&mut self, info: &PyEnumInfo) {
114        self.get_module(info.module)
115            .enum_
116            .insert((info.enum_id)(), EnumDef::from(info));
117    }
118
119    fn add_function(&mut self, info: &PyFunctionInfo) {
120        self.get_module(info.module)
121            .function
122            .insert(info.name, FunctionDef::from(info));
123    }
124
125    fn add_error(&mut self, info: &PyErrorInfo) {
126        self.get_module(Some(info.module))
127            .error
128            .insert(info.name, ErrorDef::from(info));
129    }
130
131    fn add_variable(&mut self, info: &PyVariableInfo) {
132        self.get_module(Some(info.module))
133            .variables
134            .insert(info.name, VariableDef::from(info));
135    }
136
137    fn add_methods(&mut self, info: &PyMethodsInfo) {
138        let struct_id = (info.struct_id)();
139        for module in self.modules.values_mut() {
140            if let Some(entry) = module.class.get_mut(&struct_id) {
141                for getter in info.getters {
142                    entry.members.push(MemberDef {
143                        name: getter.name,
144                        r#type: (getter.r#type)(),
145                        doc: getter.doc,
146                    });
147                }
148                for method in info.methods {
149                    entry.methods.push(MethodDef::from(method))
150                }
151                return;
152            } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
153                for getter in info.getters {
154                    entry.members.push(MemberDef {
155                        name: getter.name,
156                        r#type: (getter.r#type)(),
157                        doc: getter.doc,
158                    });
159                }
160                for method in info.methods {
161                    entry.methods.push(MethodDef::from(method))
162                }
163                return;
164            }
165        }
166        unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
167    }
168
169    fn build(mut self) -> StubInfo {
170        for info in inventory::iter::<PyClassInfo> {
171            self.add_class(info);
172        }
173        for info in inventory::iter::<PyEnumInfo> {
174            self.add_enum(info);
175        }
176        for info in inventory::iter::<PyFunctionInfo> {
177            self.add_function(info);
178        }
179        for info in inventory::iter::<PyErrorInfo> {
180            self.add_error(info);
181        }
182        for info in inventory::iter::<PyVariableInfo> {
183            self.add_variable(info);
184        }
185        for info in inventory::iter::<PyMethodsInfo> {
186            self.add_methods(info);
187        }
188        self.register_submodules();
189        StubInfo {
190            modules: self.modules,
191            python_root: self.python_root,
192        }
193    }
194}