pyo3_stub_gen_derive/gen_stub/
method.rs1use crate::gen_stub::util::TypeOrOverride;
2
3use super::{
4 arg::parse_args, attr::IgnoreTarget, extract_deprecated, extract_documents,
5 extract_return_type, parse_gen_stub_type_ignore, parse_pyo3_attrs, ArgInfo, ArgsWithSignature,
6 Attr, DeprecatedInfo, Signature,
7};
8
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{quote, ToTokens, TokenStreamExt};
11use syn::{
12 Error, GenericArgument, ImplItemFn, PathArguments, Result, Type, TypePath, TypeReference,
13};
14
15#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum MethodType {
17 Instance,
18 Static,
19 Class,
20 New,
21}
22
23#[derive(Debug)]
24pub struct MethodInfo {
25 name: String,
26 args: Vec<ArgInfo>,
27 sig: Option<Signature>,
28 r#return: Option<TypeOrOverride>,
29 doc: String,
30 r#type: MethodType,
31 is_async: bool,
32 deprecated: Option<DeprecatedInfo>,
33 type_ignored: Option<IgnoreTarget>,
34}
35
36fn replace_inner(ty: &mut Type, self_: &Type) {
37 match ty {
38 Type::Path(TypePath { path, .. }) => {
39 if let Some(last) = path.segments.last_mut() {
40 if let PathArguments::AngleBracketed(arg) = &mut last.arguments {
41 for arg in &mut arg.args {
42 if let GenericArgument::Type(ty) = arg {
43 replace_inner(ty, self_);
44 }
45 }
46 }
47 if last.ident == "Self" {
48 *ty = self_.clone();
49 }
50 }
51 }
52 Type::Reference(TypeReference { elem, .. }) => {
53 replace_inner(elem, self_);
54 }
55 _ => {}
56 }
57}
58
59impl MethodInfo {
60 pub fn replace_self(&mut self, self_: &Type) {
61 for mut arg in &mut self.args {
62 let (ArgInfo {
63 r#type:
64 TypeOrOverride::RustType {
65 r#type: ref mut ty, ..
66 },
67 ..
68 }
69 | ArgInfo {
70 r#type:
71 TypeOrOverride::OverrideType {
72 r#type: ref mut ty, ..
73 },
74 ..
75 }) = &mut arg;
76 replace_inner(ty, self_);
77 }
78 if let Some(
79 TypeOrOverride::RustType { r#type: ret }
80 | TypeOrOverride::OverrideType { r#type: ret, .. },
81 ) = self.r#return.as_mut()
82 {
83 replace_inner(ret, self_);
84 }
85 }
86}
87
88impl TryFrom<ImplItemFn> for MethodInfo {
89 type Error = Error;
90 fn try_from(item: ImplItemFn) -> Result<Self> {
91 let ImplItemFn { attrs, sig, .. } = item;
92 let doc = extract_documents(&attrs).join("\n");
93 let deprecated = extract_deprecated(&attrs);
94 let type_ignored = parse_gen_stub_type_ignore(&attrs)?;
95 let pyo3_attrs = parse_pyo3_attrs(&attrs)?;
96 let mut method_name = None;
97 let mut text_sig = Signature::overriding_operator(&sig);
98 let mut method_type = MethodType::Instance;
99 for attr in pyo3_attrs {
100 match attr {
101 Attr::Name(name) => method_name = Some(name),
102 Attr::Signature(text_sig_) => text_sig = Some(text_sig_),
103 Attr::StaticMethod => method_type = MethodType::Static,
104 Attr::ClassMethod => method_type = MethodType::Class,
105 Attr::New => method_type = MethodType::New,
106 _ => {}
107 }
108 }
109 let name = if method_type == MethodType::New {
110 "__new__".to_string()
111 } else {
112 method_name.unwrap_or(sig.ident.to_string())
113 };
114 let r#return = extract_return_type(&sig.output, &attrs)?;
115 Ok(MethodInfo {
116 name,
117 sig: text_sig,
118 args: parse_args(sig.inputs)?,
119 r#return,
120 doc,
121 r#type: method_type,
122 is_async: sig.asyncness.is_some(),
123 deprecated,
124 type_ignored,
125 })
126 }
127}
128
129impl ToTokens for MethodInfo {
130 fn to_tokens(&self, tokens: &mut TokenStream2) {
131 let Self {
132 name,
133 r#return: ret,
134 args,
135 sig,
136 doc,
137 r#type,
138 is_async,
139 deprecated,
140 type_ignored,
141 } = self;
142 let args_with_sig = ArgsWithSignature { args, sig };
143 let ret_tt = if let Some(ret) = ret {
144 match ret {
145 TypeOrOverride::RustType { r#type } => {
146 let ty = r#type.clone();
147 quote! { <#ty as pyo3_stub_gen::PyStubType>::type_output }
148 }
149 TypeOrOverride::OverrideType {
150 type_repr, imports, ..
151 } => {
152 let imports = imports.iter().collect::<Vec<&String>>();
153 quote! {
154 || ::pyo3_stub_gen::TypeInfo { name: #type_repr.to_string(), import: ::std::collections::HashSet::from([#(#imports.into(),)*]) }
155 }
156 }
157 }
158 } else {
159 quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
160 };
161 let type_tt = match r#type {
162 MethodType::Instance => quote! { ::pyo3_stub_gen::type_info::MethodType::Instance },
163 MethodType::Static => quote! { ::pyo3_stub_gen::type_info::MethodType::Static },
164 MethodType::Class => quote! { ::pyo3_stub_gen::type_info::MethodType::Class },
165 MethodType::New => quote! { ::pyo3_stub_gen::type_info::MethodType::New },
166 };
167 let deprecated_tt = deprecated
168 .as_ref()
169 .map(|d| quote! { Some(#d) })
170 .unwrap_or_else(|| quote! { None });
171 let type_ignored_tt = if let Some(target) = type_ignored {
172 match target {
173 IgnoreTarget::All => {
174 quote! { Some(::pyo3_stub_gen::type_info::IgnoreTarget::All) }
175 }
176 IgnoreTarget::SpecifiedLits(rules) => {
177 let rule_strs: Vec<String> = rules.iter().map(|lit| lit.value()).collect();
178 quote! {
179 Some(::pyo3_stub_gen::type_info::IgnoreTarget::Specified(
180 &[#(#rule_strs),*] as &[&str]
181 ))
182 }
183 }
184 }
185 } else {
186 quote! { None }
187 };
188 tokens.append_all(quote! {
189 ::pyo3_stub_gen::type_info::MethodInfo {
190 name: #name,
191 args: #args_with_sig,
192 r#return: #ret_tt,
193 doc: #doc,
194 r#type: #type_tt,
195 is_async: #is_async,
196 deprecated: #deprecated_tt,
197 type_ignored: #type_ignored_tt,
198 }
199 })
200 }
201}