pyo3_stub_gen/generate/
module.rs

1use crate::generate::*;
2use crate::stub_type::ImportRef;
3use itertools::Itertools;
4use std::{
5    any::TypeId,
6    collections::{BTreeMap, BTreeSet},
7    fmt,
8};
9
10/// Re-export from another module for __all__
11#[derive(Debug, Clone, PartialEq)]
12pub struct ModuleReExport {
13    pub source_module: String,
14    /// Items to re-export. Empty means wildcard (will be resolved).
15    pub items: Vec<String>,
16    /// Additional items to include with wildcard (e.g., `__version__`).
17    /// These are merged into `items` after wildcard resolution.
18    pub additional_items: Vec<String>,
19}
20
21/// Type info for a Python (sub-)module. This corresponds to a single `*.pyi` file.
22#[derive(Debug, Clone, PartialEq, Default)]
23pub struct Module {
24    pub doc: String,
25    pub class: BTreeMap<TypeId, ClassDef>,
26    pub enum_: BTreeMap<TypeId, EnumDef>,
27    pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
28    pub variables: BTreeMap<&'static str, VariableDef>,
29    pub type_aliases: BTreeMap<&'static str, TypeAliasDef>,
30    pub name: String,
31    pub default_module_name: String,
32    /// Direct submodules of this module.
33    pub submodules: BTreeSet<String>,
34    /// Module re-exports for __all__
35    pub module_re_exports: Vec<ModuleReExport>,
36    /// Verbatim entries to add to __all__
37    pub verbatim_all_entries: BTreeSet<String>,
38    /// Explicitly excluded entries from __all__
39    pub excluded_all_entries: BTreeSet<String>,
40}
41
42impl Module {
43    /// Check if this module has no content to generate.
44    ///
45    /// Returns true if the module has no classes, enums, functions, variables,
46    /// type aliases, submodules, re-exports, docstrings, or verbatim entries.
47    /// Modules that are empty should be skipped during generation.
48    pub fn is_empty(&self) -> bool {
49        self.doc.is_empty()
50            && self.class.is_empty()
51            && self.enum_.is_empty()
52            && self.function.is_empty()
53            && self.variables.is_empty()
54            && self.type_aliases.is_empty()
55            && self.submodules.is_empty()
56            && self.module_re_exports.is_empty()
57            && self.verbatim_all_entries.is_empty()
58    }
59
60    /// Check if this module can have `__init__.py` generated.
61    ///
62    /// Returns true if the module has no PyO3-generated items (classes, enums,
63    /// functions, variables, type aliases). Such modules can only contain
64    /// re-exports and docstrings, which can be represented in `__init__.py`.
65    pub fn is_init_py_compatible(&self) -> bool {
66        self.class.is_empty()
67            && self.enum_.is_empty()
68            && self.function.is_empty()
69            && self.variables.is_empty()
70            && self.type_aliases.is_empty()
71    }
72
73    /// Get the names of all declared items in this module.
74    ///
75    /// Returns a list of item names that were declared via `gen_stub_*` macros.
76    pub fn declared_item_names(&self) -> Vec<String> {
77        let mut names = Vec::new();
78
79        if !self.doc.is_empty() {
80            names.push("module_doc".to_string());
81        }
82        for class in self.class.values() {
83            names.push(format!("class {}", class.name));
84        }
85        for enum_def in self.enum_.values() {
86            names.push(format!("enum {}", enum_def.name));
87        }
88        for func_name in self.function.keys() {
89            names.push(format!("function {}", func_name));
90        }
91        for var_name in self.variables.keys() {
92            names.push(format!("variable {}", var_name));
93        }
94        for alias_name in self.type_aliases.keys() {
95            names.push(format!("type_alias {}", alias_name));
96        }
97        for re_export in &self.module_re_exports {
98            names.push(format!("re-export from {}", re_export.source_module));
99        }
100
101        names.sort();
102        names
103    }
104
105    /// Format module with configuration for type alias syntax, returning a String
106    pub fn format_with_config(&self, use_type_statement: bool) -> String {
107        use std::fmt::Write;
108        let mut output = String::new();
109
110        // Use a custom formatter struct
111        struct ModuleFormatter<'a> {
112            module: &'a Module,
113            use_type_statement: bool,
114        }
115
116        impl<'a> fmt::Display for ModuleFormatter<'a> {
117            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
118                // Write header and docstring
119                writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
120                writeln!(f, "# ruff: noqa: E501, F401, F403, F405")?;
121                if !self.module.doc.is_empty() {
122                    docstring::write_docstring(f, &self.module.doc, "")?;
123                }
124                writeln!(f)?;
125
126                let mut imports = self.module.import();
127
128                // Conditionally add TypeAlias import
129                if !self.use_type_statement && !self.module.type_aliases.is_empty() {
130                    imports.insert(ImportRef::Type(crate::stub_type::TypeRef {
131                        module: crate::stub_type::ModuleRef::Named("typing".to_string()),
132                        name: "TypeAlias".to_string(),
133                    }));
134                }
135
136                // Check for overload decorator
137                let any_overloaded = self.module.function.values().any(|functions| {
138                    let has_overload = functions.iter().any(|func| func.is_overload);
139                    functions.len() > 1 && has_overload
140                });
141                if any_overloaded {
142                    imports.insert("typing".into());
143                }
144
145                // Generate imports (same logic as Display impl)
146                let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
147                for import_ref in imports.into_iter().sorted() {
148                    match import_ref {
149                        ImportRef::Module(module_ref) => {
150                            let name = module_ref.get().unwrap_or(&self.module.default_module_name);
151                            if name != self.module.name && !name.is_empty() {
152                                let is_internal_module = if let Some(root) =
153                                    self.module.default_module_name.split('.').next()
154                                {
155                                    name.starts_with(root)
156                                } else {
157                                    false
158                                };
159
160                                if is_internal_module && name.contains('.') {
161                                    let last_dot_pos = name.rfind('.').unwrap();
162                                    let parent_module = &name[..last_dot_pos];
163                                    let child_module = &name[last_dot_pos + 1..];
164
165                                    if !self.module.submodules.contains(child_module) {
166                                        writeln!(
167                                            f,
168                                            "from {} import {}",
169                                            parent_module, child_module
170                                        )?;
171                                    }
172                                } else {
173                                    writeln!(f, "import {name}")?;
174                                }
175                            }
176                        }
177                        ImportRef::Type(type_ref) => {
178                            let module_name = type_ref
179                                .module
180                                .get()
181                                .unwrap_or(&self.module.default_module_name);
182                            if module_name != self.module.name {
183                                type_ref_grouped
184                                    .entry(module_name.to_string())
185                                    .or_default()
186                                    .push(type_ref.name);
187                            }
188                        }
189                    }
190                }
191                for (module_name, type_names) in type_ref_grouped {
192                    let mut sorted_type_names = type_names.clone();
193                    sorted_type_names.sort();
194                    writeln!(
195                        f,
196                        "from {} import {}",
197                        module_name,
198                        sorted_type_names.join(", ")
199                    )?;
200                }
201
202                // Add imports for module re-exports (always explicit, not wildcard)
203                let mut sorted_re_exports = self.module.module_re_exports.clone();
204                sorted_re_exports.sort_by(|a, b| a.source_module.cmp(&b.source_module));
205                for re_export in &sorted_re_exports {
206                    if re_export.items.is_empty() {
207                        continue;
208                    }
209                    let mut sorted_items = re_export.items.clone();
210                    sorted_items.sort();
211                    writeln!(
212                        f,
213                        "from {} import {}",
214                        re_export.source_module,
215                        sorted_items.join(", ")
216                    )?;
217                }
218                for submod in &self.module.submodules {
219                    writeln!(f, "from . import {submod}")?;
220                }
221
222                // Generate __all__ list
223                self.module.write_all_list(f)?;
224
225                writeln!(f)?;
226
227                // Generate type aliases with configuration
228                for alias in self.module.type_aliases.values() {
229                    alias.fmt_with_config(&self.module.name, f, self.use_type_statement)?;
230                    writeln!(f)?;
231                }
232
233                // Generate variables
234                for var in self.module.variables.values() {
235                    var.fmt_for_module(&self.module.name, f)?;
236                    writeln!(f)?;
237                }
238
239                // Generate classes
240                for class in self.module.class.values().sorted_by_key(|class| class.name) {
241                    class.fmt_for_module(&self.module.name, f)?;
242                }
243
244                // Generate enums
245                for enum_ in self.module.enum_.values().sorted_by_key(|enum_| enum_.name) {
246                    enum_.fmt_for_module(&self.module.name, f)?;
247                }
248
249                // Generate functions
250                for functions in self.module.function.values() {
251                    let has_overload = functions.iter().any(|func| func.is_overload);
252                    let should_add_overload = functions.len() > 1 && has_overload;
253
254                    let mut sorted_functions = functions.clone();
255                    sorted_functions
256                        .sort_by_key(|func| (func.file, func.line, func.column, func.index));
257                    for function in sorted_functions {
258                        if should_add_overload {
259                            writeln!(f, "@typing.overload")?;
260                        }
261                        function.fmt_for_module(&self.module.name, f)?;
262                    }
263                }
264
265                Ok(())
266            }
267        }
268
269        write!(
270            &mut output,
271            "{}",
272            ModuleFormatter {
273                module: self,
274                use_type_statement
275            }
276        )
277        .unwrap();
278        output
279    }
280
281    /// Collect all items for the `__all__` list in `.pyi` stub files.
282    ///
283    /// This collects public items from classes, enums, functions, variables,
284    /// type aliases, submodules, re-exports, and verbatim entries.
285    /// Items starting with `_` are excluded unless added via `export_verbatim!`.
286    fn collect_all_items(&self) -> BTreeSet<String> {
287        let mut all_items: BTreeSet<String> = BTreeSet::new();
288
289        // Collect public items from this module
290        for class in self.class.values() {
291            if !class.name.starts_with('_') {
292                all_items.insert(class.name.to_string());
293            }
294        }
295        for enum_ in self.enum_.values() {
296            if !enum_.name.starts_with('_') {
297                all_items.insert(enum_.name.to_string());
298            }
299        }
300        for func_name in self.function.keys() {
301            if !func_name.starts_with('_') {
302                all_items.insert(func_name.to_string());
303            }
304        }
305        for var_name in self.variables.keys() {
306            if !var_name.starts_with('_') {
307                all_items.insert(var_name.to_string());
308            }
309        }
310        for alias_name in self.type_aliases.keys() {
311            if !alias_name.starts_with('_') {
312                all_items.insert(alias_name.to_string());
313            }
314        }
315        for submod in &self.submodules {
316            if !submod.starts_with('_') {
317                all_items.insert(submod.to_string());
318            }
319        }
320
321        // Add items from re-exports
322        for re_export in &self.module_re_exports {
323            all_items.extend(re_export.items.iter().cloned());
324        }
325
326        // Add verbatim entries (allows explicit inclusion of underscore items)
327        all_items.extend(self.verbatim_all_entries.iter().cloned());
328
329        // Remove explicitly excluded items
330        for excluded in &self.excluded_all_entries {
331            all_items.remove(excluded);
332        }
333
334        all_items
335    }
336
337    /// Format module as `__init__.py` content.
338    ///
339    /// This generates a Python `__init__.py` file with re-exports and `__all__` list.
340    /// Unlike `format_with_config()` which generates `.pyi` stub files, this generates
341    /// actual Python code for runtime use.
342    pub fn format_init_py(&self) -> String {
343        use std::fmt::Write;
344        let mut output = String::new();
345
346        // Header
347        writeln!(
348            output,
349            "# This file is automatically generated by pyo3_stub_gen"
350        )
351        .unwrap();
352        writeln!(output, "# ruff: noqa: F401").unwrap();
353        writeln!(output).unwrap();
354
355        // Re-export imports (sorted for deterministic output)
356        // Always use explicit imports (not wildcard) for better tooling support
357        let mut sorted_re_exports = self.module_re_exports.clone();
358        sorted_re_exports.sort_by(|a, b| a.source_module.cmp(&b.source_module));
359        for re_export in &sorted_re_exports {
360            if re_export.items.is_empty() {
361                continue;
362            }
363            let mut sorted_items = re_export.items.clone();
364            sorted_items.sort();
365            writeln!(
366                output,
367                "from {} import {}",
368                re_export.source_module,
369                sorted_items.join(", ")
370            )
371            .unwrap();
372        }
373
374        // Collect __all__ items from re-exports only
375        // (Unlike __init__.pyi, __init__.py can only contain re-exported items)
376        let mut all_items: BTreeSet<String> = BTreeSet::new();
377        for re_export in &self.module_re_exports {
378            all_items.extend(re_export.items.iter().cloned());
379        }
380        all_items.extend(self.verbatim_all_entries.iter().cloned());
381        for excluded in &self.excluded_all_entries {
382            all_items.remove(excluded);
383        }
384        if all_items.is_empty() {
385            writeln!(output, "__all__ = []").unwrap();
386        } else {
387            writeln!(output, "__all__ = [").unwrap();
388            for item in all_items {
389                writeln!(output, "    \"{}\",", item).unwrap();
390            }
391            writeln!(output, "]").unwrap();
392        }
393
394        output
395    }
396
397    fn write_all_list(&self, f: &mut fmt::Formatter) -> fmt::Result {
398        let all_items = self.collect_all_items();
399
400        // Always write __all__ list (even if empty for consistency)
401        if all_items.is_empty() {
402            writeln!(f, "__all__ = []")?;
403        } else {
404            writeln!(f, "__all__ = [")?;
405            for item in all_items {
406                writeln!(f, "    \"{}\",", item)?;
407            }
408            writeln!(f, "]")?;
409        }
410
411        Ok(())
412    }
413}
414
415impl Import for Module {
416    fn import(&self) -> HashSet<ImportRef> {
417        let mut imports = HashSet::new();
418        for class in self.class.values() {
419            imports.extend(class.import());
420        }
421        for enum_ in self.enum_.values() {
422            imports.extend(enum_.import());
423        }
424        for function in self.function.values().flatten() {
425            imports.extend(function.import());
426        }
427        for variable in self.variables.values() {
428            imports.extend(variable.import());
429        }
430        for type_alias in self.type_aliases.values() {
431            imports.extend(type_alias.import());
432        }
433        imports
434    }
435}
436
437impl fmt::Display for Module {
438    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
439        // Write header and docstring directly (no type names here)
440        writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
441        writeln!(f, "# ruff: noqa: E501, F401, F403, F405")?;
442        if !self.doc.is_empty() {
443            docstring::write_docstring(f, &self.doc, "")?;
444        }
445        writeln!(f)?;
446
447        let mut imports = self.import();
448        // Check if any function group needs @overload decorator
449        let any_overloaded = self.function.values().any(|functions| {
450            let has_overload = functions.iter().any(|f| f.is_overload);
451            functions.len() > 1 && has_overload
452        });
453        if any_overloaded {
454            imports.insert("typing".into());
455        }
456
457        // To gather `from submod import A, B, C` style imports
458        let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
459        for import_ref in imports.into_iter().sorted() {
460            match import_ref {
461                ImportRef::Module(module_ref) => {
462                    let name = module_ref.get().unwrap_or(&self.default_module_name);
463                    if name != self.name && !name.is_empty() {
464                        // Check if this is a module within the current package
465                        // by checking if the module name starts with the package name
466                        let is_internal_module =
467                            if let Some(root) = self.default_module_name.split('.').next() {
468                                name.starts_with(root)
469                            } else {
470                                false
471                            };
472
473                        // For nested modules like "package.module.submodule" within the current package
474                        // Generate: from package.module import submodule
475                        // For external modules like "collections.abc", use: import collections.abc
476                        if is_internal_module && name.contains('.') {
477                            let last_dot_pos = name.rfind('.').unwrap();
478                            let parent_module = &name[..last_dot_pos];
479                            let child_module = &name[last_dot_pos + 1..];
480
481                            // Skip if this is a direct submodule (already imported via submodule imports)
482                            if !self.submodules.contains(child_module) {
483                                writeln!(f, "from {} import {}", parent_module, child_module)?;
484                            }
485                        } else {
486                            // External module or top-level module - use standard import
487                            writeln!(f, "import {name}")?;
488                        }
489                    }
490                }
491                ImportRef::Type(type_ref) => {
492                    let module_name = type_ref.module.get().unwrap_or(&self.default_module_name);
493                    if module_name != self.name {
494                        type_ref_grouped
495                            .entry(module_name.to_string())
496                            .or_default()
497                            .push(type_ref.name);
498                    }
499                }
500            }
501        }
502        for (module_name, type_names) in type_ref_grouped {
503            let mut sorted_type_names = type_names.clone();
504            sorted_type_names.sort();
505            writeln!(
506                f,
507                "from {} import {}",
508                module_name,
509                sorted_type_names.join(", ")
510            )?;
511        }
512        // Add imports for module re-exports (always explicit, not wildcard)
513        let mut sorted_re_exports = self.module_re_exports.clone();
514        sorted_re_exports.sort_by(|a, b| a.source_module.cmp(&b.source_module));
515        for re_export in &sorted_re_exports {
516            if re_export.items.is_empty() {
517                continue;
518            }
519            let mut sorted_items = re_export.items.clone();
520            sorted_items.sort();
521            writeln!(
522                f,
523                "from {} import {}",
524                re_export.source_module,
525                sorted_items.join(", ")
526            )?;
527        }
528        for submod in &self.submodules {
529            writeln!(f, "from . import {submod}")?;
530        }
531
532        // Generate __all__ list
533        self.write_all_list(f)?;
534
535        writeln!(f)?;
536
537        for alias in self.type_aliases.values() {
538            alias.fmt_for_module(&self.name, f)?;
539            writeln!(f)?;
540        }
541        for var in self.variables.values() {
542            var.fmt_for_module(&self.name, f)?;
543            writeln!(f)?;
544        }
545        for class in self.class.values().sorted_by_key(|class| class.name) {
546            class.fmt_for_module(&self.name, f)?;
547        }
548        for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
549            enum_.fmt_for_module(&self.name, f)?;
550        }
551        for functions in self.function.values() {
552            // Check if we should add @overload to all functions
553            let has_overload = functions.iter().any(|func| func.is_overload);
554            let should_add_overload = functions.len() > 1 && has_overload;
555
556            // Sort by source location and index for deterministic ordering
557            let mut sorted_functions = functions.clone();
558            sorted_functions.sort_by_key(|func| (func.file, func.line, func.column, func.index));
559            for function in sorted_functions {
560                if should_add_overload {
561                    writeln!(f, "@typing.overload")?;
562                }
563                function.fmt_for_module(&self.name, f)?;
564            }
565        }
566        Ok(())
567    }
568}