pyo3_stub_gen_derive/gen_stub/
attr.rs

1use indexmap::IndexSet;
2
3use super::{RenamingRule, Signature};
4use proc_macro2::{Delimiter, TokenStream as TokenStream2, TokenTree};
5use quote::{quote, ToTokens, TokenStreamExt};
6use syn::{
7    parenthesized,
8    parse::{Parse, ParseStream},
9    punctuated::Punctuated,
10    Attribute, Expr, ExprLit, Ident, Lit, LitStr, Meta, MetaList, Result, Token, Type,
11};
12
13/// Represents the target of type ignore comments during parsing
14#[derive(Debug, Clone, PartialEq)]
15pub enum IgnoreTarget {
16    /// Ignore all type checking errors `(# type: ignore)`
17    All,
18    /// Ignore specific type checking rules (stored as LitStr during parsing)
19    SpecifiedLits(Vec<LitStr>),
20}
21
22pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
23    let mut docs = Vec::new();
24    for attr in attrs {
25        // `#[doc = "..."]` case
26        if attr.path().is_ident("doc") {
27            if let Meta::NameValue(syn::MetaNameValue {
28                value:
29                    Expr::Lit(ExprLit {
30                        lit: Lit::Str(doc), ..
31                    }),
32                ..
33            }) = &attr.meta
34            {
35                let doc = doc.value();
36                // Remove head space
37                //
38                // ```
39                // /// This is special document!
40                //    ^ This space is trimmed here
41                // ```
42                docs.push(if !doc.is_empty() && doc.starts_with(' ') {
43                    doc[1..].to_string()
44                } else {
45                    doc
46                });
47            }
48        }
49    }
50    docs
51}
52
53/// Extract `#[deprecated(...)]` attribute
54pub fn extract_deprecated(attrs: &[Attribute]) -> Option<DeprecatedInfo> {
55    for attr in attrs {
56        if attr.path().is_ident("deprecated") {
57            if let Ok(list) = attr.meta.require_list() {
58                let mut since = None;
59                let mut note = None;
60
61                list.parse_nested_meta(|meta| {
62                    if meta.path.is_ident("since") {
63                        let value = meta.value()?;
64                        let lit: LitStr = value.parse()?;
65                        since = Some(lit.value());
66                    } else if meta.path.is_ident("note") {
67                        let value = meta.value()?;
68                        let lit: LitStr = value.parse()?;
69                        note = Some(lit.value());
70                    }
71                    Ok(())
72                })
73                .ok()?;
74
75                return Some(DeprecatedInfo { since, note });
76            }
77        }
78    }
79    None
80}
81
82/// `#[pyo3(...)]` style attributes appear in `#[pyclass]` and `#[pymethods]` proc-macros
83///
84/// As the reference of PyO3 says:
85///
86/// https://docs.rs/pyo3/latest/pyo3/attr.pyclass.html
87/// > All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation,
88/// > or as one or more accompanying `#[pyo3(...)]` annotations,
89///
90/// `#[pyclass(name = "MyClass", module = "MyModule")]` will be decomposed into
91/// `#[pyclass]` + `#[pyo3(name = "MyClass")]` + `#[pyo3(module = "MyModule")]`,
92/// i.e. two `Attr`s will be created for this case.
93///
94#[derive(Debug, Clone, PartialEq)]
95pub struct DeprecatedInfo {
96    pub since: Option<String>,
97    pub note: Option<String>,
98}
99
100impl ToTokens for DeprecatedInfo {
101    fn to_tokens(&self, tokens: &mut TokenStream2) {
102        let since = self
103            .since
104            .as_ref()
105            .map(|s| quote! { Some(#s) })
106            .unwrap_or_else(|| quote! { None });
107        let note = self
108            .note
109            .as_ref()
110            .map(|n| quote! { Some(#n) })
111            .unwrap_or_else(|| quote! { None });
112        tokens.append_all(quote! {
113            ::pyo3_stub_gen::type_info::DeprecatedInfo {
114                since: #since,
115                note: #note,
116            }
117        })
118    }
119}
120
121#[derive(Debug, Clone, PartialEq)]
122#[expect(clippy::enum_variant_names)]
123pub enum Attr {
124    // Attributes appears in `#[pyo3(...)]` form or its equivalence
125    Name(String),
126    Get,
127    GetAll,
128    Set,
129    SetAll,
130    Module(String),
131    Constructor(Signature),
132    Signature(Signature),
133    RenameAll(RenamingRule),
134    Extends(Type),
135
136    // Comparison and special method attributes for pyclass
137    Eq,
138    Ord,
139    Hash,
140    Str,
141    Subclass,
142
143    // Standalone #[gen_stub(...)] attribute
144    GenStubModule(String),
145
146    // Attributes appears in components within `#[pymethods]`
147    // <https://docs.rs/pyo3/latest/pyo3/attr.pymethods.html>
148    New,
149    Getter(Option<String>),
150    Setter(Option<String>),
151    StaticMethod,
152    ClassMethod,
153    ClassAttr,
154}
155
156pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
157    let mut out = Vec::new();
158    for attr in attrs {
159        let mut new = parse_pyo3_attr(attr)?;
160        out.append(&mut new);
161        // Also parse standalone #[gen_stub(module = "...")] attributes
162        if let Some(gen_stub_attr) = parse_gen_stub_module_attr(attr)? {
163            out.push(gen_stub_attr);
164        }
165    }
166    Ok(out)
167}
168
169pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
170    let mut pyo3_attrs = Vec::new();
171    let path = attr.path();
172    let is_full_path_pyo3_attr = path.segments.len() == 2
173        && path
174            .segments
175            .first()
176            .is_some_and(|seg| seg.ident.eq("pyo3"))
177        && path.segments.last().is_some_and(|seg| {
178            seg.ident.eq("pyclass") || seg.ident.eq("pymethods") || seg.ident.eq("pyfunction")
179        });
180    if path.is_ident("pyclass")
181        || path.is_ident("pymethods")
182        || path.is_ident("pyfunction")
183        || path.is_ident("pyo3")
184        || is_full_path_pyo3_attr
185    {
186        // Inner tokens of `#[pyo3(...)]` may not be nested meta
187        // which can be parsed by `Attribute::parse_nested_meta`
188        // due to the case of `#[pyo3(signature = (...))]`.
189        // https://pyo3.rs/v0.19.1/function/signature
190        if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
191            use TokenTree::*;
192            // `macro_rules!` substitutions (e.g. `#[pyclass(name = $name)]` where
193            // `$name:literal`) wrap the inserted token in an "invisible group"
194            // (`Group { delimiter: None, .. }`). Flattening these first lets a
195            // single pattern arm handle both direct literals and macro-substituted
196            // values uniformly.
197            let tokens: Vec<TokenTree> = flatten_invisible_groups(tokens.clone()).collect();
198            // Since `(...)` part with `signature` becomes `TokenTree::Group`,
199            // we can split entire stream by `,` first, and then pattern match to each cases.
200            for tt in tokens.split(|tt| {
201                if let Punct(p) = tt {
202                    p.as_char() == ','
203                } else {
204                    false
205                }
206            }) {
207                match tt {
208                    [Ident(ident)] => {
209                        if ident == "get" {
210                            pyo3_attrs.push(Attr::Get);
211                        }
212                        if ident == "get_all" {
213                            pyo3_attrs.push(Attr::GetAll);
214                        }
215                        if ident == "set" {
216                            pyo3_attrs.push(Attr::Set);
217                        }
218                        if ident == "set_all" {
219                            pyo3_attrs.push(Attr::SetAll);
220                        }
221                        if ident == "eq" {
222                            pyo3_attrs.push(Attr::Eq);
223                        }
224                        if ident == "ord" {
225                            pyo3_attrs.push(Attr::Ord);
226                        }
227                        if ident == "hash" {
228                            pyo3_attrs.push(Attr::Hash);
229                        }
230                        if ident == "str" {
231                            pyo3_attrs.push(Attr::Str);
232                        }
233                        if ident == "subclass" {
234                            pyo3_attrs.push(Attr::Subclass);
235                        }
236                        // frozen is required by PyO3 when using hash, but doesn't affect stub generation
237                    }
238                    [Ident(ident), Punct(_), Literal(lit)] => {
239                        if ident == "name" {
240                            pyo3_attrs
241                                .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
242                        }
243                        if ident == "module" {
244                            pyo3_attrs
245                                .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
246                        }
247                        if ident == "rename_all" {
248                            let name = lit.to_string().trim_matches('"').to_string();
249                            if let Some(renaming_rule) = RenamingRule::try_new(&name) {
250                                pyo3_attrs.push(Attr::RenameAll(renaming_rule));
251                            }
252                        }
253                    }
254                    [Ident(ident), Punct(_), Group(group)] => {
255                        if ident == "signature" {
256                            pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
257                        } else if ident == "constructor" {
258                            pyo3_attrs
259                                .push(Attr::Constructor(syn::parse2(group.to_token_stream())?));
260                        }
261                    }
262                    [Ident(ident), Punct(_), Ident(ident2)] => {
263                        if ident == "extends" {
264                            pyo3_attrs.push(Attr::Extends(syn::parse2(ident2.to_token_stream())?));
265                        }
266                    }
267                    _ => {}
268                }
269            }
270        }
271    } else if path.is_ident("new") {
272        pyo3_attrs.push(Attr::New);
273    } else if path.is_ident("staticmethod") {
274        pyo3_attrs.push(Attr::StaticMethod);
275    } else if path.is_ident("classmethod") {
276        pyo3_attrs.push(Attr::ClassMethod);
277    } else if path.is_ident("classattr") {
278        pyo3_attrs.push(Attr::ClassAttr);
279    } else if path.is_ident("getter") {
280        if let Ok(inner) = attr.parse_args::<Ident>() {
281            pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
282        } else {
283            pyo3_attrs.push(Attr::Getter(None));
284        }
285    } else if path.is_ident("setter") {
286        if let Ok(inner) = attr.parse_args::<Ident>() {
287            pyo3_attrs.push(Attr::Setter(Some(inner.to_string())));
288        } else {
289            pyo3_attrs.push(Attr::Setter(None));
290        }
291    }
292
293    Ok(pyo3_attrs)
294}
295
296/// Flatten `Group { delimiter: None, .. }` tokens (produced by `macro_rules!`
297/// substitutions) into their contents so downstream pattern matching can treat
298/// macro-substituted values the same as written-out literals.
299fn flatten_invisible_groups(tokens: TokenStream2) -> impl Iterator<Item = TokenTree> {
300    tokens
301        .into_iter()
302        .flat_map(|tt| -> Box<dyn Iterator<Item = TokenTree>> {
303            match tt {
304                TokenTree::Group(g) if g.delimiter() == Delimiter::None => {
305                    Box::new(flatten_invisible_groups(g.stream()))
306                }
307                other => Box::new(std::iter::once(other)),
308            }
309        })
310}
311
312/// Parse standalone `#[gen_stub(module = "...")]` attribute for module override
313pub fn parse_gen_stub_module_attr(attr: &Attribute) -> Result<Option<Attr>> {
314    let path = attr.path();
315    if path.is_ident("gen_stub") {
316        // Parse the inner tokens to find module = "..."
317        if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
318            use TokenTree::*;
319            // See note in `parse_pyo3_attr` about invisible groups.
320            let tokens: Vec<TokenTree> = flatten_invisible_groups(tokens.clone()).collect();
321
322            // Split by comma and look for module = "..."
323            for tt in tokens.split(|tt| {
324                if let Punct(p) = tt {
325                    p.as_char() == ','
326                } else {
327                    false
328                }
329            }) {
330                match tt {
331                    [Ident(ident), Punct(_), Literal(lit)] if ident == "module" => {
332                        return Ok(Some(Attr::GenStubModule(
333                            lit.to_string().trim_matches('"').to_string(),
334                        )));
335                    }
336                    _ => {}
337                }
338            }
339        }
340    }
341    Ok(None)
342}
343
344#[derive(Debug, Clone, PartialEq)]
345pub enum StubGenAttr {
346    /// Default value for getter
347    Default(Expr),
348    /// Skip a function in #[pymethods]
349    Skip,
350    /// Override the python type for a function argument or return type
351    OverrideType(OverrideTypeAttribute),
352    /// Type checker rules to ignore for this function/method
353    TypeIgnore(IgnoreTarget),
354}
355
356pub fn prune_attrs(attrs: &mut Vec<Attribute>) {
357    attrs.retain(|attr| !attr.path().is_ident("gen_stub"));
358}
359
360pub fn parse_gen_stub_override_type(attrs: &[Attribute]) -> Result<Option<OverrideTypeAttribute>> {
361    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)? {
362        if let StubGenAttr::OverrideType(attr) = attr {
363            return Ok(Some(attr));
364        }
365    }
366    Ok(None)
367}
368
369pub fn parse_gen_stub_override_return_type(
370    attrs: &[Attribute],
371) -> Result<Option<OverrideTypeAttribute>> {
372    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
373        if let StubGenAttr::OverrideType(attr) = attr {
374            return Ok(Some(attr));
375        }
376    }
377    Ok(None)
378}
379
380pub fn parse_gen_stub_default(attrs: &[Attribute]) -> Result<Option<Expr>> {
381    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
382        if let StubGenAttr::Default(default) = attr {
383            return Ok(Some(default));
384        }
385    }
386    Ok(None)
387}
388pub fn parse_gen_stub_skip(attrs: &[Attribute]) -> Result<bool> {
389    let skip = parse_gen_stub_attrs(
390        attrs,
391        AttributeLocation::Field,
392        Some(&["override_return_type", "default"]),
393    )?
394    .iter()
395    .any(|attr| matches!(attr, StubGenAttr::Skip));
396    Ok(skip)
397}
398
399pub fn parse_gen_stub_type_ignore(attrs: &[Attribute]) -> Result<Option<IgnoreTarget>> {
400    // Try Function location first (for regular functions)
401    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
402        if let StubGenAttr::TypeIgnore(target) = attr {
403            return Ok(Some(target));
404        }
405    }
406    // Try Field location (for methods in #[pymethods] blocks)
407    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Field, None)? {
408        if let StubGenAttr::TypeIgnore(target) = attr {
409            return Ok(Some(target));
410        }
411    }
412    Ok(None)
413}
414
415fn parse_gen_stub_attrs(
416    attrs: &[Attribute],
417    location: AttributeLocation,
418    ignored_idents: Option<&[&str]>,
419) -> Result<Vec<StubGenAttr>> {
420    let mut out = Vec::new();
421    for attr in attrs {
422        let mut new = parse_gen_stub_attr(attr, location, ignored_idents.unwrap_or(&[]))?;
423        out.append(&mut new);
424    }
425    Ok(out)
426}
427
428fn parse_gen_stub_attr(
429    attr: &Attribute,
430    location: AttributeLocation,
431    ignored_idents: &[&str],
432) -> Result<Vec<StubGenAttr>> {
433    let mut gen_stub_attrs = Vec::new();
434    let path = attr.path();
435    if path.is_ident("gen_stub") {
436        attr.parse_args_with(|input: ParseStream| {
437            while !input.is_empty() {
438                let ident: Ident = input.parse()?;
439                let ignored_ident = ignored_idents.iter().any(|other| ident == other);
440                if (ident == "override_type"
441                    && (location == AttributeLocation::Argument || ignored_ident))
442                    || (ident == "override_return_type"
443                        && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident))
444                {
445                    let content;
446                    parenthesized!(content in input);
447                    let override_attr: OverrideTypeAttribute = content.parse()?;
448                    gen_stub_attrs.push(StubGenAttr::OverrideType(override_attr));
449                } else if ident == "skip" && (location == AttributeLocation::Field || ignored_ident)
450                {
451                    gen_stub_attrs.push(StubGenAttr::Skip);
452                } else if ident == "default"
453                    && input.peek(Token![=])
454                    && (location == AttributeLocation::Field || location == AttributeLocation::Function || ignored_ident)
455                {
456                    input.parse::<Token![=]>()?;
457                    gen_stub_attrs.push(StubGenAttr::Default(input.parse()?));
458                } else if ident == "type_ignore"
459                    && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident)
460                {
461                    // Handle two cases:
462                    // 1. type_ignore (without equals) -> IgnoreTarget::All
463                    // 2. type_ignore = [...] -> IgnoreTarget::Specified(rules)
464                    if input.peek(Token![=]) {
465                        input.parse::<Token![=]>()?;
466                        // Parse array of rule names
467                        let content;
468                        syn::bracketed!(content in input);
469                        let rules = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
470
471                        // Validate: empty Specified should be an error
472                        if rules.is_empty() {
473                            return Err(syn::Error::new(
474                                ident.span(),
475                                "type_ignore with empty array is not allowed. Use type_ignore without equals for catch-all, or specify rules in the array."
476                            ));
477                        }
478
479                        // Store the rules as LitStr for now, will be converted to strings during code generation
480                        let rule_lits: Vec<LitStr> = rules.into_iter().collect();
481                        gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::SpecifiedLits(rule_lits)));
482                    } else {
483                        // No equals sign means catch-all
484                        gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::All));
485                    }
486                } else if ident == "override_type" {
487                    return Err(syn::Error::new(
488                        ident.span(),
489                        "`override_type(...)` is only valid in argument position".to_string(),
490                    ));
491                } else if ident == "override_return_type" {
492                    return Err(syn::Error::new(
493                        ident.span(),
494                        "`override_return_type(...)` is only valid in function or method position"
495                            .to_string(),
496                    ));
497                } else if ident == "skip" {
498                    return Err(syn::Error::new(
499                        ident.span(),
500                        "`skip` is only valid in field position".to_string(),
501                    ));
502                } else if ident == "default" {
503                    return Err(syn::Error::new(
504                        ident.span(),
505                        "`default=xxx` is only valid in field or function position".to_string(),
506                    ));
507                } else if ident == "type_ignore" {
508                    return Err(syn::Error::new(
509                        ident.span(),
510                        "`type_ignore` or `type_ignore=[...]` is only valid in function or method position".to_string(),
511                    ));
512                } else if location == AttributeLocation::Argument {
513                    return Err(syn::Error::new(
514                        ident.span(),
515                        format!("Unsupported keyword `{ident}`, valid is `override_type(...)`"),
516                    ));
517                } else if location == AttributeLocation::Field {
518                    return Err(syn::Error::new(
519                        ident.span(),
520                        format!("Unsupported keyword `{ident}`, valid is `default=xxx`, `skip`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"),
521                    ));
522                } else if location == AttributeLocation::Function {
523                    return Err(syn::Error::new(
524                        ident.span(),
525                        format!(
526                            "Unsupported keyword `{ident}`, valid is `default=xxx`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"
527                        ),
528                    ));
529                } else {
530                    return Err(syn::Error::new(
531                        ident.span(),
532                        format!("Unsupported keyword `{ident}`"),
533                    ));
534                }
535                if input.peek(Token![,]) {
536                    input.parse::<Token![,]>()?;
537                } else {
538                    break;
539                }
540            }
541            Ok(())
542        })?;
543    }
544    Ok(gen_stub_attrs)
545}
546
547#[derive(Debug, Clone, Copy, PartialEq)]
548pub(crate) enum AttributeLocation {
549    Argument,
550    Field,
551    Function,
552}
553
554#[derive(Debug, Clone, PartialEq)]
555pub struct OverrideTypeAttribute {
556    pub(crate) type_repr: String,
557    pub(crate) imports: IndexSet<String>,
558}
559
560mod kw {
561    syn::custom_keyword!(type_repr);
562    syn::custom_keyword!(imports);
563    syn::custom_keyword!(override_type);
564}
565
566impl Parse for OverrideTypeAttribute {
567    fn parse(input: ParseStream) -> Result<Self> {
568        let mut type_repr = None;
569        let mut imports = IndexSet::new();
570
571        while !input.is_empty() {
572            let lookahead = input.lookahead1();
573
574            if lookahead.peek(kw::type_repr) {
575                input.parse::<kw::type_repr>()?;
576                input.parse::<Token![=]>()?;
577                type_repr = Some(input.parse::<LitStr>()?);
578            } else if lookahead.peek(kw::imports) {
579                input.parse::<kw::imports>()?;
580                input.parse::<Token![=]>()?;
581
582                let content;
583                parenthesized!(content in input);
584                let parsed_imports = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
585                imports = parsed_imports.into_iter().collect();
586            } else {
587                return Err(lookahead.error());
588            }
589
590            if !input.is_empty() {
591                input.parse::<Token![,]>()?;
592            }
593        }
594
595        Ok(OverrideTypeAttribute {
596            type_repr: type_repr
597                .ok_or_else(|| input.error("missing type_repr"))?
598                .value(),
599            imports: imports.iter().map(|i| i.value()).collect(),
600        })
601    }
602}
603
604/// Common attributes for `#[gen_stub_pyclass(...)]`, `#[gen_stub_pyclass_enum(...)]`,
605/// and `#[gen_stub_pyclass_complex_enum(...)]` macros
606#[derive(Default)]
607pub struct PyClassAttr {
608    pub skip_stub_type: bool,
609    pub module: Option<String>,
610}
611
612impl Parse for PyClassAttr {
613    fn parse(input: ParseStream) -> Result<Self> {
614        let mut skip_stub_type = false;
615        let mut module = None;
616
617        // Parse comma-separated flags
618        while !input.is_empty() {
619            let key: Ident = input.parse()?;
620
621            match key.to_string().as_str() {
622                "skip_stub_type" => {
623                    skip_stub_type = true;
624                }
625                "module" => {
626                    let _: Token![=] = input.parse()?;
627                    let value: LitStr = input.parse()?;
628                    module = Some(value.value());
629                }
630                _ => {
631                    return Err(syn::Error::new(
632                        key.span(),
633                        format!("Unknown parameter: {}", key),
634                    ));
635                }
636            }
637
638            // Check for comma separator
639            if input.peek(Token![,]) {
640                let _: Token![,] = input.parse()?;
641            } else {
642                break;
643            }
644        }
645
646        Ok(Self {
647            skip_stub_type,
648            module,
649        })
650    }
651}
652
653#[cfg(test)]
654mod test {
655    use super::*;
656    use syn::{parse_str, Fields, ItemFn, ItemStruct, PatType};
657
658    #[test]
659    fn test_parse_pyo3_attr() -> Result<()> {
660        let item: ItemStruct = parse_str(
661            r#"
662            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
663            #[pyo3(rename_all = "SCREAMING_SNAKE_CASE")]
664            pub struct PyPlaceholder {
665                #[pyo3(get)]
666                pub name: String,
667            }
668            "#,
669        )?;
670        // `#[pyclass]` part
671        let attrs = parse_pyo3_attrs(&item.attrs)?;
672        assert_eq!(
673            attrs,
674            vec![
675                Attr::Module("my_module".to_string()),
676                Attr::Name("Placeholder".to_string()),
677                Attr::RenameAll(RenamingRule::ScreamingSnakeCase),
678            ]
679        );
680
681        // `#[pyo3(get)]` part
682        if let Fields::Named(fields) = item.fields {
683            let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
684            assert_eq!(attrs, vec![Attr::Get]);
685        } else {
686            unreachable!()
687        }
688        Ok(())
689    }
690
691    #[test]
692    fn test_parse_pyo3_attr_full_path() -> Result<()> {
693        let item: ItemStruct = parse_str(
694            r#"
695            #[pyo3::pyclass(mapping, module = "my_module", name = "Placeholder")]
696            pub struct PyPlaceholder {
697                #[pyo3(get)]
698                pub name: String,
699            }
700            "#,
701        )?;
702        // `#[pyclass]` part
703        let attrs = parse_pyo3_attr(&item.attrs[0])?;
704        assert_eq!(
705            attrs,
706            vec![
707                Attr::Module("my_module".to_string()),
708                Attr::Name("Placeholder".to_string())
709            ]
710        );
711
712        // `#[pyo3(get)]` part
713        if let Fields::Named(fields) = item.fields {
714            let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
715            assert_eq!(attrs, vec![Attr::Get]);
716        } else {
717            unreachable!()
718        }
719        Ok(())
720    }
721    /// Build a `Vec<Attribute>` whose tokens contain an "invisible group"
722    /// (`Group { delimiter: None, .. }`) for the value of a key, simulating
723    /// what `macro_rules!` produces when substituting a meta-variable into an
724    /// attribute argument list (e.g. `#[pyclass(name = $name)]`).
725    fn attrs_with_invisible_group(
726        attr_path: TokenStream2,
727        prefix: TokenStream2,
728        value: TokenStream2,
729    ) -> Vec<Attribute> {
730        let invisible = TokenTree::Group(proc_macro2::Group::new(Delimiter::None, value));
731        let item: ItemStruct = syn::parse2(quote! {
732            #[#attr_path(#prefix = #invisible)]
733            struct Dummy;
734        })
735        .unwrap();
736        item.attrs
737    }
738
739    #[test]
740    fn test_parse_pyo3_attr_name_from_macro_substitution() -> Result<()> {
741        // Simulates `macro_rules! { ($n:literal) => { #[pyclass(name = $n)] ... } }`.
742        let attrs = attrs_with_invisible_group(quote!(pyclass), quote!(name), quote!("FromMacro"));
743        let parsed = parse_pyo3_attrs(&attrs)?;
744        assert_eq!(parsed, vec![Attr::Name("FromMacro".to_string())]);
745        Ok(())
746    }
747
748    #[test]
749    fn test_parse_pyo3_attr_module_from_macro_substitution() -> Result<()> {
750        let attrs =
751            attrs_with_invisible_group(quote!(pyclass), quote!(module), quote!("my.mod.path"));
752        let parsed = parse_pyo3_attrs(&attrs)?;
753        assert_eq!(parsed, vec![Attr::Module("my.mod.path".to_string())]);
754        Ok(())
755    }
756
757    #[test]
758    fn test_parse_pyo3_attr_rename_all_from_macro_substitution() -> Result<()> {
759        let attrs = attrs_with_invisible_group(
760            quote!(pyo3),
761            quote!(rename_all),
762            quote!("SCREAMING_SNAKE_CASE"),
763        );
764        let parsed = parse_pyo3_attrs(&attrs)?;
765        assert_eq!(
766            parsed,
767            vec![Attr::RenameAll(RenamingRule::ScreamingSnakeCase)]
768        );
769        Ok(())
770    }
771
772    #[test]
773    fn test_parse_gen_stub_module_attr_from_macro_substitution() -> Result<()> {
774        let attrs =
775            attrs_with_invisible_group(quote!(gen_stub), quote!(module), quote!("explicit.mod"));
776        // `parse_pyo3_attrs` also dispatches to `parse_gen_stub_module_attr`.
777        let parsed = parse_pyo3_attrs(&attrs)?;
778        assert_eq!(
779            parsed,
780            vec![Attr::GenStubModule("explicit.mod".to_string())]
781        );
782        Ok(())
783    }
784
785    #[test]
786    fn test_parse_gen_stub_field_attr() -> Result<()> {
787        let item: ItemStruct = parse_str(
788            r#"
789            pub struct PyPlaceholder {
790                #[gen_stub(default = String::from("foo"), skip)]
791                pub field0: String,
792                #[gen_stub(skip)]
793                pub field1: String,
794                #[gen_stub(default = 1+2)]
795                pub field2: usize,
796            }
797            "#,
798        )?;
799        let fields: Vec<_> = item.fields.into_iter().collect();
800        let field0_attrs = parse_gen_stub_attrs(&fields[0].attrs, AttributeLocation::Field, None)?;
801        if let StubGenAttr::Default(expr) = &field0_attrs[0] {
802            assert_eq!(
803                expr.to_token_stream().to_string(),
804                "String :: from (\"foo\")"
805            );
806        } else {
807            panic!("attr should be Default");
808        };
809        assert_eq!(&StubGenAttr::Skip, &field0_attrs[1]);
810        let field1_attrs = parse_gen_stub_attrs(&fields[1].attrs, AttributeLocation::Field, None)?;
811        assert_eq!(vec![StubGenAttr::Skip], field1_attrs);
812        let field2_attrs = parse_gen_stub_attrs(&fields[2].attrs, AttributeLocation::Field, None)?;
813        if let StubGenAttr::Default(expr) = &field2_attrs[0] {
814            assert_eq!(expr.to_token_stream().to_string(), "1 + 2");
815        } else {
816            panic!("attr should be Default");
817        };
818        Ok(())
819    }
820    #[test]
821    fn test_parse_gen_stub_override_type_attr() -> Result<()> {
822        let item: ItemFn = parse_str(
823            r#"
824            #[gen_stub_pyfunction]
825            #[pyfunction]
826            #[gen_stub(override_return_type(type_repr="typing.Never", imports=("typing")))]
827            fn say_hello_forever<'a>(
828                #[gen_stub(override_type(type_repr="collections.abc.Callable[[str]]", imports=("collections.abc")))]
829                cb: Bound<'a, PyAny>,
830            ) -> PyResult<()> {
831                loop {
832                    cb.call1(("Hello!",))?;
833                }
834            }
835            "#,
836        )?;
837        let fn_attrs = parse_gen_stub_attrs(&item.attrs, AttributeLocation::Function, None)?;
838        assert_eq!(fn_attrs.len(), 1);
839        if let StubGenAttr::OverrideType(expr) = &fn_attrs[0] {
840            assert_eq!(
841                *expr,
842                OverrideTypeAttribute {
843                    type_repr: "typing.Never".into(),
844                    imports: IndexSet::from(["typing".into()])
845                }
846            );
847        } else {
848            panic!("attr should be OverrideType");
849        };
850        if let syn::FnArg::Typed(PatType { attrs, .. }) = &item.sig.inputs[0] {
851            let arg_attrs = parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)?;
852            assert_eq!(arg_attrs.len(), 1);
853            if let StubGenAttr::OverrideType(expr) = &arg_attrs[0] {
854                assert_eq!(
855                    *expr,
856                    OverrideTypeAttribute {
857                        type_repr: "collections.abc.Callable[[str]]".into(),
858                        imports: IndexSet::from(["collections.abc".into()])
859                    }
860                );
861            } else {
862                panic!("attr should be OverrideType");
863            };
864        }
865        Ok(())
866    }
867}