pyo3_stub_gen/generate/
stub_info.rs1use crate::{generate::*, pyproject::PyProject, type_info::*};
2use anyhow::{Context, Result};
3use std::{
4 collections::{BTreeMap, BTreeSet},
5 fs,
6 io::Write,
7 path::*,
8};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct StubInfo {
12 pub modules: BTreeMap<String, Module>,
13 pub python_root: PathBuf,
14}
15
16impl StubInfo {
17 pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
20 let pyproject = PyProject::parse_toml(path)?;
21 StubInfoBuilder::from_pyproject_toml(pyproject).build()
22 }
23
24 pub fn from_project_root(default_module_name: String, project_root: PathBuf) -> Result<Self> {
28 StubInfoBuilder::from_project_root(default_module_name, project_root).build()
29 }
30
31 pub fn generate(&self) -> Result<()> {
32 for (name, module) in self.modules.iter() {
33 let normalized_name = name.replace("-", "_");
35 let path = normalized_name.replace(".", "/");
36 let dest = if module.submodules.is_empty() {
37 self.python_root.join(format!("{path}.pyi"))
38 } else {
39 self.python_root.join(path).join("__init__.pyi")
40 };
41
42 let dir = dest.parent().context("Cannot get parent directory")?;
43 if !dir.exists() {
44 fs::create_dir_all(dir)?;
45 }
46
47 let mut f = fs::File::create(&dest)?;
48 write!(f, "{module}")?;
49 log::info!(
50 "Generate stub file of a module `{name}` at {dest}",
51 dest = dest.display()
52 );
53 }
54 Ok(())
55 }
56}
57
58struct StubInfoBuilder {
59 modules: BTreeMap<String, Module>,
60 default_module_name: String,
61 python_root: PathBuf,
62}
63
64impl StubInfoBuilder {
65 fn from_pyproject_toml(pyproject: PyProject) -> Self {
66 StubInfoBuilder::from_project_root(
67 pyproject.module_name().to_string(),
68 pyproject
69 .python_source()
70 .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())),
71 )
72 }
73
74 fn from_project_root(default_module_name: String, project_root: PathBuf) -> Self {
75 Self {
76 modules: BTreeMap::new(),
77 default_module_name,
78 python_root: project_root,
79 }
80 }
81
82 fn get_module(&mut self, name: Option<&str>) -> &mut Module {
83 let name = name.unwrap_or(&self.default_module_name).to_string();
84 let module = self.modules.entry(name.clone()).or_default();
85 module.name = name;
86 module.default_module_name = self.default_module_name.clone();
87 module
88 }
89
90 fn register_submodules(&mut self) {
91 let mut map: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
92 for module in self.modules.keys() {
93 let path = module.split('.').collect::<Vec<_>>();
94 let n = path.len();
95 if n <= 1 {
96 continue;
97 }
98 map.entry(path[..n - 1].join("."))
99 .or_default()
100 .insert(path[n - 1].to_string());
101 }
102 for (parent, children) in map {
103 if let Some(module) = self.modules.get_mut(&parent) {
104 module.submodules.extend(children);
105 }
106 }
107 }
108
109 fn add_class(&mut self, info: &PyClassInfo) {
110 self.get_module(info.module)
111 .class
112 .insert((info.struct_id)(), ClassDef::from(info));
113 }
114
115 fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
116 self.get_module(info.module)
117 .class
118 .insert((info.enum_id)(), ClassDef::from(info));
119 }
120
121 fn add_enum(&mut self, info: &PyEnumInfo) {
122 self.get_module(info.module)
123 .enum_
124 .insert((info.enum_id)(), EnumDef::from(info));
125 }
126
127 fn add_function(&mut self, info: &PyFunctionInfo) -> Result<()> {
128 let target = self
129 .get_module(info.module)
130 .function
131 .entry(info.name)
132 .or_default();
133
134 let new_func = FunctionDef::from(info);
136 if !new_func.is_overload {
137 let non_overload_count = target.iter().filter(|f| !f.is_overload).count();
138 if non_overload_count > 0 {
139 anyhow::bail!(
140 "Multiple functions with name '{}' found without @overload decorator. \
141 Please add @overload decorator to all variants.",
142 info.name
143 );
144 }
145 }
146
147 target.push(new_func);
148 Ok(())
149 }
150
151 fn add_variable(&mut self, info: &PyVariableInfo) {
152 self.get_module(Some(info.module))
153 .variables
154 .insert(info.name, VariableDef::from(info));
155 }
156
157 fn add_module_doc(&mut self, info: &ModuleDocInfo) {
158 self.get_module(Some(info.module)).doc = (info.doc)();
159 }
160
161 fn add_methods(&mut self, info: &PyMethodsInfo) -> Result<()> {
162 let struct_id = (info.struct_id)();
163 for module in self.modules.values_mut() {
164 if let Some(entry) = module.class.get_mut(&struct_id) {
165 for attr in info.attrs {
166 entry.attrs.push(MemberDef {
167 name: attr.name,
168 r#type: (attr.r#type)(),
169 doc: attr.doc,
170 default: attr.default.map(|f| f()),
171 deprecated: attr.deprecated.clone(),
172 });
173 }
174 for getter in info.getters {
175 entry
176 .getter_setters
177 .entry(getter.name.to_string())
178 .or_default()
179 .0 = Some(MemberDef {
180 name: getter.name,
181 r#type: (getter.r#type)(),
182 doc: getter.doc,
183 default: getter.default.map(|f| f()),
184 deprecated: getter.deprecated.clone(),
185 });
186 }
187 for setter in info.setters {
188 entry
189 .getter_setters
190 .entry(setter.name.to_string())
191 .or_default()
192 .1 = Some(MemberDef {
193 name: setter.name,
194 r#type: (setter.r#type)(),
195 doc: setter.doc,
196 default: setter.default.map(|f| f()),
197 deprecated: setter.deprecated.clone(),
198 });
199 }
200 for method in info.methods {
201 let entries = entry.methods.entry(method.name.to_string()).or_default();
202
203 let new_method = MethodDef::from(method);
205 if !new_method.is_overload {
206 let non_overload_count = entries.iter().filter(|m| !m.is_overload).count();
207 if non_overload_count > 0 {
208 anyhow::bail!(
209 "Multiple methods with name '{}' in class '{}' found without @overload decorator. \
210 Please add @overload decorator to all variants.",
211 method.name, entry.name
212 );
213 }
214 }
215
216 entries.push(new_method);
217 }
218 return Ok(());
219 } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
220 for attr in info.attrs {
221 entry.attrs.push(MemberDef {
222 name: attr.name,
223 r#type: (attr.r#type)(),
224 doc: attr.doc,
225 default: attr.default.map(|f| f()),
226 deprecated: attr.deprecated.clone(),
227 });
228 }
229 for getter in info.getters {
230 entry.getters.push(MemberDef {
231 name: getter.name,
232 r#type: (getter.r#type)(),
233 doc: getter.doc,
234 default: getter.default.map(|f| f()),
235 deprecated: getter.deprecated.clone(),
236 });
237 }
238 for setter in info.setters {
239 entry.setters.push(MemberDef {
240 name: setter.name,
241 r#type: (setter.r#type)(),
242 doc: setter.doc,
243 default: setter.default.map(|f| f()),
244 deprecated: setter.deprecated.clone(),
245 });
246 }
247 for method in info.methods {
248 let new_method = MethodDef::from(method);
250 if !new_method.is_overload {
251 let non_overload_count = entry
252 .methods
253 .iter()
254 .filter(|m| m.name == method.name && !m.is_overload)
255 .count();
256 if non_overload_count > 0 {
257 anyhow::bail!(
258 "Multiple methods with name '{}' in enum '{}' found without @overload decorator. \
259 Please add @overload decorator to all variants.",
260 method.name, entry.name
261 );
262 }
263 }
264
265 entry.methods.push(new_method);
266 }
267 return Ok(());
268 }
269 }
270 unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
271 }
272
273 fn build(mut self) -> Result<StubInfo> {
274 for info in inventory::iter::<PyClassInfo> {
275 self.add_class(info);
276 }
277 for info in inventory::iter::<PyComplexEnumInfo> {
278 self.add_complex_enum(info);
279 }
280 for info in inventory::iter::<PyEnumInfo> {
281 self.add_enum(info);
282 }
283 for info in inventory::iter::<PyFunctionInfo> {
284 self.add_function(info)?;
285 }
286 for info in inventory::iter::<PyVariableInfo> {
287 self.add_variable(info);
288 }
289 for info in inventory::iter::<ModuleDocInfo> {
290 self.add_module_doc(info);
291 }
292 let mut methods_infos: Vec<&PyMethodsInfo> = inventory::iter::<PyMethodsInfo>().collect();
294 methods_infos.sort_by_key(|info| (info.file, info.line, info.column));
295 for info in methods_infos {
296 self.add_methods(info)?;
297 }
298 self.register_submodules();
299 Ok(StubInfo {
300 modules: self.modules,
301 python_root: self.python_root,
302 })
303 }
304}