pyo3_stub_gen_derive/gen_stub/
attr.rs

1use super::{RenamingRule, Signature};
2use proc_macro2::TokenTree;
3use quote::ToTokens;
4use syn::{Attribute, Expr, ExprLit, Ident, Lit, Meta, MetaList, Result, Type};
5
6pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
7    let mut docs = Vec::new();
8    for attr in attrs {
9        // `#[doc = "..."]` case
10        if attr.path().is_ident("doc") {
11            if let Meta::NameValue(syn::MetaNameValue {
12                value:
13                    Expr::Lit(ExprLit {
14                        lit: Lit::Str(doc), ..
15                    }),
16                ..
17            }) = &attr.meta
18            {
19                let doc = doc.value();
20                // Remove head space
21                //
22                // ```
23                // /// This is special document!
24                //    ^ This space is trimmed here
25                // ```
26                docs.push(if !doc.is_empty() && doc.starts_with(' ') {
27                    doc[1..].to_string()
28                } else {
29                    doc
30                });
31            }
32        }
33    }
34    docs
35}
36
37/// `#[pyo3(...)]` style attributes appear in `#[pyclass]` and `#[pymethods]` proc-macros
38///
39/// As the reference of PyO3 says:
40///
41/// https://docs.rs/pyo3/latest/pyo3/attr.pyclass.html
42/// > All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation,
43/// > or as one or more accompanying `#[pyo3(...)]` annotations,
44///
45/// `#[pyclass(name = "MyClass", module = "MyModule")]` will be decomposed into
46/// `#[pyclass]` + `#[pyo3(name = "MyClass")]` + `#[pyo3(module = "MyModule")]`,
47/// i.e. two `Attr`s will be created for this case.
48///
49#[derive(Debug, Clone, PartialEq)]
50pub enum Attr {
51    // Attributes appears in `#[pyo3(...)]` form or its equivalence
52    Name(String),
53    Get,
54    GetAll,
55    Module(String),
56    Signature(Signature),
57    RenameAll(RenamingRule),
58    Extends(Type),
59
60    // Attributes appears in components within `#[pymethods]`
61    // <https://docs.rs/pyo3/latest/pyo3/attr.pymethods.html>
62    New,
63    Getter(Option<String>),
64    StaticMethod,
65    ClassMethod,
66}
67
68pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
69    let mut out = Vec::new();
70    for attr in attrs {
71        let mut new = parse_pyo3_attr(attr)?;
72        out.append(&mut new);
73    }
74    Ok(out)
75}
76
77pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
78    let mut pyo3_attrs = Vec::new();
79    let path = attr.path();
80    let is_full_path_pyo3_attr = path.segments.len() == 2
81        && path
82            .segments
83            .first()
84            .is_some_and(|seg| seg.ident.eq("pyo3"))
85        && path.segments.last().is_some_and(|seg| {
86            seg.ident.eq("pyclass") || seg.ident.eq("pymethods") || seg.ident.eq("pyfunction")
87        });
88    if path.is_ident("pyclass")
89        || path.is_ident("pymethods")
90        || path.is_ident("pyfunction")
91        || path.is_ident("pyo3")
92        || is_full_path_pyo3_attr
93    {
94        // Inner tokens of `#[pyo3(...)]` may not be nested meta
95        // which can be parsed by `Attribute::parse_nested_meta`
96        // due to the case of `#[pyo3(signature = (...))]`.
97        // https://pyo3.rs/v0.19.1/function/signature
98        if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
99            use TokenTree::*;
100            let tokens: Vec<TokenTree> = tokens.clone().into_iter().collect();
101            // Since `(...)` part with `signature` becomes `TokenTree::Group`,
102            // we can split entire stream by `,` first, and then pattern match to each cases.
103            for tt in tokens.split(|tt| {
104                if let Punct(p) = tt {
105                    p.as_char() == ','
106                } else {
107                    false
108                }
109            }) {
110                match tt {
111                    [Ident(ident)] => {
112                        if ident == "get" {
113                            pyo3_attrs.push(Attr::Get);
114                        }
115                        if ident == "get_all" {
116                            pyo3_attrs.push(Attr::GetAll);
117                        }
118                    }
119                    [Ident(ident), Punct(_), Literal(lit)] => {
120                        if ident == "name" {
121                            pyo3_attrs
122                                .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
123                        }
124                        if ident == "module" {
125                            pyo3_attrs
126                                .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
127                        }
128                        if ident == "rename_all" {
129                            let name = lit.to_string().trim_matches('"').to_string();
130                            if let Some(renaming_rule) = RenamingRule::try_new(&name) {
131                                pyo3_attrs.push(Attr::RenameAll(renaming_rule));
132                            }
133                        }
134                    }
135                    [Ident(ident), Punct(_), Group(group)] => {
136                        if ident == "signature" {
137                            pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
138                        }
139                    }
140                    [Ident(ident), Punct(_), Ident(ident2)] => {
141                        if ident == "extends" {
142                            pyo3_attrs.push(Attr::Extends(syn::parse2(ident2.to_token_stream())?));
143                        }
144                    }
145                    _ => {}
146                }
147            }
148        }
149    } else if path.is_ident("new") {
150        pyo3_attrs.push(Attr::New);
151    } else if path.is_ident("staticmethod") {
152        pyo3_attrs.push(Attr::StaticMethod);
153    } else if path.is_ident("classmethod") {
154        pyo3_attrs.push(Attr::ClassMethod);
155    } else if path.is_ident("getter") {
156        if let Ok(inner) = attr.parse_args::<Ident>() {
157            pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
158        } else {
159            pyo3_attrs.push(Attr::Getter(None));
160        }
161    }
162
163    Ok(pyo3_attrs)
164}
165
166#[cfg(test)]
167mod test {
168    use super::*;
169    use syn::{parse_str, Fields, ItemStruct};
170
171    #[test]
172    fn test_parse_pyo3_attr() -> Result<()> {
173        let item: ItemStruct = parse_str(
174            r#"
175            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
176            #[pyo3(rename_all = "SCREAMING_SNAKE_CASE")]
177            pub struct PyPlaceholder {
178                #[pyo3(get)]
179                pub name: String,
180            }
181            "#,
182        )?;
183        // `#[pyclass]` part
184        let attrs = parse_pyo3_attrs(&item.attrs)?;
185        assert_eq!(
186            attrs,
187            vec![
188                Attr::Module("my_module".to_string()),
189                Attr::Name("Placeholder".to_string()),
190                Attr::RenameAll(RenamingRule::ScreamingSnakeCase),
191            ]
192        );
193
194        // `#[pyo3(get)]` part
195        if let Fields::Named(fields) = item.fields {
196            let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
197            assert_eq!(attrs, vec![Attr::Get]);
198        } else {
199            unreachable!()
200        }
201        Ok(())
202    }
203
204    #[test]
205    fn test_parse_pyo3_attr_full_path() -> Result<()> {
206        let item: ItemStruct = parse_str(
207            r#"
208            #[pyo3::pyclass(mapping, module = "my_module", name = "Placeholder")]
209            pub struct PyPlaceholder {
210                #[pyo3(get)]
211                pub name: String,
212            }
213            "#,
214        )?;
215        // `#[pyclass]` part
216        let attrs = parse_pyo3_attr(&item.attrs[0])?;
217        assert_eq!(
218            attrs,
219            vec![
220                Attr::Module("my_module".to_string()),
221                Attr::Name("Placeholder".to_string())
222            ]
223        );
224
225        // `#[pyo3(get)]` part
226        if let Fields::Named(fields) = item.fields {
227            let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
228            assert_eq!(attrs, vec![Attr::Get]);
229        } else {
230            unreachable!()
231        }
232        Ok(())
233    }
234}