pyo3_stub_gen/generate/
stub_info.rs

1use crate::{generate::*, pyproject::PyProject, type_info::*};
2use anyhow::{Context, Result};
3use std::{
4    collections::{BTreeMap, BTreeSet},
5    fs,
6    io::Write,
7    path::*,
8};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct StubInfo {
12    pub modules: BTreeMap<String, Module>,
13    pub python_root: PathBuf,
14}
15
16impl StubInfo {
17    /// Initialize [StubInfo] from a `pyproject.toml` file in `CARGO_MANIFEST_DIR`.
18    /// This is automatically set up by the [crate::define_stub_info_gatherer] macro.
19    pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
20        let pyproject = PyProject::parse_toml(path)?;
21        Ok(StubInfoBuilder::from_pyproject_toml(pyproject).build())
22    }
23
24    /// Initialize [StubInfo] with a specific module name and project root.
25    /// This must be placed in your PyO3 library crate, i.e. the same crate where [inventory::submit]ted,
26    /// not in the `gen_stub` executables due to [inventory]'s mechanism.
27    pub fn from_project_root(default_module_name: String, project_root: PathBuf) -> Result<Self> {
28        Ok(StubInfoBuilder::from_project_root(default_module_name, project_root).build())
29    }
30
31    pub fn generate(&self) -> Result<()> {
32        for (name, module) in self.modules.iter() {
33            // Convert dashes to underscores for Python compatibility
34            let normalized_name = name.replace("-", "_");
35            let path = normalized_name.replace(".", "/");
36            let dest = if module.submodules.is_empty() {
37                self.python_root.join(format!("{path}.pyi"))
38            } else {
39                self.python_root.join(path).join("__init__.pyi")
40            };
41
42            let dir = dest.parent().context("Cannot get parent directory")?;
43            if !dir.exists() {
44                fs::create_dir_all(dir)?;
45            }
46
47            let mut f = fs::File::create(&dest)?;
48            write!(f, "{module}")?;
49            log::info!(
50                "Generate stub file of a module `{name}` at {dest}",
51                dest = dest.display()
52            );
53        }
54        Ok(())
55    }
56}
57
58struct StubInfoBuilder {
59    modules: BTreeMap<String, Module>,
60    default_module_name: String,
61    python_root: PathBuf,
62}
63
64impl StubInfoBuilder {
65    fn from_pyproject_toml(pyproject: PyProject) -> Self {
66        StubInfoBuilder::from_project_root(
67            pyproject.module_name().to_string(),
68            pyproject
69                .python_source()
70                .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())),
71        )
72    }
73
74    fn from_project_root(default_module_name: String, project_root: PathBuf) -> Self {
75        Self {
76            modules: BTreeMap::new(),
77            default_module_name,
78            python_root: project_root,
79        }
80    }
81
82    fn get_module(&mut self, name: Option<&str>) -> &mut Module {
83        let name = name.unwrap_or(&self.default_module_name).to_string();
84        let module = self.modules.entry(name.clone()).or_default();
85        module.name = name;
86        module.default_module_name = self.default_module_name.clone();
87        module
88    }
89
90    fn register_submodules(&mut self) {
91        let mut map: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
92        for module in self.modules.keys() {
93            let path = module.split('.').collect::<Vec<_>>();
94            let n = path.len();
95            if n <= 1 {
96                continue;
97            }
98            map.entry(path[..n - 1].join("."))
99                .or_default()
100                .insert(path[n - 1].to_string());
101        }
102        for (parent, children) in map {
103            if let Some(module) = self.modules.get_mut(&parent) {
104                module.submodules.extend(children);
105            }
106        }
107    }
108
109    fn add_class(&mut self, info: &PyClassInfo) {
110        self.get_module(info.module)
111            .class
112            .insert((info.struct_id)(), ClassDef::from(info));
113    }
114
115    fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
116        self.get_module(info.module)
117            .class
118            .insert((info.enum_id)(), ClassDef::from(info));
119    }
120
121    fn add_enum(&mut self, info: &PyEnumInfo) {
122        self.get_module(info.module)
123            .enum_
124            .insert((info.enum_id)(), EnumDef::from(info));
125    }
126
127    fn add_function(&mut self, info: &PyFunctionInfo) {
128        let target = self
129            .get_module(info.module)
130            .function
131            .entry(info.name)
132            .or_default();
133        target.push(FunctionDef::from(info));
134    }
135
136    fn add_variable(&mut self, info: &PyVariableInfo) {
137        self.get_module(Some(info.module))
138            .variables
139            .insert(info.name, VariableDef::from(info));
140    }
141
142    fn add_module_doc(&mut self, info: &ModuleDocInfo) {
143        self.get_module(Some(info.module)).doc = (info.doc)();
144    }
145
146    fn add_methods(&mut self, info: &PyMethodsInfo) {
147        let struct_id = (info.struct_id)();
148        for module in self.modules.values_mut() {
149            if let Some(entry) = module.class.get_mut(&struct_id) {
150                for attr in info.attrs {
151                    entry.attrs.push(MemberDef {
152                        name: attr.name,
153                        r#type: (attr.r#type)(),
154                        doc: attr.doc,
155                        default: attr.default.map(|f| f()),
156                        deprecated: attr.deprecated.clone(),
157                    });
158                }
159                for getter in info.getters {
160                    entry
161                        .getter_setters
162                        .entry(getter.name.to_string())
163                        .or_default()
164                        .0 = Some(MemberDef {
165                        name: getter.name,
166                        r#type: (getter.r#type)(),
167                        doc: getter.doc,
168                        default: getter.default.map(|f| f()),
169                        deprecated: getter.deprecated.clone(),
170                    });
171                }
172                for setter in info.setters {
173                    entry
174                        .getter_setters
175                        .entry(setter.name.to_string())
176                        .or_default()
177                        .1 = Some(MemberDef {
178                        name: setter.name,
179                        r#type: (setter.r#type)(),
180                        doc: setter.doc,
181                        default: setter.default.map(|f| f()),
182                        deprecated: setter.deprecated.clone(),
183                    });
184                }
185                for method in info.methods {
186                    let entries = entry.methods.entry(method.name.to_string()).or_default();
187                    entries.push(MethodDef::from(method));
188                }
189                return;
190            } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
191                for attr in info.attrs {
192                    entry.attrs.push(MemberDef {
193                        name: attr.name,
194                        r#type: (attr.r#type)(),
195                        doc: attr.doc,
196                        default: attr.default.map(|f| f()),
197                        deprecated: attr.deprecated.clone(),
198                    });
199                }
200                for getter in info.getters {
201                    entry.getters.push(MemberDef {
202                        name: getter.name,
203                        r#type: (getter.r#type)(),
204                        doc: getter.doc,
205                        default: getter.default.map(|f| f()),
206                        deprecated: getter.deprecated.clone(),
207                    });
208                }
209                for setter in info.setters {
210                    entry.setters.push(MemberDef {
211                        name: setter.name,
212                        r#type: (setter.r#type)(),
213                        doc: setter.doc,
214                        default: setter.default.map(|f| f()),
215                        deprecated: setter.deprecated.clone(),
216                    });
217                }
218                for method in info.methods {
219                    entry.methods.push(MethodDef::from(method))
220                }
221                return;
222            }
223        }
224        unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
225    }
226
227    fn build(mut self) -> StubInfo {
228        for info in inventory::iter::<PyClassInfo> {
229            self.add_class(info);
230        }
231        for info in inventory::iter::<PyComplexEnumInfo> {
232            self.add_complex_enum(info);
233        }
234        for info in inventory::iter::<PyEnumInfo> {
235            self.add_enum(info);
236        }
237        for info in inventory::iter::<PyFunctionInfo> {
238            self.add_function(info);
239        }
240        for info in inventory::iter::<PyVariableInfo> {
241            self.add_variable(info);
242        }
243        for info in inventory::iter::<ModuleDocInfo> {
244            self.add_module_doc(info);
245        }
246        for info in inventory::iter::<PyMethodsInfo> {
247            self.add_methods(info);
248        }
249        self.register_submodules();
250        StubInfo {
251            modules: self.modules,
252            python_root: self.python_root,
253        }
254    }
255}