pyo3_stub_gen/generate/
stub_info.rs1use crate::{generate::*, pyproject::PyProject, type_info::*};
2use anyhow::{Context, Result};
3use std::{
4 collections::{BTreeMap, BTreeSet},
5 fs,
6 io::Write,
7 path::*,
8};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct StubInfo {
12 pub modules: BTreeMap<String, Module>,
13 pub python_root: PathBuf,
14}
15
16impl StubInfo {
17 pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
20 let pyproject = PyProject::parse_toml(path)?;
21 Ok(StubInfoBuilder::from_pyproject_toml(pyproject).build())
22 }
23
24 pub fn from_project_root(default_module_name: String, project_root: PathBuf) -> Result<Self> {
28 Ok(StubInfoBuilder::from_project_root(default_module_name, project_root).build())
29 }
30
31 pub fn generate(&self) -> Result<()> {
32 for (name, module) in self.modules.iter() {
33 let normalized_name = name.replace("-", "_");
35 let path = normalized_name.replace(".", "/");
36 let dest = if module.submodules.is_empty() {
37 self.python_root.join(format!("{path}.pyi"))
38 } else {
39 self.python_root.join(path).join("__init__.pyi")
40 };
41
42 let dir = dest.parent().context("Cannot get parent directory")?;
43 if !dir.exists() {
44 fs::create_dir_all(dir)?;
45 }
46
47 let mut f = fs::File::create(&dest)?;
48 write!(f, "{module}")?;
49 log::info!(
50 "Generate stub file of a module `{name}` at {dest}",
51 dest = dest.display()
52 );
53 }
54 Ok(())
55 }
56}
57
58struct StubInfoBuilder {
59 modules: BTreeMap<String, Module>,
60 default_module_name: String,
61 python_root: PathBuf,
62}
63
64impl StubInfoBuilder {
65 fn from_pyproject_toml(pyproject: PyProject) -> Self {
66 StubInfoBuilder::from_project_root(
67 pyproject.module_name().to_string(),
68 pyproject
69 .python_source()
70 .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())),
71 )
72 }
73
74 fn from_project_root(default_module_name: String, project_root: PathBuf) -> Self {
75 Self {
76 modules: BTreeMap::new(),
77 default_module_name,
78 python_root: project_root,
79 }
80 }
81
82 fn get_module(&mut self, name: Option<&str>) -> &mut Module {
83 let name = name.unwrap_or(&self.default_module_name).to_string();
84 let module = self.modules.entry(name.clone()).or_default();
85 module.name = name;
86 module.default_module_name = self.default_module_name.clone();
87 module
88 }
89
90 fn register_submodules(&mut self) {
91 let mut map: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
92 for module in self.modules.keys() {
93 let path = module.split('.').collect::<Vec<_>>();
94 let n = path.len();
95 if n <= 1 {
96 continue;
97 }
98 map.entry(path[..n - 1].join("."))
99 .or_default()
100 .insert(path[n - 1].to_string());
101 }
102 for (parent, children) in map {
103 if let Some(module) = self.modules.get_mut(&parent) {
104 module.submodules.extend(children);
105 }
106 }
107 }
108
109 fn add_class(&mut self, info: &PyClassInfo) {
110 self.get_module(info.module)
111 .class
112 .insert((info.struct_id)(), ClassDef::from(info));
113 }
114
115 fn add_complex_enum(&mut self, info: &PyComplexEnumInfo) {
116 self.get_module(info.module)
117 .class
118 .insert((info.enum_id)(), ClassDef::from(info));
119 }
120
121 fn add_enum(&mut self, info: &PyEnumInfo) {
122 self.get_module(info.module)
123 .enum_
124 .insert((info.enum_id)(), EnumDef::from(info));
125 }
126
127 fn add_function(&mut self, info: &PyFunctionInfo) {
128 let target = self
129 .get_module(info.module)
130 .function
131 .entry(info.name)
132 .or_default();
133 target.push(FunctionDef::from(info));
134 }
135
136 fn add_variable(&mut self, info: &PyVariableInfo) {
137 self.get_module(Some(info.module))
138 .variables
139 .insert(info.name, VariableDef::from(info));
140 }
141
142 fn add_module_doc(&mut self, info: &ModuleDocInfo) {
143 self.get_module(Some(info.module)).doc = (info.doc)();
144 }
145
146 fn add_methods(&mut self, info: &PyMethodsInfo) {
147 let struct_id = (info.struct_id)();
148 for module in self.modules.values_mut() {
149 if let Some(entry) = module.class.get_mut(&struct_id) {
150 for attr in info.attrs {
151 entry.attrs.push(MemberDef {
152 name: attr.name,
153 r#type: (attr.r#type)(),
154 doc: attr.doc,
155 default: attr.default.map(|f| f()),
156 deprecated: attr.deprecated.clone(),
157 });
158 }
159 for getter in info.getters {
160 entry
161 .getter_setters
162 .entry(getter.name.to_string())
163 .or_default()
164 .0 = Some(MemberDef {
165 name: getter.name,
166 r#type: (getter.r#type)(),
167 doc: getter.doc,
168 default: getter.default.map(|f| f()),
169 deprecated: getter.deprecated.clone(),
170 });
171 }
172 for setter in info.setters {
173 entry
174 .getter_setters
175 .entry(setter.name.to_string())
176 .or_default()
177 .1 = Some(MemberDef {
178 name: setter.name,
179 r#type: (setter.r#type)(),
180 doc: setter.doc,
181 default: setter.default.map(|f| f()),
182 deprecated: setter.deprecated.clone(),
183 });
184 }
185 for method in info.methods {
186 let entries = entry.methods.entry(method.name.to_string()).or_default();
187 entries.push(MethodDef::from(method));
188 }
189 return;
190 } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
191 for attr in info.attrs {
192 entry.attrs.push(MemberDef {
193 name: attr.name,
194 r#type: (attr.r#type)(),
195 doc: attr.doc,
196 default: attr.default.map(|f| f()),
197 deprecated: attr.deprecated.clone(),
198 });
199 }
200 for getter in info.getters {
201 entry.getters.push(MemberDef {
202 name: getter.name,
203 r#type: (getter.r#type)(),
204 doc: getter.doc,
205 default: getter.default.map(|f| f()),
206 deprecated: getter.deprecated.clone(),
207 });
208 }
209 for setter in info.setters {
210 entry.setters.push(MemberDef {
211 name: setter.name,
212 r#type: (setter.r#type)(),
213 doc: setter.doc,
214 default: setter.default.map(|f| f()),
215 deprecated: setter.deprecated.clone(),
216 });
217 }
218 for method in info.methods {
219 entry.methods.push(MethodDef::from(method))
220 }
221 return;
222 }
223 }
224 unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
225 }
226
227 fn build(mut self) -> StubInfo {
228 for info in inventory::iter::<PyClassInfo> {
229 self.add_class(info);
230 }
231 for info in inventory::iter::<PyComplexEnumInfo> {
232 self.add_complex_enum(info);
233 }
234 for info in inventory::iter::<PyEnumInfo> {
235 self.add_enum(info);
236 }
237 for info in inventory::iter::<PyFunctionInfo> {
238 self.add_function(info);
239 }
240 for info in inventory::iter::<PyVariableInfo> {
241 self.add_variable(info);
242 }
243 for info in inventory::iter::<ModuleDocInfo> {
244 self.add_module_doc(info);
245 }
246 for info in inventory::iter::<PyMethodsInfo> {
247 self.add_methods(info);
248 }
249 self.register_submodules();
250 StubInfo {
251 modules: self.modules,
252 python_root: self.python_root,
253 }
254 }
255}