pyo3_stub_gen/generate/
stub_info.rs1use crate::{generate::*, pyproject::PyProject, type_info::*};
2use anyhow::{Context, Result};
3use std::{
4 collections::{BTreeMap, BTreeSet},
5 fs,
6 io::Write,
7 path::*,
8};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct StubInfo {
12 pub modules: BTreeMap<String, Module>,
13 pub python_root: PathBuf,
14}
15
16impl StubInfo {
17 pub fn from_pyproject_toml(path: impl AsRef<Path>) -> Result<Self> {
20 let pyproject = PyProject::parse_toml(path)?;
21 Ok(StubInfoBuilder::from_pyproject_toml(pyproject).build())
22 }
23
24 pub fn from_project_root(default_module_name: String, project_root: PathBuf) -> Result<Self> {
28 Ok(StubInfoBuilder::from_project_root(default_module_name, project_root).build())
29 }
30
31 pub fn generate(&self) -> Result<()> {
32 for (name, module) in self.modules.iter() {
33 let path = name.replace(".", "/");
34 let dest = if module.submodules.is_empty() {
35 self.python_root.join(format!("{path}.pyi"))
36 } else {
37 self.python_root.join(path).join("__init__.pyi")
38 };
39
40 let dir = dest.parent().context("Cannot get parent directory")?;
41 if !dir.exists() {
42 fs::create_dir_all(dir)?;
43 }
44
45 let mut f = fs::File::create(&dest)?;
46 write!(f, "{}", module)?;
47 log::info!(
48 "Generate stub file of a module `{name}` at {dest}",
49 dest = dest.display()
50 );
51 }
52 Ok(())
53 }
54}
55
56struct StubInfoBuilder {
57 modules: BTreeMap<String, Module>,
58 default_module_name: String,
59 python_root: PathBuf,
60}
61
62impl StubInfoBuilder {
63 fn from_pyproject_toml(pyproject: PyProject) -> Self {
64 StubInfoBuilder::from_project_root(
65 pyproject.module_name().to_string(),
66 pyproject
67 .python_source()
68 .unwrap_or(PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap())),
69 )
70 }
71
72 fn from_project_root(default_module_name: String, project_root: PathBuf) -> Self {
73 Self {
74 modules: BTreeMap::new(),
75 default_module_name,
76 python_root: project_root,
77 }
78 }
79
80 fn get_module(&mut self, name: Option<&str>) -> &mut Module {
81 let name = name.unwrap_or(&self.default_module_name).to_string();
82 let module = self.modules.entry(name.clone()).or_default();
83 module.name = name;
84 module.default_module_name = self.default_module_name.clone();
85 module
86 }
87
88 fn register_submodules(&mut self) {
89 let mut map: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
90 for module in self.modules.keys() {
91 let path = module.split('.').collect::<Vec<_>>();
92 let n = path.len();
93 if n <= 1 {
94 continue;
95 }
96 map.entry(path[..n - 1].join("."))
97 .or_default()
98 .insert(path[n - 1].to_string());
99 }
100 for (parent, children) in map {
101 if let Some(module) = self.modules.get_mut(&parent) {
102 module.submodules.extend(children);
103 }
104 }
105 }
106
107 fn add_class(&mut self, info: &PyClassInfo) {
108 self.get_module(info.module)
109 .class
110 .insert((info.struct_id)(), ClassDef::from(info));
111 }
112
113 fn add_enum(&mut self, info: &PyEnumInfo) {
114 self.get_module(info.module)
115 .enum_
116 .insert((info.enum_id)(), EnumDef::from(info));
117 }
118
119 fn add_function(&mut self, info: &PyFunctionInfo) {
120 self.get_module(info.module)
121 .function
122 .insert(info.name, FunctionDef::from(info));
123 }
124
125 fn add_error(&mut self, info: &PyErrorInfo) {
126 self.get_module(Some(info.module))
127 .error
128 .insert(info.name, ErrorDef::from(info));
129 }
130
131 fn add_variable(&mut self, info: &PyVariableInfo) {
132 self.get_module(Some(info.module))
133 .variables
134 .insert(info.name, VariableDef::from(info));
135 }
136
137 fn add_methods(&mut self, info: &PyMethodsInfo) {
138 let struct_id = (info.struct_id)();
139 for module in self.modules.values_mut() {
140 if let Some(entry) = module.class.get_mut(&struct_id) {
141 for getter in info.getters {
142 entry.members.push(MemberDef {
143 name: getter.name,
144 r#type: (getter.r#type)(),
145 doc: getter.doc,
146 });
147 }
148 for method in info.methods {
149 entry.methods.push(MethodDef::from(method))
150 }
151 return;
152 } else if let Some(entry) = module.enum_.get_mut(&struct_id) {
153 for getter in info.getters {
154 entry.members.push(MemberDef {
155 name: getter.name,
156 r#type: (getter.r#type)(),
157 doc: getter.doc,
158 });
159 }
160 for method in info.methods {
161 entry.methods.push(MethodDef::from(method))
162 }
163 return;
164 }
165 }
166 unreachable!("Missing struct_id/enum_id = {:?}", struct_id);
167 }
168
169 fn build(mut self) -> StubInfo {
170 for info in inventory::iter::<PyClassInfo> {
171 self.add_class(info);
172 }
173 for info in inventory::iter::<PyEnumInfo> {
174 self.add_enum(info);
175 }
176 for info in inventory::iter::<PyFunctionInfo> {
177 self.add_function(info);
178 }
179 for info in inventory::iter::<PyErrorInfo> {
180 self.add_error(info);
181 }
182 for info in inventory::iter::<PyVariableInfo> {
183 self.add_variable(info);
184 }
185 for info in inventory::iter::<PyMethodsInfo> {
186 self.add_methods(info);
187 }
188 self.register_submodules();
189 StubInfo {
190 modules: self.modules,
191 python_root: self.python_root,
192 }
193 }
194}