pyo3_stub_gen_derive/gen_stub/
attr.rs

1use super::Signature;
2use proc_macro2::TokenTree;
3use quote::ToTokens;
4use syn::{Attribute, Expr, ExprLit, Ident, Lit, Meta, MetaList, Result};
5
6pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
7    let mut docs = Vec::new();
8    for attr in attrs {
9        // `#[doc = "..."]` case
10        if attr.path().is_ident("doc") {
11            if let Meta::NameValue(syn::MetaNameValue {
12                value:
13                    Expr::Lit(ExprLit {
14                        lit: Lit::Str(doc), ..
15                    }),
16                ..
17            }) = &attr.meta
18            {
19                let doc = doc.value();
20                // Remove head space
21                //
22                // ```
23                // /// This is special document!
24                //    ^ This space is trimmed here
25                // ```
26                docs.push(if !doc.is_empty() && doc.starts_with(' ') {
27                    doc[1..].to_string()
28                } else {
29                    doc
30                });
31            }
32        }
33    }
34    docs
35}
36
37/// `#[pyo3(...)]` style attributes appear in `#[pyclass]` and `#[pymethods]` proc-macros
38///
39/// As the reference of PyO3 says:
40///
41/// https://docs.rs/pyo3/latest/pyo3/attr.pyclass.html
42/// > All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation,
43/// > or as one or more accompanying `#[pyo3(...)]` annotations,
44///
45/// `#[pyclass(name = "MyClass", module = "MyModule")]` will be decomposed into
46/// `#[pyclass]` + `#[pyo3(name = "MyClass")]` + `#[pyo3(module = "MyModule")]`,
47/// i.e. two `Attr`s will be created for this case.
48///
49#[derive(Debug, Clone, PartialEq)]
50pub enum Attr {
51    // Attributes appears in `#[pyo3(...)]` form or its equivalence
52    Name(String),
53    Get,
54    GetAll,
55    Module(String),
56    Signature(Signature),
57
58    // Attributes appears in components within `#[pymethods]`
59    // <https://docs.rs/pyo3/latest/pyo3/attr.pymethods.html>
60    New,
61    Getter(Option<String>),
62    StaticMethod,
63    ClassMethod,
64}
65
66pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
67    let mut out = Vec::new();
68    for attr in attrs {
69        let mut new = parse_pyo3_attr(attr)?;
70        out.append(&mut new);
71    }
72    Ok(out)
73}
74
75pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
76    let mut pyo3_attrs = Vec::new();
77    let path = attr.path();
78    if path.is_ident("pyclass")
79        || path.is_ident("pymethods")
80        || path.is_ident("pyfunction")
81        || path.is_ident("pyo3")
82    {
83        // Inner tokens of `#[pyo3(...)]` may not be nested meta
84        // which can be parsed by `Attribute::parse_nested_meta`
85        // due to the case of `#[pyo3(signature = (...))]`.
86        // https://pyo3.rs/v0.19.1/function/signature
87        if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
88            use TokenTree::*;
89            let tokens: Vec<TokenTree> = tokens.clone().into_iter().collect();
90            // Since `(...)` part with `signature` becomes `TokenTree::Group`,
91            // we can split entire stream by `,` first, and then pattern match to each cases.
92            for tt in tokens.split(|tt| {
93                if let Punct(p) = tt {
94                    p.as_char() == ','
95                } else {
96                    false
97                }
98            }) {
99                match tt {
100                    [Ident(ident)] => {
101                        if ident == "get" {
102                            pyo3_attrs.push(Attr::Get);
103                        }
104                        if ident == "get_all" {
105                            pyo3_attrs.push(Attr::GetAll);
106                        }
107                    }
108                    [Ident(ident), Punct(_), Literal(lit)] => {
109                        if ident == "name" {
110                            pyo3_attrs
111                                .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
112                        }
113                        if ident == "module" {
114                            pyo3_attrs
115                                .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
116                        }
117                    }
118                    [Ident(ident), Punct(_), Group(group)] => {
119                        if ident == "signature" {
120                            pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
121                        }
122                    }
123                    _ => {}
124                }
125            }
126        }
127    } else if path.is_ident("new") {
128        pyo3_attrs.push(Attr::New);
129    } else if path.is_ident("staticmethod") {
130        pyo3_attrs.push(Attr::StaticMethod);
131    } else if path.is_ident("classmethod") {
132        pyo3_attrs.push(Attr::ClassMethod);
133    } else if path.is_ident("getter") {
134        if let Ok(inner) = attr.parse_args::<Ident>() {
135            pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
136        } else {
137            pyo3_attrs.push(Attr::Getter(None));
138        }
139    }
140
141    Ok(pyo3_attrs)
142}
143
144#[cfg(test)]
145mod test {
146    use super::*;
147    use syn::{parse_str, Fields, ItemStruct};
148
149    #[test]
150    fn test_parse_pyo3_attr() -> Result<()> {
151        let item: ItemStruct = parse_str(
152            r#"
153            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
154            pub struct PyPlaceholder {
155                #[pyo3(get)]
156                pub name: String,
157            }
158            "#,
159        )?;
160        // `#[pyclass]` part
161        let attrs = parse_pyo3_attr(&item.attrs[0])?;
162        assert_eq!(
163            attrs,
164            vec![
165                Attr::Module("my_module".to_string()),
166                Attr::Name("Placeholder".to_string())
167            ]
168        );
169
170        // `#[pyo3(get)]` part
171        if let Fields::Named(fields) = item.fields {
172            let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
173            assert_eq!(attrs, vec![Attr::Get]);
174        } else {
175            unreachable!()
176        }
177        Ok(())
178    }
179}