pyo3_stub_gen/generate/
module.rs

1use crate::generate::*;
2use crate::stub_type::ImportRef;
3use itertools::Itertools;
4use std::{
5    any::TypeId,
6    collections::{BTreeMap, BTreeSet},
7    fmt,
8};
9
10/// Re-export from another module for __all__
11#[derive(Debug, Clone, PartialEq)]
12pub struct ModuleReExport {
13    pub source_module: String,
14    pub items: Vec<String>,
15    pub use_wildcard_import: bool,
16}
17
18/// Type info for a Python (sub-)module. This corresponds to a single `*.pyi` file.
19#[derive(Debug, Clone, PartialEq, Default)]
20pub struct Module {
21    pub doc: String,
22    pub class: BTreeMap<TypeId, ClassDef>,
23    pub enum_: BTreeMap<TypeId, EnumDef>,
24    pub function: BTreeMap<&'static str, Vec<FunctionDef>>,
25    pub variables: BTreeMap<&'static str, VariableDef>,
26    pub type_aliases: BTreeMap<&'static str, TypeAliasDef>,
27    pub name: String,
28    pub default_module_name: String,
29    /// Direct submodules of this module.
30    pub submodules: BTreeSet<String>,
31    /// Module re-exports for __all__
32    pub module_re_exports: Vec<ModuleReExport>,
33    /// Verbatim entries to add to __all__
34    pub verbatim_all_entries: BTreeSet<String>,
35    /// Explicitly excluded entries from __all__
36    pub excluded_all_entries: BTreeSet<String>,
37}
38
39impl Module {
40    /// Format module with configuration for type alias syntax, returning a String
41    pub fn format_with_config(&self, use_type_statement: bool) -> String {
42        use std::fmt::Write;
43        let mut output = String::new();
44
45        // Use a custom formatter struct
46        struct ModuleFormatter<'a> {
47            module: &'a Module,
48            use_type_statement: bool,
49        }
50
51        impl<'a> fmt::Display for ModuleFormatter<'a> {
52            fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
53                // Write header and docstring
54                writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
55                writeln!(f, "# ruff: noqa: E501, F401, F403, F405")?;
56                if !self.module.doc.is_empty() {
57                    docstring::write_docstring(f, &self.module.doc, "")?;
58                }
59                writeln!(f)?;
60
61                let mut imports = self.module.import();
62
63                // Conditionally add TypeAlias import
64                if !self.use_type_statement && !self.module.type_aliases.is_empty() {
65                    imports.insert(ImportRef::Type(crate::stub_type::TypeRef {
66                        module: crate::stub_type::ModuleRef::Named("typing".to_string()),
67                        name: "TypeAlias".to_string(),
68                    }));
69                }
70
71                // Check for overload decorator
72                let any_overloaded = self.module.function.values().any(|functions| {
73                    let has_overload = functions.iter().any(|func| func.is_overload);
74                    functions.len() > 1 && has_overload
75                });
76                if any_overloaded {
77                    imports.insert("typing".into());
78                }
79
80                // Generate imports (same logic as Display impl)
81                let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
82                for import_ref in imports.into_iter().sorted() {
83                    match import_ref {
84                        ImportRef::Module(module_ref) => {
85                            let name = module_ref.get().unwrap_or(&self.module.default_module_name);
86                            if name != self.module.name && !name.is_empty() {
87                                let is_internal_module = if let Some(root) =
88                                    self.module.default_module_name.split('.').next()
89                                {
90                                    name.starts_with(root)
91                                } else {
92                                    false
93                                };
94
95                                if is_internal_module && name.contains('.') {
96                                    let last_dot_pos = name.rfind('.').unwrap();
97                                    let parent_module = &name[..last_dot_pos];
98                                    let child_module = &name[last_dot_pos + 1..];
99
100                                    if !self.module.submodules.contains(child_module) {
101                                        writeln!(
102                                            f,
103                                            "from {} import {}",
104                                            parent_module, child_module
105                                        )?;
106                                    }
107                                } else {
108                                    writeln!(f, "import {name}")?;
109                                }
110                            }
111                        }
112                        ImportRef::Type(type_ref) => {
113                            let module_name = type_ref
114                                .module
115                                .get()
116                                .unwrap_or(&self.module.default_module_name);
117                            if module_name != self.module.name {
118                                type_ref_grouped
119                                    .entry(module_name.to_string())
120                                    .or_default()
121                                    .push(type_ref.name);
122                            }
123                        }
124                    }
125                }
126                for (module_name, type_names) in type_ref_grouped {
127                    let mut sorted_type_names = type_names.clone();
128                    sorted_type_names.sort();
129                    writeln!(
130                        f,
131                        "from {} import {}",
132                        module_name,
133                        sorted_type_names.join(", ")
134                    )?;
135                }
136
137                // Add imports for module re-exports
138                let mut sorted_re_exports = self.module.module_re_exports.clone();
139                sorted_re_exports.sort_by(|a, b| a.source_module.cmp(&b.source_module));
140                for re_export in &sorted_re_exports {
141                    if re_export.use_wildcard_import {
142                        writeln!(f, "from {} import *", re_export.source_module)?;
143                    } else {
144                        let mut sorted_items = re_export.items.clone();
145                        sorted_items.sort();
146                        writeln!(
147                            f,
148                            "from {} import {}",
149                            re_export.source_module,
150                            sorted_items.join(", ")
151                        )?;
152                    }
153                }
154                for submod in &self.module.submodules {
155                    writeln!(f, "from . import {submod}")?;
156                }
157
158                // Generate __all__ list
159                self.module.write_all_list(f)?;
160
161                writeln!(f)?;
162
163                // Generate type aliases with configuration
164                for alias in self.module.type_aliases.values() {
165                    alias.fmt_with_config(&self.module.name, f, self.use_type_statement)?;
166                    writeln!(f)?;
167                }
168
169                // Generate variables
170                for var in self.module.variables.values() {
171                    var.fmt_for_module(&self.module.name, f)?;
172                    writeln!(f)?;
173                }
174
175                // Generate classes
176                for class in self.module.class.values().sorted_by_key(|class| class.name) {
177                    class.fmt_for_module(&self.module.name, f)?;
178                }
179
180                // Generate enums
181                for enum_ in self.module.enum_.values().sorted_by_key(|enum_| enum_.name) {
182                    enum_.fmt_for_module(&self.module.name, f)?;
183                }
184
185                // Generate functions
186                for functions in self.module.function.values() {
187                    let has_overload = functions.iter().any(|func| func.is_overload);
188                    let should_add_overload = functions.len() > 1 && has_overload;
189
190                    let mut sorted_functions = functions.clone();
191                    sorted_functions
192                        .sort_by_key(|func| (func.file, func.line, func.column, func.index));
193                    for function in sorted_functions {
194                        if should_add_overload {
195                            writeln!(f, "@typing.overload")?;
196                        }
197                        function.fmt_for_module(&self.module.name, f)?;
198                    }
199                }
200
201                Ok(())
202            }
203        }
204
205        write!(
206            &mut output,
207            "{}",
208            ModuleFormatter {
209                module: self,
210                use_type_statement
211            }
212        )
213        .unwrap();
214        output
215    }
216
217    fn write_all_list(&self, f: &mut fmt::Formatter) -> fmt::Result {
218        let mut all_items: BTreeSet<String> = BTreeSet::new();
219
220        // Collect public items from this module
221        for class in self.class.values() {
222            if !class.name.starts_with('_') {
223                all_items.insert(class.name.to_string());
224            }
225        }
226        for enum_ in self.enum_.values() {
227            if !enum_.name.starts_with('_') {
228                all_items.insert(enum_.name.to_string());
229            }
230        }
231        for func_name in self.function.keys() {
232            if !func_name.starts_with('_') {
233                all_items.insert(func_name.to_string());
234            }
235        }
236        for var_name in self.variables.keys() {
237            if !var_name.starts_with('_') {
238                all_items.insert(var_name.to_string());
239            }
240        }
241        for alias_name in self.type_aliases.keys() {
242            if !alias_name.starts_with('_') {
243                all_items.insert(alias_name.to_string());
244            }
245        }
246        // FIX: Add underscore filtering for submodules
247        for submod in &self.submodules {
248            if !submod.starts_with('_') {
249                all_items.insert(submod.to_string());
250            }
251        }
252
253        // Add items from re-exports
254        for re_export in &self.module_re_exports {
255            all_items.extend(re_export.items.iter().cloned());
256        }
257
258        // Add verbatim entries (allows explicit inclusion of underscore items)
259        all_items.extend(self.verbatim_all_entries.iter().cloned());
260
261        // Remove explicitly excluded items
262        for excluded in &self.excluded_all_entries {
263            all_items.remove(excluded);
264        }
265
266        // Always write __all__ list (even if empty for consistency)
267        if all_items.is_empty() {
268            writeln!(f, "__all__ = []")?;
269        } else {
270            writeln!(f, "__all__ = [")?;
271            for item in all_items {
272                writeln!(f, "    \"{}\",", item)?;
273            }
274            writeln!(f, "]")?;
275        }
276
277        Ok(())
278    }
279}
280
281impl Import for Module {
282    fn import(&self) -> HashSet<ImportRef> {
283        let mut imports = HashSet::new();
284        for class in self.class.values() {
285            imports.extend(class.import());
286        }
287        for enum_ in self.enum_.values() {
288            imports.extend(enum_.import());
289        }
290        for function in self.function.values().flatten() {
291            imports.extend(function.import());
292        }
293        for variable in self.variables.values() {
294            imports.extend(variable.import());
295        }
296        for type_alias in self.type_aliases.values() {
297            imports.extend(type_alias.import());
298        }
299        imports
300    }
301}
302
303impl fmt::Display for Module {
304    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
305        // Write header and docstring directly (no type names here)
306        writeln!(f, "# This file is automatically generated by pyo3_stub_gen")?;
307        writeln!(f, "# ruff: noqa: E501, F401, F403, F405")?;
308        if !self.doc.is_empty() {
309            docstring::write_docstring(f, &self.doc, "")?;
310        }
311        writeln!(f)?;
312
313        let mut imports = self.import();
314        // Check if any function group needs @overload decorator
315        let any_overloaded = self.function.values().any(|functions| {
316            let has_overload = functions.iter().any(|f| f.is_overload);
317            functions.len() > 1 && has_overload
318        });
319        if any_overloaded {
320            imports.insert("typing".into());
321        }
322
323        // To gather `from submod import A, B, C` style imports
324        let mut type_ref_grouped: BTreeMap<String, Vec<String>> = BTreeMap::new();
325        for import_ref in imports.into_iter().sorted() {
326            match import_ref {
327                ImportRef::Module(module_ref) => {
328                    let name = module_ref.get().unwrap_or(&self.default_module_name);
329                    if name != self.name && !name.is_empty() {
330                        // Check if this is a module within the current package
331                        // by checking if the module name starts with the package name
332                        let is_internal_module =
333                            if let Some(root) = self.default_module_name.split('.').next() {
334                                name.starts_with(root)
335                            } else {
336                                false
337                            };
338
339                        // For nested modules like "package.module.submodule" within the current package
340                        // Generate: from package.module import submodule
341                        // For external modules like "collections.abc", use: import collections.abc
342                        if is_internal_module && name.contains('.') {
343                            let last_dot_pos = name.rfind('.').unwrap();
344                            let parent_module = &name[..last_dot_pos];
345                            let child_module = &name[last_dot_pos + 1..];
346
347                            // Skip if this is a direct submodule (already imported via submodule imports)
348                            if !self.submodules.contains(child_module) {
349                                writeln!(f, "from {} import {}", parent_module, child_module)?;
350                            }
351                        } else {
352                            // External module or top-level module - use standard import
353                            writeln!(f, "import {name}")?;
354                        }
355                    }
356                }
357                ImportRef::Type(type_ref) => {
358                    let module_name = type_ref.module.get().unwrap_or(&self.default_module_name);
359                    if module_name != self.name {
360                        type_ref_grouped
361                            .entry(module_name.to_string())
362                            .or_default()
363                            .push(type_ref.name);
364                    }
365                }
366            }
367        }
368        for (module_name, type_names) in type_ref_grouped {
369            let mut sorted_type_names = type_names.clone();
370            sorted_type_names.sort();
371            writeln!(
372                f,
373                "from {} import {}",
374                module_name,
375                sorted_type_names.join(", ")
376            )?;
377        }
378        // Add imports for module re-exports (sorted for deterministic output)
379        let mut sorted_re_exports = self.module_re_exports.clone();
380        sorted_re_exports.sort_by(|a, b| a.source_module.cmp(&b.source_module));
381        for re_export in &sorted_re_exports {
382            if re_export.use_wildcard_import {
383                // Wildcard: from source import *
384                writeln!(f, "from {} import *", re_export.source_module)?;
385            } else {
386                // Specific items: from source import item1, item2
387                let mut sorted_items = re_export.items.clone();
388                sorted_items.sort();
389                writeln!(
390                    f,
391                    "from {} import {}",
392                    re_export.source_module,
393                    sorted_items.join(", ")
394                )?;
395            }
396        }
397        for submod in &self.submodules {
398            writeln!(f, "from . import {submod}")?;
399        }
400
401        // Generate __all__ list
402        self.write_all_list(f)?;
403
404        writeln!(f)?;
405
406        for alias in self.type_aliases.values() {
407            alias.fmt_for_module(&self.name, f)?;
408            writeln!(f)?;
409        }
410        for var in self.variables.values() {
411            var.fmt_for_module(&self.name, f)?;
412            writeln!(f)?;
413        }
414        for class in self.class.values().sorted_by_key(|class| class.name) {
415            class.fmt_for_module(&self.name, f)?;
416        }
417        for enum_ in self.enum_.values().sorted_by_key(|class| class.name) {
418            enum_.fmt_for_module(&self.name, f)?;
419        }
420        for functions in self.function.values() {
421            // Check if we should add @overload to all functions
422            let has_overload = functions.iter().any(|func| func.is_overload);
423            let should_add_overload = functions.len() > 1 && has_overload;
424
425            // Sort by source location and index for deterministic ordering
426            let mut sorted_functions = functions.clone();
427            sorted_functions.sort_by_key(|func| (func.file, func.line, func.column, func.index));
428            for function in sorted_functions {
429                if should_add_overload {
430                    writeln!(f, "@typing.overload")?;
431                }
432                function.fmt_for_module(&self.name, f)?;
433            }
434        }
435        Ok(())
436    }
437}