pyo3_stub_gen_derive/gen_stub/
attr.rs

1use std::collections::HashSet;
2
3use super::{RenamingRule, Signature};
4use proc_macro2::{TokenStream as TokenStream2, TokenTree};
5use quote::{quote, ToTokens, TokenStreamExt};
6use syn::{
7    parenthesized,
8    parse::{Parse, ParseStream},
9    punctuated::Punctuated,
10    Attribute, Expr, ExprLit, Ident, Lit, LitStr, Meta, MetaList, Result, Token, Type,
11};
12
13pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
14    let mut docs = Vec::new();
15    for attr in attrs {
16        // `#[doc = "..."]` case
17        if attr.path().is_ident("doc") {
18            if let Meta::NameValue(syn::MetaNameValue {
19                value:
20                    Expr::Lit(ExprLit {
21                        lit: Lit::Str(doc), ..
22                    }),
23                ..
24            }) = &attr.meta
25            {
26                let doc = doc.value();
27                // Remove head space
28                //
29                // ```
30                // /// This is special document!
31                //    ^ This space is trimmed here
32                // ```
33                docs.push(if !doc.is_empty() && doc.starts_with(' ') {
34                    doc[1..].to_string()
35                } else {
36                    doc
37                });
38            }
39        }
40    }
41    docs
42}
43
44/// Extract `#[deprecated(...)]` attribute
45pub fn extract_deprecated(attrs: &[Attribute]) -> Option<DeprecatedInfo> {
46    for attr in attrs {
47        if attr.path().is_ident("deprecated") {
48            if let Ok(list) = attr.meta.require_list() {
49                let mut since = None;
50                let mut note = None;
51
52                list.parse_nested_meta(|meta| {
53                    if meta.path.is_ident("since") {
54                        let value = meta.value()?;
55                        let lit: LitStr = value.parse()?;
56                        since = Some(lit.value());
57                    } else if meta.path.is_ident("note") {
58                        let value = meta.value()?;
59                        let lit: LitStr = value.parse()?;
60                        note = Some(lit.value());
61                    }
62                    Ok(())
63                })
64                .ok()?;
65
66                return Some(DeprecatedInfo { since, note });
67            }
68        }
69    }
70    None
71}
72
73/// `#[pyo3(...)]` style attributes appear in `#[pyclass]` and `#[pymethods]` proc-macros
74///
75/// As the reference of PyO3 says:
76///
77/// https://docs.rs/pyo3/latest/pyo3/attr.pyclass.html
78/// > All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation,
79/// > or as one or more accompanying `#[pyo3(...)]` annotations,
80///
81/// `#[pyclass(name = "MyClass", module = "MyModule")]` will be decomposed into
82/// `#[pyclass]` + `#[pyo3(name = "MyClass")]` + `#[pyo3(module = "MyModule")]`,
83/// i.e. two `Attr`s will be created for this case.
84///
85#[derive(Debug, Clone, PartialEq)]
86pub struct DeprecatedInfo {
87    pub since: Option<String>,
88    pub note: Option<String>,
89}
90
91impl ToTokens for DeprecatedInfo {
92    fn to_tokens(&self, tokens: &mut TokenStream2) {
93        let since = self
94            .since
95            .as_ref()
96            .map(|s| quote! { Some(#s) })
97            .unwrap_or_else(|| quote! { None });
98        let note = self
99            .note
100            .as_ref()
101            .map(|n| quote! { Some(#n) })
102            .unwrap_or_else(|| quote! { None });
103        tokens.append_all(quote! {
104            ::pyo3_stub_gen::type_info::DeprecatedInfo {
105                since: #since,
106                note: #note,
107            }
108        })
109    }
110}
111
112#[derive(Debug, Clone, PartialEq)]
113#[expect(clippy::enum_variant_names)]
114pub enum Attr {
115    // Attributes appears in `#[pyo3(...)]` form or its equivalence
116    Name(String),
117    Get,
118    GetAll,
119    Set,
120    SetAll,
121    Module(String),
122    Constructor(Signature),
123    Signature(Signature),
124    RenameAll(RenamingRule),
125    Extends(Type),
126
127    // Attributes appears in components within `#[pymethods]`
128    // <https://docs.rs/pyo3/latest/pyo3/attr.pymethods.html>
129    New,
130    Getter(Option<String>),
131    Setter(Option<String>),
132    StaticMethod,
133    ClassMethod,
134    ClassAttr,
135}
136
137pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
138    let mut out = Vec::new();
139    for attr in attrs {
140        let mut new = parse_pyo3_attr(attr)?;
141        out.append(&mut new);
142    }
143    Ok(out)
144}
145
146pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
147    let mut pyo3_attrs = Vec::new();
148    let path = attr.path();
149    let is_full_path_pyo3_attr = path.segments.len() == 2
150        && path
151            .segments
152            .first()
153            .is_some_and(|seg| seg.ident.eq("pyo3"))
154        && path.segments.last().is_some_and(|seg| {
155            seg.ident.eq("pyclass") || seg.ident.eq("pymethods") || seg.ident.eq("pyfunction")
156        });
157    if path.is_ident("pyclass")
158        || path.is_ident("pymethods")
159        || path.is_ident("pyfunction")
160        || path.is_ident("pyo3")
161        || is_full_path_pyo3_attr
162    {
163        // Inner tokens of `#[pyo3(...)]` may not be nested meta
164        // which can be parsed by `Attribute::parse_nested_meta`
165        // due to the case of `#[pyo3(signature = (...))]`.
166        // https://pyo3.rs/v0.19.1/function/signature
167        if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
168            use TokenTree::*;
169            let tokens: Vec<TokenTree> = tokens.clone().into_iter().collect();
170            // Since `(...)` part with `signature` becomes `TokenTree::Group`,
171            // we can split entire stream by `,` first, and then pattern match to each cases.
172            for tt in tokens.split(|tt| {
173                if let Punct(p) = tt {
174                    p.as_char() == ','
175                } else {
176                    false
177                }
178            }) {
179                match tt {
180                    [Ident(ident)] => {
181                        if ident == "get" {
182                            pyo3_attrs.push(Attr::Get);
183                        }
184                        if ident == "get_all" {
185                            pyo3_attrs.push(Attr::GetAll);
186                        }
187                        if ident == "set" {
188                            pyo3_attrs.push(Attr::Set);
189                        }
190                        if ident == "set_all" {
191                            pyo3_attrs.push(Attr::SetAll);
192                        }
193                    }
194                    [Ident(ident), Punct(_), Literal(lit)] => {
195                        if ident == "name" {
196                            pyo3_attrs
197                                .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
198                        }
199                        if ident == "module" {
200                            pyo3_attrs
201                                .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
202                        }
203                        if ident == "rename_all" {
204                            let name = lit.to_string().trim_matches('"').to_string();
205                            if let Some(renaming_rule) = RenamingRule::try_new(&name) {
206                                pyo3_attrs.push(Attr::RenameAll(renaming_rule));
207                            }
208                        }
209                    }
210                    [Ident(ident), Punct(_), Group(group)] => {
211                        if ident == "signature" {
212                            pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
213                        } else if ident == "constructor" {
214                            pyo3_attrs
215                                .push(Attr::Constructor(syn::parse2(group.to_token_stream())?));
216                        }
217                    }
218                    [Ident(ident), Punct(_), Ident(ident2)] => {
219                        if ident == "extends" {
220                            pyo3_attrs.push(Attr::Extends(syn::parse2(ident2.to_token_stream())?));
221                        }
222                    }
223                    _ => {}
224                }
225            }
226        }
227    } else if path.is_ident("new") {
228        pyo3_attrs.push(Attr::New);
229    } else if path.is_ident("staticmethod") {
230        pyo3_attrs.push(Attr::StaticMethod);
231    } else if path.is_ident("classmethod") {
232        pyo3_attrs.push(Attr::ClassMethod);
233    } else if path.is_ident("classattr") {
234        pyo3_attrs.push(Attr::ClassAttr);
235    } else if path.is_ident("getter") {
236        if let Ok(inner) = attr.parse_args::<Ident>() {
237            pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
238        } else {
239            pyo3_attrs.push(Attr::Getter(None));
240        }
241    } else if path.is_ident("setter") {
242        if let Ok(inner) = attr.parse_args::<Ident>() {
243            pyo3_attrs.push(Attr::Setter(Some(inner.to_string())));
244        } else {
245            pyo3_attrs.push(Attr::Setter(None));
246        }
247    }
248
249    Ok(pyo3_attrs)
250}
251
252#[derive(Debug, Clone, PartialEq)]
253pub enum StubGenAttr {
254    /// Default value for getter
255    Default(Expr),
256    /// Skip a function in #[pymethods]
257    Skip,
258    /// Override the python type for a function argument or return type
259    OverrideType(OverrideTypeAttribute),
260}
261
262pub fn prune_attrs(attrs: &mut Vec<Attribute>) {
263    attrs.retain(|attr| !attr.path().is_ident("gen_stub"));
264}
265
266pub fn parse_gen_stub_override_type(attrs: &[Attribute]) -> Result<Option<OverrideTypeAttribute>> {
267    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)? {
268        if let StubGenAttr::OverrideType(attr) = attr {
269            return Ok(Some(attr));
270        }
271    }
272    Ok(None)
273}
274
275pub fn parse_gen_stub_override_return_type(
276    attrs: &[Attribute],
277) -> Result<Option<OverrideTypeAttribute>> {
278    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
279        if let StubGenAttr::OverrideType(attr) = attr {
280            return Ok(Some(attr));
281        }
282    }
283    Ok(None)
284}
285
286pub fn parse_gen_stub_default(attrs: &[Attribute]) -> Result<Option<Expr>> {
287    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
288        if let StubGenAttr::Default(default) = attr {
289            return Ok(Some(default));
290        }
291    }
292    Ok(None)
293}
294pub fn parse_gen_stub_skip(attrs: &[Attribute]) -> Result<bool> {
295    let skip = parse_gen_stub_attrs(
296        attrs,
297        AttributeLocation::Field,
298        Some(&["override_return_type", "default"]),
299    )?
300    .iter()
301    .any(|attr| matches!(attr, StubGenAttr::Skip));
302    Ok(skip)
303}
304fn parse_gen_stub_attrs(
305    attrs: &[Attribute],
306    location: AttributeLocation,
307    ignored_idents: Option<&[&str]>,
308) -> Result<Vec<StubGenAttr>> {
309    let mut out = Vec::new();
310    for attr in attrs {
311        let mut new = parse_gen_stub_attr(attr, location, ignored_idents.unwrap_or(&[]))?;
312        out.append(&mut new);
313    }
314    Ok(out)
315}
316
317fn parse_gen_stub_attr(
318    attr: &Attribute,
319    location: AttributeLocation,
320    ignored_idents: &[&str],
321) -> Result<Vec<StubGenAttr>> {
322    let mut gen_stub_attrs = Vec::new();
323    let path = attr.path();
324    if path.is_ident("gen_stub") {
325        attr.parse_args_with(|input: ParseStream| {
326            while !input.is_empty() {
327                let ident: Ident = input.parse()?;
328                let ignored_ident = ignored_idents.iter().any(|other| ident == other);
329                if (ident == "override_type"
330                    && (location == AttributeLocation::Argument || ignored_ident))
331                    || (ident == "override_return_type"
332                        && (location == AttributeLocation::Function || ignored_ident))
333                {
334                    let content;
335                    parenthesized!(content in input);
336                    let override_attr: OverrideTypeAttribute = content.parse()?;
337                    gen_stub_attrs.push(StubGenAttr::OverrideType(override_attr));
338                } else if ident == "skip" && (location == AttributeLocation::Field || ignored_ident)
339                {
340                    gen_stub_attrs.push(StubGenAttr::Skip);
341                } else if ident == "default"
342                    && input.peek(Token![=])
343                    && (location == AttributeLocation::Field || location == AttributeLocation::Function || ignored_ident)
344                {
345                    input.parse::<Token![=]>()?;
346                    gen_stub_attrs.push(StubGenAttr::Default(input.parse()?));
347                } else if ident == "override_type" {
348                    return Err(syn::Error::new(
349                        ident.span(),
350                        "`override_type(...)` is only valid in argument position".to_string(),
351                    ));
352                } else if ident == "override_return_type" {
353                    return Err(syn::Error::new(
354                        ident.span(),
355                        "`override_return_type(...)` is only valid in function position"
356                            .to_string(),
357                    ));
358                } else if ident == "skip" {
359                    return Err(syn::Error::new(
360                        ident.span(),
361                        "`skip` is only valid in field position".to_string(),
362                    ));
363                } else if ident == "default" {
364                    return Err(syn::Error::new(
365                        ident.span(),
366                        "`default=xxx` is only valid in field or function position".to_string(),
367                    ));
368                } else if location == AttributeLocation::Argument {
369                    return Err(syn::Error::new(
370                        ident.span(),
371                        format!("Unsupported keyword `{ident}`, valid is `override_type(...)`"),
372                    ));
373                } else if location == AttributeLocation::Field {
374                    return Err(syn::Error::new(
375                        ident.span(),
376                        format!("Unsupported keyword `{ident}`, valid is `default=xxx` or `skip`"),
377                    ));
378                } else if location == AttributeLocation::Function {
379                    return Err(syn::Error::new(
380                        ident.span(),
381                        format!(
382                            "Unsupported keyword `{ident}`, valid is `default=xxx` or `override_return_type(...)`"
383                        ),
384                    ));
385                } else {
386                    return Err(syn::Error::new(
387                        ident.span(),
388                        format!("Unsupported keyword `{ident}`"),
389                    ));
390                }
391                if input.peek(Token![,]) {
392                    input.parse::<Token![,]>()?;
393                } else {
394                    break;
395                }
396            }
397            Ok(())
398        })?;
399    }
400    Ok(gen_stub_attrs)
401}
402
403#[derive(Debug, Clone, Copy, PartialEq)]
404pub(crate) enum AttributeLocation {
405    Argument,
406    Field,
407    Function,
408}
409
410#[derive(Debug, Clone, PartialEq)]
411pub struct OverrideTypeAttribute {
412    pub(crate) type_repr: String,
413    pub(crate) imports: HashSet<String>,
414}
415
416mod kw {
417    syn::custom_keyword!(type_repr);
418    syn::custom_keyword!(imports);
419    syn::custom_keyword!(override_type);
420}
421
422impl Parse for OverrideTypeAttribute {
423    fn parse(input: ParseStream) -> Result<Self> {
424        let mut type_repr = None;
425        let mut imports = HashSet::new();
426
427        while !input.is_empty() {
428            let lookahead = input.lookahead1();
429
430            if lookahead.peek(kw::type_repr) {
431                input.parse::<kw::type_repr>()?;
432                input.parse::<Token![=]>()?;
433                type_repr = Some(input.parse::<LitStr>()?);
434            } else if lookahead.peek(kw::imports) {
435                input.parse::<kw::imports>()?;
436                input.parse::<Token![=]>()?;
437
438                let content;
439                parenthesized!(content in input);
440                let parsed_imports = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
441                imports = parsed_imports.into_iter().collect();
442            } else {
443                return Err(lookahead.error());
444            }
445
446            if !input.is_empty() {
447                input.parse::<Token![,]>()?;
448            }
449        }
450
451        Ok(OverrideTypeAttribute {
452            type_repr: type_repr
453                .ok_or_else(|| input.error("missing type_repr"))?
454                .value(),
455            imports: imports.iter().map(|i| i.value()).collect(),
456        })
457    }
458}
459
460#[cfg(test)]
461mod test {
462    use super::*;
463    use syn::{parse_str, Fields, ItemFn, ItemStruct, PatType};
464
465    #[test]
466    fn test_parse_pyo3_attr() -> Result<()> {
467        let item: ItemStruct = parse_str(
468            r#"
469            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
470            #[pyo3(rename_all = "SCREAMING_SNAKE_CASE")]
471            pub struct PyPlaceholder {
472                #[pyo3(get)]
473                pub name: String,
474            }
475            "#,
476        )?;
477        // `#[pyclass]` part
478        let attrs = parse_pyo3_attrs(&item.attrs)?;
479        assert_eq!(
480            attrs,
481            vec![
482                Attr::Module("my_module".to_string()),
483                Attr::Name("Placeholder".to_string()),
484                Attr::RenameAll(RenamingRule::ScreamingSnakeCase),
485            ]
486        );
487
488        // `#[pyo3(get)]` part
489        if let Fields::Named(fields) = item.fields {
490            let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
491            assert_eq!(attrs, vec![Attr::Get]);
492        } else {
493            unreachable!()
494        }
495        Ok(())
496    }
497
498    #[test]
499    fn test_parse_pyo3_attr_full_path() -> Result<()> {
500        let item: ItemStruct = parse_str(
501            r#"
502            #[pyo3::pyclass(mapping, module = "my_module", name = "Placeholder")]
503            pub struct PyPlaceholder {
504                #[pyo3(get)]
505                pub name: String,
506            }
507            "#,
508        )?;
509        // `#[pyclass]` part
510        let attrs = parse_pyo3_attr(&item.attrs[0])?;
511        assert_eq!(
512            attrs,
513            vec![
514                Attr::Module("my_module".to_string()),
515                Attr::Name("Placeholder".to_string())
516            ]
517        );
518
519        // `#[pyo3(get)]` part
520        if let Fields::Named(fields) = item.fields {
521            let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
522            assert_eq!(attrs, vec![Attr::Get]);
523        } else {
524            unreachable!()
525        }
526        Ok(())
527    }
528    #[test]
529    fn test_parse_gen_stub_field_attr() -> Result<()> {
530        let item: ItemStruct = parse_str(
531            r#"
532            pub struct PyPlaceholder {
533                #[gen_stub(default = String::from("foo"), skip)]
534                pub field0: String,
535                #[gen_stub(skip)]
536                pub field1: String,
537                #[gen_stub(default = 1+2)]
538                pub field2: usize,
539            }
540            "#,
541        )?;
542        let fields: Vec<_> = item.fields.into_iter().collect();
543        let field0_attrs = parse_gen_stub_attrs(&fields[0].attrs, AttributeLocation::Field, None)?;
544        if let StubGenAttr::Default(expr) = &field0_attrs[0] {
545            assert_eq!(
546                expr.to_token_stream().to_string(),
547                "String :: from (\"foo\")"
548            );
549        } else {
550            panic!("attr should be Default");
551        };
552        assert_eq!(&StubGenAttr::Skip, &field0_attrs[1]);
553        let field1_attrs = parse_gen_stub_attrs(&fields[1].attrs, AttributeLocation::Field, None)?;
554        assert_eq!(vec![StubGenAttr::Skip], field1_attrs);
555        let field2_attrs = parse_gen_stub_attrs(&fields[2].attrs, AttributeLocation::Field, None)?;
556        if let StubGenAttr::Default(expr) = &field2_attrs[0] {
557            assert_eq!(expr.to_token_stream().to_string(), "1 + 2");
558        } else {
559            panic!("attr should be Default");
560        };
561        Ok(())
562    }
563    #[test]
564    fn test_parse_gen_stub_override_type_attr() -> Result<()> {
565        let item: ItemFn = parse_str(
566            r#"
567            #[gen_stub_pyfunction]
568            #[pyfunction]
569            #[gen_stub(override_return_type(type_repr="typing.Never", imports=("typing")))]
570            fn say_hello_forever<'a>(
571                #[gen_stub(override_type(type_repr="collections.abc.Callable[[str]]", imports=("collections.abc")))]
572                cb: Bound<'a, PyAny>,
573            ) -> PyResult<()> {
574                loop {
575                    cb.call1(("Hello!",))?;
576                }
577            }
578            "#,
579        )?;
580        let fn_attrs = parse_gen_stub_attrs(&item.attrs, AttributeLocation::Function, None)?;
581        assert_eq!(fn_attrs.len(), 1);
582        if let StubGenAttr::OverrideType(expr) = &fn_attrs[0] {
583            assert_eq!(
584                *expr,
585                OverrideTypeAttribute {
586                    type_repr: "typing.Never".into(),
587                    imports: HashSet::from(["typing".into()])
588                }
589            );
590        } else {
591            panic!("attr should be OverrideType");
592        };
593        if let syn::FnArg::Typed(PatType { attrs, .. }) = &item.sig.inputs[0] {
594            let arg_attrs = parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)?;
595            assert_eq!(arg_attrs.len(), 1);
596            if let StubGenAttr::OverrideType(expr) = &arg_attrs[0] {
597                assert_eq!(
598                    *expr,
599                    OverrideTypeAttribute {
600                        type_repr: "collections.abc.Callable[[str]]".into(),
601                        imports: HashSet::from(["collections.abc".into()])
602                    }
603                );
604            } else {
605                panic!("attr should be OverrideType");
606            };
607        }
608        Ok(())
609    }
610}