pyo3_stub_gen_derive/gen_stub/
method.rs1use crate::gen_stub::util::TypeOrOverride;
2
3use super::{
4 arg::parse_args, extract_deprecated, extract_documents, extract_return_type, parse_pyo3_attrs,
5 ArgInfo, ArgsWithSignature, Attr, DeprecatedInfo, Signature,
6};
7
8use proc_macro2::TokenStream as TokenStream2;
9use quote::{quote, ToTokens, TokenStreamExt};
10use syn::{
11 Error, GenericArgument, ImplItemFn, PathArguments, Result, Type, TypePath, TypeReference,
12};
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum MethodType {
16 Instance,
17 Static,
18 Class,
19 New,
20}
21
22#[derive(Debug)]
23pub struct MethodInfo {
24 name: String,
25 args: Vec<ArgInfo>,
26 sig: Option<Signature>,
27 r#return: Option<TypeOrOverride>,
28 doc: String,
29 r#type: MethodType,
30 is_async: bool,
31 deprecated: Option<DeprecatedInfo>,
32}
33
34fn replace_inner(ty: &mut Type, self_: &Type) {
35 match ty {
36 Type::Path(TypePath { path, .. }) => {
37 if let Some(last) = path.segments.last_mut() {
38 if let PathArguments::AngleBracketed(arg) = &mut last.arguments {
39 for arg in &mut arg.args {
40 if let GenericArgument::Type(ty) = arg {
41 replace_inner(ty, self_);
42 }
43 }
44 }
45 if last.ident == "Self" {
46 *ty = self_.clone();
47 }
48 }
49 }
50 Type::Reference(TypeReference { elem, .. }) => {
51 replace_inner(elem, self_);
52 }
53 _ => {}
54 }
55}
56
57impl MethodInfo {
58 pub fn replace_self(&mut self, self_: &Type) {
59 for mut arg in &mut self.args {
60 let (ArgInfo {
61 r#type:
62 TypeOrOverride::RustType {
63 r#type: ref mut ty, ..
64 },
65 ..
66 }
67 | ArgInfo {
68 r#type:
69 TypeOrOverride::OverrideType {
70 r#type: ref mut ty, ..
71 },
72 ..
73 }) = &mut arg;
74 replace_inner(ty, self_);
75 }
76 if let Some(
77 TypeOrOverride::RustType { r#type: ret }
78 | TypeOrOverride::OverrideType { r#type: ret, .. },
79 ) = self.r#return.as_mut()
80 {
81 replace_inner(ret, self_);
82 }
83 }
84}
85
86impl TryFrom<ImplItemFn> for MethodInfo {
87 type Error = Error;
88 fn try_from(item: ImplItemFn) -> Result<Self> {
89 let ImplItemFn { attrs, sig, .. } = item;
90 let doc = extract_documents(&attrs).join("\n");
91 let deprecated = extract_deprecated(&attrs);
92 let pyo3_attrs = parse_pyo3_attrs(&attrs)?;
93 let mut method_name = None;
94 let mut text_sig = Signature::overriding_operator(&sig);
95 let mut method_type = MethodType::Instance;
96 for attr in pyo3_attrs {
97 match attr {
98 Attr::Name(name) => method_name = Some(name),
99 Attr::Signature(text_sig_) => text_sig = Some(text_sig_),
100 Attr::StaticMethod => method_type = MethodType::Static,
101 Attr::ClassMethod => method_type = MethodType::Class,
102 Attr::New => method_type = MethodType::New,
103 _ => {}
104 }
105 }
106 let name = if method_type == MethodType::New {
107 "__new__".to_string()
108 } else {
109 method_name.unwrap_or(sig.ident.to_string())
110 };
111 let r#return = extract_return_type(&sig.output, &attrs)?;
112 Ok(MethodInfo {
113 name,
114 sig: text_sig,
115 args: parse_args(sig.inputs)?,
116 r#return,
117 doc,
118 r#type: method_type,
119 is_async: sig.asyncness.is_some(),
120 deprecated,
121 })
122 }
123}
124
125impl ToTokens for MethodInfo {
126 fn to_tokens(&self, tokens: &mut TokenStream2) {
127 let Self {
128 name,
129 r#return: ret,
130 args,
131 sig,
132 doc,
133 r#type,
134 is_async,
135 deprecated,
136 } = self;
137 let args_with_sig = ArgsWithSignature { args, sig };
138 let ret_tt = if let Some(ret) = ret {
139 match ret {
140 TypeOrOverride::RustType { r#type } => {
141 let ty = r#type.clone();
142 quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
143 }
144 TypeOrOverride::OverrideType {
145 type_repr, imports, ..
146 } => {
147 let imports = imports.iter().collect::<Vec<&String>>();
148 quote! {
149 || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
150 }
151 }
152 }
153 } else {
154 quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
155 };
156 let type_tt = match r#type {
157 MethodType::Instance => quote! { ::pyo3_stub_gen::type_info::MethodType::Instance },
158 MethodType::Static => quote! { ::pyo3_stub_gen::type_info::MethodType::Static },
159 MethodType::Class => quote! { ::pyo3_stub_gen::type_info::MethodType::Class },
160 MethodType::New => quote! { ::pyo3_stub_gen::type_info::MethodType::New },
161 };
162 let deprecated_tt = deprecated
163 .as_ref()
164 .map(|d| quote! { Some(#d) })
165 .unwrap_or_else(|| quote! { None });
166 tokens.append_all(quote! {
167 ::pyo3_stub_gen::type_info::MethodInfo {
168 name: #name,
169 args: #args_with_sig,
170 r#return: #ret_tt,
171 doc: #doc,
172 r#type: #type_tt,
173 is_async: #is_async,
174 deprecated: #deprecated_tt,
175 }
176 })
177 }
178}