pyo3_stub_gen/generate/
stub_info.rs1use crate::{generate::*, pyproject::PyProject, type_info::*};
2use anyhow::{Context, Result};
3use std::{
4 collections::{BTreeMap, BTreeSet},
5 fs,
6 io::Write,
7 path::*,
8};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct StubInfo {
12 pub modules: BTreeMap<String, Module>,
13 pub python_root: PathBuf,
14}
15
16impl StubInfo {
17 pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
20 let pyproject = PyProject::parse_toml(path)?;
21 Ok(StubInfoBuilder::from_pyproject_toml(pyproject).build())
22 }
23
24 pub fn from_project_root(default_module_name: String, project_root: PathBuf) -> Result<Self> {
28 Ok(StubInfoBuilder::from_project_root(default_module_name, project_root).build())
29 }
30
31 pub fn generate(&self) -> Result<()> {
32 for (name, module) in self.modules.iter() {
33 let normalized_name = name.replace("-", "_");
35 let path = normalized_name.replace(".", "/");
36 let dest = if module.submodules.is_empty() {
37 self.python_root.join(format!("{path}.pyi"))
38 } else {
39 self.python_root.join(path).join("__init__.pyi")
40 };
41
42 let dir = dest.parent().context("Cannot get parent directory")?;
43 if !dir.exists() {
44 fs::create_dir_all(dir)?;
45 }
46
47 let mut f = fs::File::create(&dest)?;
48 write!(f, "{module}")?;
49 log::info!(
50 "Generate stub file of a module `{name}` at {dest}",
51 dest = dest.display()
52 );
53 }
54 Ok(())
55 }
56}
57
58struct StubInfoBuilder {
59 modules: BTreeMap<String, Module>,
60 default_module_name: String,
61 python_root: PathBuf,
62}
63
64impl StubInfoBuilder {
65 fn from_pyproject_toml(pyproject: PyProject) -> Self {
66 StubInfoBuilder::from_project_root(
67 pyproject.module_name().to_string(),
68 pyproject
69 .python_source()
70 .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())),
71 )
72 }
73
74 fn from_project_root(default_module_name: String, project_root: PathBuf) -> Self {
75 Self {
76 modules: BTreeMap::new(),
77 default_module_name,
78 python_root: project_root,
79 }
80 }
81
82 fn get_module(&mut self, name: Option<&str>) -> &mut Module {
83 let name = name.unwrap_or(&self.default_module_name).to_string();
84 let module = self.modules.entry(name.clone()).or_default();
85 module.name = name;
86 module.default_module_name = self.default_module_name.clone();
87 module
88 }
89
90 fn register_submodules(&mut self) {
91 let mut map: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
92 for module in self.modules.keys() {
93 let path = module.split('.').collect::<Vec<_>>();
94 let n = path.len();
95 if n <= 1 {
96 continue;
97 }
98 map.entry(path[..n - 1].join("."))
99 .or_default()
100 .insert(path[n - 1].to_string());
101 }
102 for (parent, children) in map {
103 if let Some(module) = self.modules.get_mut(&parent) {
104 module.submodules.extend(children);
105 }
106 }
107 }
108
109 fn add_class(&mut self, info: &PyClassInfo) {
110 self.get_module(info.module)
111 .class
112 .insert((info.struct_id)(), ClassDef::from(info));
113 }
114
115 fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
116 self.get_module(info.module)
117 .class
118 .insert((info.enum_id)(), ClassDef::from(info));
119 }
120
121 fn add_enum(&mut self, info: &PyEnumInfo) {
122 self.get_module(info.module)
123 .enum_
124 .insert((info.enum_id)(), EnumDef::from(info));
125 }
126
127 fn add_function(&mut self, info: &PyFunctionInfo) {
128 let target = self
129 .get_module(info.module)
130 .function
131 .entry(info.name)
132 .or_default();
133 target.push(FunctionDef::from(info));
134 }
135
136 fn add_error(&mut self, info: &PyErrorInfo) {
137 self.get_module(Some(info.module))
138 .error
139 .insert(info.name, ErrorDef::from(info));
140 }
141
142 fn add_variable(&mut self, info: &PyVariableInfo) {
143 self.get_module(Some(info.module))
144 .variables
145 .insert(info.name, VariableDef::from(info));
146 }
147
148 fn add_methods(&mut self, info: &PyMethodsInfo) {
149 let struct_id = (info.struct_id)();
150 for module in self.modules.values_mut() {
151 if let Some(entry) = module.class.get_mut(&struct_id) {
152 for attr in info.attrs {
153 entry.attrs.push(MemberDef {
154 name: attr.name,
155 r#type: (attr.r#type)(),
156 doc: attr.doc,
157 default: attr.default.map(|s| s.as_str()),
158 });
159 }
160 for getter in info.getters {
161 entry.getters.push(MemberDef {
162 name: getter.name,
163 r#type: (getter.r#type)(),
164 doc: getter.doc,
165 default: getter.default.map(|s| s.as_str()),
166 });
167 }
168 for setter in info.setters {
169 entry.setters.push(MemberDef {
170 name: setter.name,
171 r#type: (setter.r#type)(),
172 doc: setter.doc,
173 default: setter.default.map(|s| s.as_str()),
174 });
175 }
176 for method in info.methods {
177 let entries = entry.methods.entry(method.name.to_string()).or_default();
178 entries.push(MethodDef::from(method));
179 }
180 return;
181 } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
182 for attr in info.attrs {
183 entry.attrs.push(MemberDef {
184 name: attr.name,
185 r#type: (attr.r#type)(),
186 doc: attr.doc,
187 default: attr.default.map(|s| s.as_str()),
188 });
189 }
190 for getter in info.getters {
191 entry.getters.push(MemberDef {
192 name: getter.name,
193 r#type: (getter.r#type)(),
194 doc: getter.doc,
195 default: getter.default.map(|s| s.as_str()),
196 });
197 }
198 for setter in info.setters {
199 entry.setters.push(MemberDef {
200 name: setter.name,
201 r#type: (setter.r#type)(),
202 doc: setter.doc,
203 default: setter.default.map(|s| s.as_str()),
204 });
205 }
206 for method in info.methods {
207 entry.methods.push(MethodDef::from(method))
208 }
209 return;
210 }
211 }
212 unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
213 }
214
215 fn build(mut self) -> StubInfo {
216 for info in inventory::iter::<PyClassInfo> {
217 self.add_class(info);
218 }
219 for info in inventory::iter::<PyComplexEnumInfo> {
220 self.add_complex_enum(info);
221 }
222 for info in inventory::iter::<PyEnumInfo> {
223 self.add_enum(info);
224 }
225 for info in inventory::iter::<PyFunctionInfo> {
226 self.add_function(info);
227 }
228 for info in inventory::iter::<PyErrorInfo> {
229 self.add_error(info);
230 }
231 for info in inventory::iter::<PyVariableInfo> {
232 self.add_variable(info);
233 }
234 for info in inventory::iter::<PyMethodsInfo> {
235 self.add_methods(info);
236 }
237 self.register_submodules();
238 StubInfo {
239 modules: self.modules,
240 python_root: self.python_root,
241 }
242 }
243}