pyo3_stub_gen_derive/gen_stub/
method.rs

1use crate::gen_stub::util::TypeOrOverride;
2
3use super::{
4    arg::parse_args, extract_deprecated, extract_documents, extract_return_type, parse_pyo3_attrs,
5    ArgInfo, ArgsWithSignature, Attr, DeprecatedInfo, Signature,
6};
7
8use proc_macro2::TokenStream as TokenStream2;
9use quote::{quote, ToTokens, TokenStreamExt};
10use syn::{
11    Error, GenericArgument, ImplItemFn, PathArguments, Result, Type, TypePath, TypeReference,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum MethodType {
16    Instance,
17    Static,
18    Class,
19    New,
20}
21
22#[derive(Debug)]
23pub struct MethodInfo {
24    name: String,
25    args: Vec<ArgInfo>,
26    sig: Option<Signature>,
27    r#return: Option<TypeOrOverride>,
28    doc: String,
29    r#type: MethodType,
30    is_async: bool,
31    deprecated: Option<DeprecatedInfo>,
32}
33
34fn replace_inner(ty: &mut Type, self_: &Type) {
35    match ty {
36        Type::Path(TypePath { path, .. }) => {
37            if let Some(last) = path.segments.last_mut() {
38                if let PathArguments::AngleBracketed(arg) = &mut last.arguments {
39                    for arg in &mut arg.args {
40                        if let GenericArgument::Type(ty) = arg {
41                            replace_inner(ty, self_);
42                        }
43                    }
44                }
45                if last.ident == "Self" {
46                    *ty = self_.clone();
47                }
48            }
49        }
50        Type::Reference(TypeReference { elem, .. }) => {
51            replace_inner(elem, self_);
52        }
53        _ => {}
54    }
55}
56
57impl MethodInfo {
58    pub fn replace_self(&mut self, self_: &Type) {
59        for mut arg in &mut self.args {
60            let (ArgInfo {
61                r#type:
62                    TypeOrOverride::RustType {
63                        r#type: ref mut ty, ..
64                    },
65                ..
66            }
67            | ArgInfo {
68                r#type:
69                    TypeOrOverride::OverrideType {
70                        r#type: ref mut ty, ..
71                    },
72                ..
73            }) = &mut arg;
74            replace_inner(ty, self_);
75        }
76        if let Some(
77            TypeOrOverride::RustType { r#type: ret }
78            | TypeOrOverride::OverrideType { r#type: ret, .. },
79        ) = self.r#return.as_mut()
80        {
81            replace_inner(ret, self_);
82        }
83    }
84}
85
86impl TryFrom<ImplItemFn> for MethodInfo {
87    type Error = Error;
88    fn try_from(item: ImplItemFn) -> Result<Self> {
89        let ImplItemFn { attrs, sig, .. } = item;
90        let doc = extract_documents(&attrs).join("\n");
91        let deprecated = extract_deprecated(&attrs);
92        let pyo3_attrs = parse_pyo3_attrs(&attrs)?;
93        let mut method_name = None;
94        let mut text_sig = Signature::overriding_operator(&sig);
95        let mut method_type = MethodType::Instance;
96        for attr in pyo3_attrs {
97            match attr {
98                Attr::Name(name) => method_name = Some(name),
99                Attr::Signature(text_sig_) => text_sig = Some(text_sig_),
100                Attr::StaticMethod => method_type = MethodType::Static,
101                Attr::ClassMethod => method_type = MethodType::Class,
102                Attr::New => method_type = MethodType::New,
103                _ => {}
104            }
105        }
106        let name = if method_type == MethodType::New {
107            "__new__".to_string()
108        } else {
109            method_name.unwrap_or(sig.ident.to_string())
110        };
111        let r#return = extract_return_type(&sig.output, &attrs)?;
112        Ok(MethodInfo {
113            name,
114            sig: text_sig,
115            args: parse_args(sig.inputs)?,
116            r#return,
117            doc,
118            r#type: method_type,
119            is_async: sig.asyncness.is_some(),
120            deprecated,
121        })
122    }
123}
124
125impl ToTokens for MethodInfo {
126    fn to_tokens(&self, tokens: &mut TokenStream2) {
127        let Self {
128            name,
129            r#return: ret,
130            args,
131            sig,
132            doc,
133            r#type,
134            is_async,
135            deprecated,
136        } = self;
137        let args_with_sig = ArgsWithSignature { args, sig };
138        let ret_tt = if let Some(ret) = ret {
139            match ret {
140                TypeOrOverride::RustType { r#type } => {
141                    let ty = r#type.clone();
142                    quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
143                }
144                TypeOrOverride::OverrideType {
145                    type_repr, imports, ..
146                } => {
147                    let imports = imports.iter().collect::<Vec<&String>>();
148                    quote! {
149                        || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
150                    }
151                }
152            }
153        } else {
154            quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
155        };
156        let type_tt = match r#type {
157            MethodType::Instance => quote! { ::pyo3_stub_gen::type_info::MethodType::Instance },
158            MethodType::Static => quote! { ::pyo3_stub_gen::type_info::MethodType::Static },
159            MethodType::Class => quote! { ::pyo3_stub_gen::type_info::MethodType::Class },
160            MethodType::New => quote! { ::pyo3_stub_gen::type_info::MethodType::New },
161        };
162        let deprecated_tt = deprecated
163            .as_ref()
164            .map(|d| quote! { Some(#d) })
165            .unwrap_or_else(|| quote! { None });
166        tokens.append_all(quote! {
167            ::pyo3_stub_gen::type_info::MethodInfo {
168                name: #name,
169                args: #args_with_sig,
170                r#return: #ret_tt,
171                doc: #doc,
172                r#type: #type_tt,
173                is_async: #is_async,
174                deprecated: #deprecated_tt,
175            }
176        })
177    }
178}