pyo3_stub_gen_derive/gen_stub/
method.rs

1use crate::gen_stub::util::TypeOrOverride;
2
3use super::{
4    arg::parse_args, attr::IgnoreTarget, extract_deprecated, extract_documents,
5    extract_return_type, parse_gen_stub_type_ignore, parse_pyo3_attrs, ArgInfo, ArgsWithSignature,
6    Attr, DeprecatedInfo, Signature,
7};
8
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{quote, ToTokens, TokenStreamExt};
11use syn::{
12    Error, GenericArgument, ImplItemFn, PathArguments, Result, Type, TypePath, TypeReference,
13};
14
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum MethodType {
17    Instance,
18    Static,
19    Class,
20    New,
21}
22
23#[derive(Debug)]
24pub struct MethodInfo {
25    name: String,
26    args: Vec<ArgInfo>,
27    sig: Option<Signature>,
28    r#return: Option<TypeOrOverride>,
29    doc: String,
30    r#type: MethodType,
31    is_async: bool,
32    deprecated: Option<DeprecatedInfo>,
33    type_ignored: Option<IgnoreTarget>,
34}
35
36fn replace_inner(ty: &mut Type, self_: &Type) {
37    match ty {
38        Type::Path(TypePath { path, .. }) => {
39            if let Some(last) = path.segments.last_mut() {
40                if let PathArguments::AngleBracketed(arg) = &mut last.arguments {
41                    for arg in &mut arg.args {
42                        if let GenericArgument::Type(ty) = arg {
43                            replace_inner(ty, self_);
44                        }
45                    }
46                }
47                if last.ident == "Self" {
48                    *ty = self_.clone();
49                }
50            }
51        }
52        Type::Reference(TypeReference { elem, .. }) => {
53            replace_inner(elem, self_);
54        }
55        _ => {}
56    }
57}
58
59impl MethodInfo {
60    pub fn replace_self(&mut self, self_: &Type) {
61        for mut arg in &mut self.args {
62            let (ArgInfo {
63                r#type:
64                    TypeOrOverride::RustType {
65                        r#type: ref mut ty, ..
66                    },
67                ..
68            }
69            | ArgInfo {
70                r#type:
71                    TypeOrOverride::OverrideType {
72                        r#type: ref mut ty, ..
73                    },
74                ..
75            }) = &mut arg;
76            replace_inner(ty, self_);
77        }
78        if let Some(
79            TypeOrOverride::RustType { r#type: ret }
80            | TypeOrOverride::OverrideType { r#type: ret, .. },
81        ) = self.r#return.as_mut()
82        {
83            replace_inner(ret, self_);
84        }
85    }
86}
87
88impl TryFrom<ImplItemFn> for MethodInfo {
89    type Error = Error;
90    fn try_from(item: ImplItemFn) -> Result<Self> {
91        let ImplItemFn { attrs, sig, .. } = item;
92        let doc = extract_documents(&attrs).join("\n");
93        let deprecated = extract_deprecated(&attrs);
94        let type_ignored = parse_gen_stub_type_ignore(&attrs)?;
95        let pyo3_attrs = parse_pyo3_attrs(&attrs)?;
96        let mut method_name = None;
97        let mut text_sig = Signature::overriding_operator(&sig);
98        let mut method_type = MethodType::Instance;
99        for attr in pyo3_attrs {
100            match attr {
101                Attr::Name(name) => method_name = Some(name),
102                Attr::Signature(text_sig_) => text_sig = Some(text_sig_),
103                Attr::StaticMethod => method_type = MethodType::Static,
104                Attr::ClassMethod => method_type = MethodType::Class,
105                Attr::New => method_type = MethodType::New,
106                _ => {}
107            }
108        }
109        let name = if method_type == MethodType::New {
110            "__new__".to_string()
111        } else {
112            method_name.unwrap_or(sig.ident.to_string())
113        };
114        let r#return = extract_return_type(&sig.output, &attrs)?;
115        Ok(MethodInfo {
116            name,
117            sig: text_sig,
118            args: parse_args(sig.inputs)?,
119            r#return,
120            doc,
121            r#type: method_type,
122            is_async: sig.asyncness.is_some(),
123            deprecated,
124            type_ignored,
125        })
126    }
127}
128
129impl ToTokens for MethodInfo {
130    fn to_tokens(&self, tokens: &mut TokenStream2) {
131        let Self {
132            name,
133            r#return: ret,
134            args,
135            sig,
136            doc,
137            r#type,
138            is_async,
139            deprecated,
140            type_ignored,
141        } = self;
142        let args_with_sig = ArgsWithSignature { args, sig };
143        let ret_tt = if let Some(ret) = ret {
144            match ret {
145                TypeOrOverride::RustType { r#type } => {
146                    let ty = r#type.clone();
147                    quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
148                }
149                TypeOrOverride::OverrideType {
150                    type_repr, imports, ..
151                } => {
152                    let imports = imports.iter().collect::<Vec<&String>>();
153                    quote! {
154                        || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
155                    }
156                }
157            }
158        } else {
159            quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
160        };
161        let type_tt = match r#type {
162            MethodType::Instance => quote! { ::pyo3_stub_gen::type_info::MethodType::Instance },
163            MethodType::Static => quote! { ::pyo3_stub_gen::type_info::MethodType::Static },
164            MethodType::Class => quote! { ::pyo3_stub_gen::type_info::MethodType::Class },
165            MethodType::New => quote! { ::pyo3_stub_gen::type_info::MethodType::New },
166        };
167        let deprecated_tt = deprecated
168            .as_ref()
169            .map(|d| quote! { Some(#d) })
170            .unwrap_or_else(|| quote! { None });
171        let type_ignored_tt = if let Some(target) = type_ignored {
172            match target {
173                IgnoreTarget::All => {
174                    quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
175                }
176                IgnoreTarget::SpecifiedLits(rules) => {
177                    let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
178                    quote! {
179                        Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
180                            &[#(#rule_strs),*] as &[&str]
181                        ))
182                    }
183                }
184            }
185        } else {
186            quote! { None }
187        };
188        tokens.append_all(quote! {
189            ::pyo3_stub_gen::type_info::MethodInfo {
190                name: #name,
191                args: #args_with_sig,
192                r#return: #ret_tt,
193                doc: #doc,
194                r#type: #type_tt,
195                is_async: #is_async,
196                deprecated: #deprecated_tt,
197                type_ignored: #type_ignored_tt,
198            }
199        })
200    }
201}