pyo3_stub_gen_derive/gen_stub/
method.rs1use crate::gen_stub::util::TypeOrOverride;
2
3use super::{
4 arg::parse_args, attr::IgnoreTarget, extract_deprecated, extract_documents,
5 extract_return_type, parameter::Parameters, parse_gen_stub_type_ignore, parse_pyo3_attrs,
6 ArgInfo, Attr, DeprecatedInfo, Signature,
7};
8
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{quote, ToTokens, TokenStreamExt};
11use syn::{
12 Error, GenericArgument, ImplItemFn, PathArguments, Result, Type, TypePath, TypeReference,
13};
14
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum MethodType {
17 Instance,
18 Static,
19 Class,
20 New,
21}
22
23#[derive(Debug)]
24pub struct MethodInfo {
25 pub(super) name: String,
26 pub(super) parameters: Parameters,
27 pub(super) r#return: Option<TypeOrOverride>,
28 pub(super) doc: String,
29 pub(super) r#type: MethodType,
30 pub(super) is_async: bool,
31 pub(super) deprecated: Option<DeprecatedInfo>,
32 pub(super) type_ignored: Option<IgnoreTarget>,
33 pub(super) is_overload: bool,
34}
35
36fn replace_inner(ty: &mut Type, self_: &Type) {
37 match ty {
38 Type::Path(TypePath { path, .. }) => {
39 if let Some(last) = path.segments.last_mut() {
40 if let PathArguments::AngleBracketed(arg) = &mut last.arguments {
41 for arg in &mut arg.args {
42 if let GenericArgument::Type(ty) = arg {
43 replace_inner(ty, self_);
44 }
45 }
46 }
47 if last.ident == "Self" {
48 *ty = self_.clone();
49 }
50 }
51 }
52 Type::Reference(TypeReference { elem, .. }) => {
53 replace_inner(elem, self_);
54 }
55 _ => {}
56 }
57}
58
59impl MethodInfo {
60 pub fn replace_self(&mut self, self_: &Type) {
61 for param in self.parameters.iter_mut() {
62 let arg_info = &mut param.arg_info;
63 let (ArgInfo {
64 r#type:
65 TypeOrOverride::RustType {
66 r#type: ref mut ty, ..
67 },
68 ..
69 }
70 | ArgInfo {
71 r#type:
72 TypeOrOverride::OverrideType {
73 r#type: ref mut ty, ..
74 },
75 ..
76 }) = arg_info;
77 replace_inner(ty, self_);
78 }
79 if let Some(
80 TypeOrOverride::RustType { r#type: ret }
81 | TypeOrOverride::OverrideType { r#type: ret, .. },
82 ) = self.r#return.as_mut()
83 {
84 replace_inner(ret, self_);
85 }
86 }
87}
88
89impl TryFrom<ImplItemFn> for MethodInfo {
90 type Error = Error;
91 fn try_from(item: ImplItemFn) -> Result<Self> {
92 let ImplItemFn { attrs, sig, .. } = item;
93 let doc = extract_documents(&attrs).join("\n");
94 let deprecated = extract_deprecated(&attrs);
95 let type_ignored = parse_gen_stub_type_ignore(&attrs)?;
96 let pyo3_attrs = parse_pyo3_attrs(&attrs)?;
97 let mut method_name = None;
98 let mut text_sig = Signature::overriding_operator(&sig);
99 let mut method_type = MethodType::Instance;
100 for attr in pyo3_attrs {
101 match attr {
102 Attr::Name(name) => method_name = Some(name),
103 Attr::Signature(text_sig_) => text_sig = Some(text_sig_),
104 Attr::StaticMethod => method_type = MethodType::Static,
105 Attr::ClassMethod => method_type = MethodType::Class,
106 Attr::New => method_type = MethodType::New,
107 _ => {}
108 }
109 }
110 let name = if method_type == MethodType::New {
111 "__new__".to_string()
112 } else {
113 method_name.unwrap_or(sig.ident.to_string())
114 };
115 let r#return = extract_return_type(&sig.output, &attrs)?;
116
117 let args = parse_args(sig.inputs)?;
119 let parameters = if let Some(text_sig) = text_sig {
120 Parameters::new_with_sig(&args, &text_sig)?
121 } else {
122 Parameters::new(&args)
123 };
124
125 Ok(MethodInfo {
126 name,
127 parameters,
128 r#return,
129 doc,
130 r#type: method_type,
131 is_async: sig.asyncness.is_some(),
132 deprecated,
133 type_ignored,
134 is_overload: false,
135 })
136 }
137}
138
139impl ToTokens for MethodInfo {
140 fn to_tokens(&self, tokens: &mut TokenStream2) {
141 let Self {
142 name,
143 r#return: ret,
144 parameters,
145 doc,
146 r#type,
147 is_async,
148 deprecated,
149 type_ignored,
150 is_overload,
151 } = self;
152
153 let ret_tt = if let Some(ret) = ret {
154 match ret {
155 TypeOrOverride::RustType { r#type } => {
156 let ty = r#type.clone();
157 quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
158 }
159 TypeOrOverride::OverrideType {
160 type_repr, imports, ..
161 } => {
162 let imports = imports.iter().collect::<Vec<&String>>();
163 quote! {
164 || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
165 }
166 }
167 }
168 } else {
169 quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
170 };
171 let type_tt = match r#type {
172 MethodType::Instance => quote! { ::pyo3_stub_gen::type_info::MethodType::Instance },
173 MethodType::Static => quote! { ::pyo3_stub_gen::type_info::MethodType::Static },
174 MethodType::Class => quote! { ::pyo3_stub_gen::type_info::MethodType::Class },
175 MethodType::New => quote! { ::pyo3_stub_gen::type_info::MethodType::New },
176 };
177 let deprecated_tt = deprecated
178 .as_ref()
179 .map(|d| quote! { Some(#d) })
180 .unwrap_or_else(|| quote! { None });
181 let type_ignored_tt = if let Some(target) = type_ignored {
182 match target {
183 IgnoreTarget::All => {
184 quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
185 }
186 IgnoreTarget::SpecifiedLits(rules) => {
187 let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
188 quote! {
189 Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
190 &[#(#rule_strs),*] as &[&str]
191 ))
192 }
193 }
194 }
195 } else {
196 quote! { None }
197 };
198 tokens.append_all(quote! {
199 ::pyo3_stub_gen::type_info::MethodInfo {
200 name: #name,
201 parameters: #parameters,
202 r#return: #ret_tt,
203 doc: #doc,
204 r#type: #type_tt,
205 is_async: #is_async,
206 deprecated: #deprecated_tt,
207 type_ignored: #type_ignored_tt,
208 is_overload: #is_overload,
209 }
210 })
211 }
212}