pyo3_stub_gen_derive/gen_stub/
method.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
use super::{
    arg::parse_args, escape_return_type, extract_documents, parse_pyo3_attrs, quote_option,
    ArgInfo, Attr, Signature,
};

use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens, TokenStreamExt};
use syn::{
    Error, GenericArgument, ImplItemFn, PathArguments, Result, Type, TypePath, TypeReference,
};

#[derive(Debug)]
pub struct MethodInfo {
    name: String,
    args: Vec<ArgInfo>,
    sig: Option<Signature>,
    r#return: Option<Type>,
    doc: String,
    is_static: bool,
    is_class: bool,
}

fn replace_inner(ty: &mut Type, self_: &Type) {
    match ty {
        Type::Path(TypePath { path, .. }) => {
            if let Some(last) = path.segments.iter_mut().last() {
                if let PathArguments::AngleBracketed(arg) = &mut last.arguments {
                    for arg in &mut arg.args {
                        if let GenericArgument::Type(ty) = arg {
                            replace_inner(ty, self_);
                        }
                    }
                }
                if last.ident == "Self" {
                    *ty = self_.clone();
                }
            }
        }
        Type::Reference(TypeReference { elem, .. }) => {
            replace_inner(elem, self_);
        }
        _ => {}
    }
}

impl MethodInfo {
    pub fn replace_self(&mut self, self_: &Type) {
        for arg in &mut self.args {
            replace_inner(&mut arg.r#type, self_);
        }
        if let Some(ret) = self.r#return.as_mut() {
            replace_inner(ret, self_);
        }
    }
}

impl TryFrom<ImplItemFn> for MethodInfo {
    type Error = Error;
    fn try_from(item: ImplItemFn) -> Result<Self> {
        let ImplItemFn { attrs, sig, .. } = item;
        let doc = extract_documents(&attrs).join("\n");
        let attrs = parse_pyo3_attrs(&attrs)?;
        let mut method_name = None;
        let mut text_sig = Signature::overriding_operator(&sig);
        let mut is_static = false;
        let mut is_class = false;
        for attr in attrs {
            match attr {
                Attr::Name(name) => method_name = Some(name),
                Attr::Signature(text_sig_) => text_sig = Some(text_sig_),
                Attr::StaticMethod => is_static = true,
                Attr::ClassMethod => is_class = true,
                _ => {}
            }
        }
        let name = method_name.unwrap_or(sig.ident.to_string());
        let r#return = escape_return_type(&sig.output);
        Ok(MethodInfo {
            name,
            sig: text_sig,
            args: parse_args(sig.inputs)?,
            r#return,
            doc,
            is_static,
            is_class,
        })
    }
}

impl ToTokens for MethodInfo {
    fn to_tokens(&self, tokens: &mut TokenStream2) {
        let Self {
            name,
            r#return: ret,
            args,
            sig,
            doc,
            is_class,
            is_static,
        } = self;
        let sig_tt = quote_option(sig);
        let ret_tt = if let Some(ret) = ret {
            quote! { <#ret as pyo3_stub_gen::PyStubType>::type_output }
        } else {
            quote! { ::pyo3_stub_gen::type_info::no_return_type_output }
        };
        tokens.append_all(quote! {
            ::pyo3_stub_gen::type_info::MethodInfo {
                name: #name,
                args: &[ #(#args),* ],
                r#return: #ret_tt,
                signature: #sig_tt,
                doc: #doc,
                is_static: #is_static,
                is_class: #is_class,
            }
        })
    }
}