pyo3_stub_gen_derive/gen_stub/
method.rs

1use crate::gen_stub::util::TypeOrOverride;
2
3use super::{
4    arg::parse_args, attr::IgnoreTarget, extract_deprecated, extract_documents,
5    extract_return_type, parameter::Parameters, parse_gen_stub_type_ignore, parse_pyo3_attrs,
6    ArgInfo, Attr, DeprecatedInfo, Signature,
7};
8
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{quote, ToTokens, TokenStreamExt};
11use syn::{
12    Error, GenericArgument, ImplItemFn, PathArguments, Result, Type, TypePath, TypeReference,
13};
14
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum MethodType {
17    Instance,
18    Static,
19    Class,
20    New,
21}
22
23#[derive(Debug)]
24pub struct MethodInfo {
25    pub(super) name: String,
26    pub(super) parameters: Parameters,
27    pub(super) r#return: Option<TypeOrOverride>,
28    pub(super) doc: String,
29    pub(super) r#type: MethodType,
30    pub(super) is_async: bool,
31    pub(super) deprecated: Option<DeprecatedInfo>,
32    pub(super) type_ignored: Option<IgnoreTarget>,
33    pub(super) is_overload: bool,
34}
35
36fn replace_inner(ty: &mut Type, self_: &Type) {
37    match ty {
38        Type::Path(TypePath { path, .. }) => {
39            if let Some(last) = path.segments.last_mut() {
40                if let PathArguments::AngleBracketed(arg) = &mut last.arguments {
41                    for arg in &mut arg.args {
42                        if let GenericArgument::Type(ty) = arg {
43                            replace_inner(ty, self_);
44                        }
45                    }
46                }
47                if last.ident == "Self" {
48                    *ty = self_.clone();
49                }
50            }
51        }
52        Type::Reference(TypeReference { elem, .. }) => {
53            replace_inner(elem, self_);
54        }
55        _ => {}
56    }
57}
58
59impl MethodInfo {
60    pub fn replace_self(&mut self, self_: &Type) {
61        for param in self.parameters.iter_mut() {
62            let arg_info = &mut param.arg_info;
63            let (ArgInfo {
64                r#type:
65                    TypeOrOverride::RustType {
66                        r#type: ref mut ty, ..
67                    },
68                ..
69            }
70            | ArgInfo {
71                r#type:
72                    TypeOrOverride::OverrideType {
73                        r#type: ref mut ty, ..
74                    },
75                ..
76            }) = arg_info;
77            replace_inner(ty, self_);
78        }
79        if let Some(
80            TypeOrOverride::RustType { r#type: ret }
81            | TypeOrOverride::OverrideType { r#type: ret, .. },
82        ) = self.r#return.as_mut()
83        {
84            replace_inner(ret, self_);
85        }
86    }
87}
88
89impl TryFrom<ImplItemFn> for MethodInfo {
90    type Error = Error;
91    fn try_from(item: ImplItemFn) -> Result<Self> {
92        let ImplItemFn { attrs, sig, .. } = item;
93        let doc = extract_documents(&attrs).join("\n");
94        let deprecated = extract_deprecated(&attrs);
95        let type_ignored = parse_gen_stub_type_ignore(&attrs)?;
96        let pyo3_attrs = parse_pyo3_attrs(&attrs)?;
97        let mut method_name = None;
98        let mut text_sig = Signature::overriding_operator(&sig);
99        let mut method_type = MethodType::Instance;
100        for attr in pyo3_attrs {
101            match attr {
102                Attr::Name(name) => method_name = Some(name),
103                Attr::Signature(text_sig_) => text_sig = Some(text_sig_),
104                Attr::StaticMethod => method_type = MethodType::Static,
105                Attr::ClassMethod => method_type = MethodType::Class,
106                Attr::New => method_type = MethodType::New,
107                _ => {}
108            }
109        }
110        let name = if method_type == MethodType::New {
111            "__new__".to_string()
112        } else {
113            method_name.unwrap_or(sig.ident.to_string())
114        };
115        let r#return = extract_return_type(&sig.output, &attrs)?;
116
117        // Build parameters from args and signature
118        let args = parse_args(sig.inputs)?;
119        let parameters = if let Some(text_sig) = text_sig {
120            Parameters::new_with_sig(&args, &text_sig)?
121        } else {
122            Parameters::new(&args)
123        };
124
125        Ok(MethodInfo {
126            name,
127            parameters,
128            r#return,
129            doc,
130            r#type: method_type,
131            is_async: sig.asyncness.is_some(),
132            deprecated,
133            type_ignored,
134            is_overload: false,
135        })
136    }
137}
138
139impl ToTokens for MethodInfo {
140    fn to_tokens(&self, tokens: &mut TokenStream2) {
141        let Self {
142            name,
143            r#return: ret,
144            parameters,
145            doc,
146            r#type,
147            is_async,
148            deprecated,
149            type_ignored,
150            is_overload,
151        } = self;
152
153        let ret_tt = if let Some(ret) = ret {
154            match ret {
155                TypeOrOverride::RustType { r#type } => {
156                    let ty = r#type.clone();
157                    quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
158                }
159                TypeOrOverride::OverrideType {
160                    type_repr, imports, ..
161                } => {
162                    let imports = imports.iter().collect::<Vec<&String>>();
163                    quote! {
164                        || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
165                    }
166                }
167            }
168        } else {
169            quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
170        };
171        let type_tt = match r#type {
172            MethodType::Instance => quote! { ::pyo3_stub_gen::type_info::MethodType::Instance },
173            MethodType::Static => quote! { ::pyo3_stub_gen::type_info::MethodType::Static },
174            MethodType::Class => quote! { ::pyo3_stub_gen::type_info::MethodType::Class },
175            MethodType::New => quote! { ::pyo3_stub_gen::type_info::MethodType::New },
176        };
177        let deprecated_tt = deprecated
178            .as_ref()
179            .map(|d| quote! { Some(#d) })
180            .unwrap_or_else(|| quote! { None });
181        let type_ignored_tt = if let Some(target) = type_ignored {
182            match target {
183                IgnoreTarget::All => {
184                    quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
185                }
186                IgnoreTarget::SpecifiedLits(rules) => {
187                    let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
188                    quote! {
189                        Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
190                            &[#(#rule_strs),*] as &[&str]
191                        ))
192                    }
193                }
194            }
195        } else {
196            quote! { None }
197        };
198        tokens.append_all(quote! {
199            ::pyo3_stub_gen::type_info::MethodInfo {
200                name: #name,
201                parameters: #parameters,
202                r#return: #ret_tt,
203                doc: #doc,
204                r#type: #type_tt,
205                is_async: #is_async,
206                deprecated: #deprecated_tt,
207                type_ignored: #type_ignored_tt,
208                is_overload: #is_overload,
209            }
210        })
211    }
212}