pyo3_stub_gen_derive/gen_stub/
member.rs1use crate::gen_stub::{
2 attr::{parse_gen_stub_default, parse_gen_stub_override_type, OverrideTypeAttribute},
3 extract_documents,
4 util::TypeOrOverride,
5};
6
7use super::{extract_return_type, parse_pyo3_attrs, Attr};
8
9use crate::gen_stub::arg::ArgInfo;
10use proc_macro2::TokenStream as TokenStream2;
11use quote::{quote, ToTokens, TokenStreamExt};
12use syn::{Attribute, Error, Expr, Field, FnArg, ImplItemConst, ImplItemFn, Result};
13
14#[derive(Debug, Clone)]
15pub struct MemberInfo {
16 doc: String,
17 name: String,
18 r#type: TypeOrOverride,
19 default: Option<Expr>,
20 deprecated: Option<crate::gen_stub::attr::DeprecatedInfo>,
21}
22
23impl MemberInfo {
24 pub fn is_getter(attrs: &[Attribute]) -> Result<bool> {
25 let attrs = parse_pyo3_attrs(attrs)?;
26 Ok(attrs.iter().any(|attr| matches!(attr, Attr::Getter(_))))
27 }
28 pub fn is_setter(attrs: &[Attribute]) -> Result<bool> {
29 let attrs = parse_pyo3_attrs(attrs)?;
30 Ok(attrs.iter().any(|attr| matches!(attr, Attr::Setter(_))))
31 }
32
33 pub fn is_classattr(attrs: &[Attribute]) -> Result<bool> {
34 let attrs = parse_pyo3_attrs(attrs)?;
35 Ok(attrs.iter().any(|attr| matches!(attr, Attr::ClassAttr)))
36 }
37 pub fn is_get(field: &Field) -> Result<bool> {
38 let Field { attrs, .. } = field;
39 Ok(parse_pyo3_attrs(attrs)?
40 .iter()
41 .any(|attr| matches!(attr, Attr::Get)))
42 }
43 pub fn is_set(field: &Field) -> Result<bool> {
44 let Field { attrs, .. } = field;
45 Ok(parse_pyo3_attrs(attrs)?
46 .iter()
47 .any(|attr| matches!(attr, Attr::Set)))
48 }
49}
50
51impl MemberInfo {
52 pub fn new_getter(item: ImplItemFn) -> Result<Self> {
62 assert!(Self::is_getter(&item.attrs)?);
63 let ImplItemFn { attrs, sig, .. } = &item;
64 let default = parse_gen_stub_default(attrs)?;
65 let doc = extract_documents(attrs).join("\n");
66 let pyo3_attrs = parse_pyo3_attrs(attrs)?;
67
68 let mut name = None;
70 for attr in &pyo3_attrs {
71 if let Attr::Getter(getter_name) = attr {
72 let fn_name = sig.ident.to_string();
73 let fn_getter_name = match fn_name.strip_prefix("get_") {
74 Some(s) => s.to_owned(),
75 None => fn_name,
76 };
77 name = Some(getter_name.clone().unwrap_or(fn_getter_name));
78 break;
79 }
80 }
81
82 for attr in &pyo3_attrs {
84 if let Attr::Name(pyo3_name) = attr {
85 name = Some(pyo3_name.clone());
86 break;
87 }
88 }
89
90 let name = name.ok_or_else(|| Error::new_spanned(&item, "Not a getter"))?;
91 let r#type = extract_return_type(&sig.output, attrs)?
92 .ok_or_else(|| Error::new_spanned(&item, "Getter must return a type"))?;
93 Ok(MemberInfo {
94 doc,
95 name,
96 r#type,
97 default,
98 deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
99 })
100 }
101 pub fn new_setter(item: ImplItemFn) -> Result<Self> {
111 assert!(Self::is_setter(&item.attrs)?);
112 let ImplItemFn { attrs, sig, .. } = &item;
113 let default = parse_gen_stub_default(attrs)?;
114 let doc = extract_documents(attrs).join("\n");
115 let pyo3_attrs = parse_pyo3_attrs(attrs)?;
116
117 let mut name = None;
119 let mut r#type = None;
120 for attr in &pyo3_attrs {
121 if let Attr::Setter(setter_name) = attr {
122 let fn_name = sig.ident.to_string();
123 let fn_setter_name = match fn_name.strip_prefix("set_") {
124 Some(s) => s.to_owned(),
125 None => fn_name,
126 };
127 name = Some(setter_name.clone().unwrap_or(fn_setter_name));
128 r#type = Some(
129 sig.inputs
130 .get(1)
131 .ok_or(syn::Error::new_spanned(&item, "Setter must input a type"))
132 .and_then(|arg| {
133 if let FnArg::Typed(t) = arg {
134 Ok(match parse_gen_stub_override_type(&t.attrs)? {
135 Some(OverrideTypeAttribute { type_repr, imports }) => {
136 TypeOrOverride::OverrideType {
137 r#type: *t.ty.clone(),
138 type_repr,
139 imports,
140 rust_type_markers: vec![],
141 }
142 }
143 _ => TypeOrOverride::RustType {
144 r#type: *t.ty.clone(),
145 },
146 })
147 } else {
148 Err(syn::Error::new_spanned(&item, "Setter must input a type"))
149 }
150 })?,
151 );
152 break;
153 }
154 }
155
156 for attr in &pyo3_attrs {
158 if let Attr::Name(pyo3_name) = attr {
159 name = Some(pyo3_name.clone());
160 break;
161 }
162 }
163
164 let name = name.ok_or_else(|| Error::new_spanned(&item, "Not a setter"))?;
165 let r#type = r#type.ok_or_else(|| Error::new_spanned(&item, "Setter type not found"))?;
166 Ok(MemberInfo {
167 doc,
168 name,
169 r#type,
170 default,
171 deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
172 })
173 }
174 pub fn new_classattr_fn(item: ImplItemFn) -> Result<Self> {
175 assert!(Self::is_classattr(&item.attrs)?);
176 let ImplItemFn { attrs, sig, .. } = &item;
177 let default = parse_gen_stub_default(attrs)?;
178 let doc = extract_documents(attrs).join("\n");
179 let mut name = sig.ident.to_string();
180 for attr in parse_pyo3_attrs(attrs)? {
181 if let Attr::Name(_name) = attr {
182 name = _name;
183 }
184 }
185 Ok(MemberInfo {
186 doc,
187 name,
188 r#type: extract_return_type(&sig.output, attrs)?.expect("Getter must return a type"),
189 default,
190 deprecated: crate::gen_stub::attr::extract_deprecated(attrs),
191 })
192 }
193 pub fn new_classattr_const(item: ImplItemConst) -> Result<Self> {
194 assert!(Self::is_classattr(&item.attrs)?);
195 let ImplItemConst {
196 attrs,
197 ident,
198 ty,
199 expr,
200 ..
201 } = item;
202 let doc = extract_documents(&attrs).join("\n");
203 let mut name = ident.to_string();
204 for attr in parse_pyo3_attrs(&attrs)? {
205 if let Attr::Name(_name) = attr {
206 name = _name;
207 }
208 }
209 Ok(MemberInfo {
210 doc,
211 name,
212 r#type: TypeOrOverride::RustType { r#type: ty },
213 default: Some(expr),
214 deprecated: crate::gen_stub::attr::extract_deprecated(&attrs),
215 })
216 }
217}
218
219impl TryFrom<Field> for MemberInfo {
220 type Error = Error;
221 fn try_from(field: Field) -> Result<Self> {
222 let Field {
223 ident, ty, attrs, ..
224 } = field;
225 let mut field_name = None;
226 for attr in parse_pyo3_attrs(&attrs)? {
227 if let Attr::Name(name) = attr {
228 field_name = Some(name);
229 }
230 }
231 let doc = extract_documents(&attrs).join("\n");
232 let default = parse_gen_stub_default(&attrs)?;
233 let deprecated = crate::gen_stub::attr::extract_deprecated(&attrs);
234 Ok(Self {
235 name: field_name.unwrap_or(ident.unwrap().to_string()),
236 r#type: TypeOrOverride::RustType { r#type: ty },
237 doc,
238 default,
239 deprecated,
240 })
241 }
242}
243
244impl ToTokens for MemberInfo {
245 fn to_tokens(&self, tokens: &mut TokenStream2) {
246 let Self {
247 name,
248 r#type,
249 doc,
250 default,
251 deprecated,
252 } = self;
253 let default = default
254 .as_ref()
255 .map(|value| {
256 if value.to_token_stream().to_string() == "None" {
257 quote! {
258 "None".to_string()
259 }
260 } else {
261 let (TypeOrOverride::RustType { r#type: ty }
262 | TypeOrOverride::OverrideType { r#type: ty, .. }) = r#type;
263 quote! {
264 let v: #ty = #value;
265 ::pyo3_stub_gen::util::fmt_py_obj(v)
266 }
267 }
268 })
269 .map_or(quote! {None}, |default| {
270 quote! {Some({
271 fn _fmt() -> String {
272 #default
273 }
274 _fmt
275 })}
276 });
277 let deprecated_info = deprecated
278 .as_ref()
279 .map(|deprecated| {
280 quote! {
281 Some(::pyo3_stub_gen::type_info::DeprecatedInfo {
282 since: #deprecated.since,
283 note: #deprecated.note,
284 })
285 }
286 })
287 .unwrap_or_else(|| quote! { None });
288 match r#type {
289 TypeOrOverride::RustType { r#type: ty } => tokens.append_all(quote! {
290 ::pyo3_stub_gen::type_info::MemberInfo {
291 name: #name,
292 r#type: <#ty as ::pyo3_stub_gen::PyStubType>::type_output,
293 doc: #doc,
294 default: #default,
295 deprecated: #deprecated_info,
296 }
297 }),
298 TypeOrOverride::OverrideType {
299 type_repr,
300 imports,
301 rust_type_markers,
302 ..
303 } => {
304 let imports = imports.iter().collect::<Vec<&String>>();
305
306 let (type_name_code, type_refs_code) = if rust_type_markers.is_empty() {
308 (
309 quote! { #type_repr.to_string() },
310 quote! { ::std::collections::HashMap::new() },
311 )
312 } else {
313 let marker_types: Vec<syn::Type> = rust_type_markers
315 .iter()
316 .filter_map(|s| syn::parse_str(s).ok())
317 .collect();
318
319 let rust_names = rust_type_markers.iter().collect::<Vec<_>>();
320
321 (
322 quote! {
323 {
324 let mut type_name = #type_repr.to_string();
325 #(
326 let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
327 type_name = type_name.replace(#rust_names, &type_info.name);
328 )*
329 type_name
330 }
331 },
332 quote! {
333 {
334 let mut type_refs = ::std::collections::HashMap::new();
335 #(
336 let type_info = <#marker_types as ::pyo3_stub_gen::PyStubType>::type_input();
337 if let Some(module) = type_info.source_module {
338 type_refs.insert(
339 type_info.name.split('[').next().unwrap_or(&type_info.name).split('.').last().unwrap_or(&type_info.name).to_string(),
340 ::pyo3_stub_gen::TypeIdentifierRef {
341 module: module.into(),
342 import_kind: ::pyo3_stub_gen::ImportKind::Module,
343 }
344 );
345 }
346 type_refs.extend(type_info.type_refs);
347 )*
348 type_refs
349 }
350 },
351 )
352 };
353
354 tokens.append_all(quote! {
355 ::pyo3_stub_gen::type_info::MemberInfo {
356 name: #name,
357 r#type: || ::pyo3_stub_gen::TypeInfo { name: #type_name_code, source_module: None, import: ::std::collections::HashSet::from([#(#imports.into(),)*]), type_refs: #type_refs_code },
358 doc: #doc,
359 default: #default,
360 deprecated: #deprecated_info,
361 }
362 })
363 }
364 }
365 }
366}
367
368impl From<MemberInfo> for ArgInfo {
369 fn from(value: MemberInfo) -> Self {
370 let MemberInfo { name, r#type, .. } = value;
371
372 Self { name, r#type }
373 }
374}