pyo3_stub_gen_derive/gen_stub/
attr.rs

1use std::collections::HashSet;
2
3use super::{RenamingRule, Signature};
4use proc_macro2::{TokenStream as TokenStream2, TokenTree};
5use quote::{quote, ToTokens, TokenStreamExt};
6use syn::{
7    parenthesized,
8    parse::{Parse, ParseStream},
9    punctuated::Punctuated,
10    Attribute, Expr, ExprLit, Ident, Lit, LitStr, Meta, MetaList, Result, Token, Type,
11};
12
13/// Represents the target of type ignore comments during parsing
14#[derive(Debug, Clone, PartialEq)]
15pub enum IgnoreTarget {
16    /// Ignore all type checking errors `(# type: ignore)`
17    All,
18    /// Ignore specific type checking rules (stored as LitStr during parsing)
19    SpecifiedLits(Vec<LitStr>),
20}
21
22pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
23    let mut docs = Vec::new();
24    for attr in attrs {
25        // `#[doc = "..."]` case
26        if attr.path().is_ident("doc") {
27            if let Meta::NameValue(syn::MetaNameValue {
28                value:
29                    Expr::Lit(ExprLit {
30                        lit: Lit::Str(doc), ..
31                    }),
32                ..
33            }) = &attr.meta
34            {
35                let doc = doc.value();
36                // Remove head space
37                //
38                // ```
39                // /// This is special document!
40                //    ^ This space is trimmed here
41                // ```
42                docs.push(if !doc.is_empty() && doc.starts_with(' ') {
43                    doc[1..].to_string()
44                } else {
45                    doc
46                });
47            }
48        }
49    }
50    docs
51}
52
53/// Extract `#[deprecated(...)]` attribute
54pub fn extract_deprecated(attrs: &[Attribute]) -> Option<DeprecatedInfo> {
55    for attr in attrs {
56        if attr.path().is_ident("deprecated") {
57            if let Ok(list) = attr.meta.require_list() {
58                let mut since = None;
59                let mut note = None;
60
61                list.parse_nested_meta(|meta| {
62                    if meta.path.is_ident("since") {
63                        let value = meta.value()?;
64                        let lit: LitStr = value.parse()?;
65                        since = Some(lit.value());
66                    } else if meta.path.is_ident("note") {
67                        let value = meta.value()?;
68                        let lit: LitStr = value.parse()?;
69                        note = Some(lit.value());
70                    }
71                    Ok(())
72                })
73                .ok()?;
74
75                return Some(DeprecatedInfo { since, note });
76            }
77        }
78    }
79    None
80}
81
82/// `#[pyo3(...)]` style attributes appear in `#[pyclass]` and `#[pymethods]` proc-macros
83///
84/// As the reference of PyO3 says:
85///
86/// https://docs.rs/pyo3/latest/pyo3/attr.pyclass.html
87/// > All of these parameters can either be passed directly on the `#[pyclass(...)]` annotation,
88/// > or as one or more accompanying `#[pyo3(...)]` annotations,
89///
90/// `#[pyclass(name = "MyClass", module = "MyModule")]` will be decomposed into
91/// `#[pyclass]` + `#[pyo3(name = "MyClass")]` + `#[pyo3(module = "MyModule")]`,
92/// i.e. two `Attr`s will be created for this case.
93///
94#[derive(Debug, Clone, PartialEq)]
95pub struct DeprecatedInfo {
96    pub since: Option<String>,
97    pub note: Option<String>,
98}
99
100impl ToTokens for DeprecatedInfo {
101    fn to_tokens(&self, tokens: &mut TokenStream2) {
102        let since = self
103            .since
104            .as_ref()
105            .map(|s| quote! { Some(#s) })
106            .unwrap_or_else(|| quote! { None });
107        let note = self
108            .note
109            .as_ref()
110            .map(|n| quote! { Some(#n) })
111            .unwrap_or_else(|| quote! { None });
112        tokens.append_all(quote! {
113            ::pyo3_stub_gen::type_info::DeprecatedInfo {
114                since: #since,
115                note: #note,
116            }
117        })
118    }
119}
120
121#[derive(Debug, Clone, PartialEq)]
122#[expect(clippy::enum_variant_names)]
123pub enum Attr {
124    // Attributes appears in `#[pyo3(...)]` form or its equivalence
125    Name(String),
126    Get,
127    GetAll,
128    Set,
129    SetAll,
130    Module(String),
131    Constructor(Signature),
132    Signature(Signature),
133    RenameAll(RenamingRule),
134    Extends(Type),
135
136    // Comparison and special method attributes for pyclass
137    Eq,
138    Ord,
139    Hash,
140    Str,
141
142    // Attributes appears in components within `#[pymethods]`
143    // <https://docs.rs/pyo3/latest/pyo3/attr.pymethods.html>
144    New,
145    Getter(Option<String>),
146    Setter(Option<String>),
147    StaticMethod,
148    ClassMethod,
149    ClassAttr,
150}
151
152pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
153    let mut out = Vec::new();
154    for attr in attrs {
155        let mut new = parse_pyo3_attr(attr)?;
156        out.append(&mut new);
157    }
158    Ok(out)
159}
160
161pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
162    let mut pyo3_attrs = Vec::new();
163    let path = attr.path();
164    let is_full_path_pyo3_attr = path.segments.len() == 2
165        && path
166            .segments
167            .first()
168            .is_some_and(|seg| seg.ident.eq("pyo3"))
169        && path.segments.last().is_some_and(|seg| {
170            seg.ident.eq("pyclass") || seg.ident.eq("pymethods") || seg.ident.eq("pyfunction")
171        });
172    if path.is_ident("pyclass")
173        || path.is_ident("pymethods")
174        || path.is_ident("pyfunction")
175        || path.is_ident("pyo3")
176        || is_full_path_pyo3_attr
177    {
178        // Inner tokens of `#[pyo3(...)]` may not be nested meta
179        // which can be parsed by `Attribute::parse_nested_meta`
180        // due to the case of `#[pyo3(signature = (...))]`.
181        // https://pyo3.rs/v0.19.1/function/signature
182        if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
183            use TokenTree::*;
184            let tokens: Vec<TokenTree> = tokens.clone().into_iter().collect();
185            // Since `(...)` part with `signature` becomes `TokenTree::Group`,
186            // we can split entire stream by `,` first, and then pattern match to each cases.
187            for tt in tokens.split(|tt| {
188                if let Punct(p) = tt {
189                    p.as_char() == ','
190                } else {
191                    false
192                }
193            }) {
194                match tt {
195                    [Ident(ident)] => {
196                        if ident == "get" {
197                            pyo3_attrs.push(Attr::Get);
198                        }
199                        if ident == "get_all" {
200                            pyo3_attrs.push(Attr::GetAll);
201                        }
202                        if ident == "set" {
203                            pyo3_attrs.push(Attr::Set);
204                        }
205                        if ident == "set_all" {
206                            pyo3_attrs.push(Attr::SetAll);
207                        }
208                        if ident == "eq" {
209                            pyo3_attrs.push(Attr::Eq);
210                        }
211                        if ident == "ord" {
212                            pyo3_attrs.push(Attr::Ord);
213                        }
214                        if ident == "hash" {
215                            pyo3_attrs.push(Attr::Hash);
216                        }
217                        if ident == "str" {
218                            pyo3_attrs.push(Attr::Str);
219                        }
220                        // frozen is required by PyO3 when using hash, but doesn't affect stub generation
221                    }
222                    [Ident(ident), Punct(_), Literal(lit)] => {
223                        if ident == "name" {
224                            pyo3_attrs
225                                .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
226                        }
227                        if ident == "module" {
228                            pyo3_attrs
229                                .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
230                        }
231                        if ident == "rename_all" {
232                            let name = lit.to_string().trim_matches('"').to_string();
233                            if let Some(renaming_rule) = RenamingRule::try_new(&name) {
234                                pyo3_attrs.push(Attr::RenameAll(renaming_rule));
235                            }
236                        }
237                    }
238                    [Ident(ident), Punct(_), Group(group)] => {
239                        if ident == "signature" {
240                            pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
241                        } else if ident == "constructor" {
242                            pyo3_attrs
243                                .push(Attr::Constructor(syn::parse2(group.to_token_stream())?));
244                        }
245                    }
246                    [Ident(ident), Punct(_), Ident(ident2)] => {
247                        if ident == "extends" {
248                            pyo3_attrs.push(Attr::Extends(syn::parse2(ident2.to_token_stream())?));
249                        }
250                    }
251                    _ => {}
252                }
253            }
254        }
255    } else if path.is_ident("new") {
256        pyo3_attrs.push(Attr::New);
257    } else if path.is_ident("staticmethod") {
258        pyo3_attrs.push(Attr::StaticMethod);
259    } else if path.is_ident("classmethod") {
260        pyo3_attrs.push(Attr::ClassMethod);
261    } else if path.is_ident("classattr") {
262        pyo3_attrs.push(Attr::ClassAttr);
263    } else if path.is_ident("getter") {
264        if let Ok(inner) = attr.parse_args::<Ident>() {
265            pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
266        } else {
267            pyo3_attrs.push(Attr::Getter(None));
268        }
269    } else if path.is_ident("setter") {
270        if let Ok(inner) = attr.parse_args::<Ident>() {
271            pyo3_attrs.push(Attr::Setter(Some(inner.to_string())));
272        } else {
273            pyo3_attrs.push(Attr::Setter(None));
274        }
275    }
276
277    Ok(pyo3_attrs)
278}
279
280#[derive(Debug, Clone, PartialEq)]
281pub enum StubGenAttr {
282    /// Default value for getter
283    Default(Expr),
284    /// Skip a function in #[pymethods]
285    Skip,
286    /// Override the python type for a function argument or return type
287    OverrideType(OverrideTypeAttribute),
288    /// Type checker rules to ignore for this function/method
289    TypeIgnore(IgnoreTarget),
290}
291
292pub fn prune_attrs(attrs: &mut Vec<Attribute>) {
293    attrs.retain(|attr| !attr.path().is_ident("gen_stub"));
294}
295
296pub fn parse_gen_stub_override_type(attrs: &[Attribute]) -> Result<Option<OverrideTypeAttribute>> {
297    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)? {
298        if let StubGenAttr::OverrideType(attr) = attr {
299            return Ok(Some(attr));
300        }
301    }
302    Ok(None)
303}
304
305pub fn parse_gen_stub_override_return_type(
306    attrs: &[Attribute],
307) -> Result<Option<OverrideTypeAttribute>> {
308    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
309        if let StubGenAttr::OverrideType(attr) = attr {
310            return Ok(Some(attr));
311        }
312    }
313    Ok(None)
314}
315
316pub fn parse_gen_stub_default(attrs: &[Attribute]) -> Result<Option<Expr>> {
317    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
318        if let StubGenAttr::Default(default) = attr {
319            return Ok(Some(default));
320        }
321    }
322    Ok(None)
323}
324pub fn parse_gen_stub_skip(attrs: &[Attribute]) -> Result<bool> {
325    let skip = parse_gen_stub_attrs(
326        attrs,
327        AttributeLocation::Field,
328        Some(&["override_return_type", "default"]),
329    )?
330    .iter()
331    .any(|attr| matches!(attr, StubGenAttr::Skip));
332    Ok(skip)
333}
334
335pub fn parse_gen_stub_type_ignore(attrs: &[Attribute]) -> Result<Option<IgnoreTarget>> {
336    // Try Function location first (for regular functions)
337    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
338        if let StubGenAttr::TypeIgnore(target) = attr {
339            return Ok(Some(target));
340        }
341    }
342    // Try Field location (for methods in #[pymethods] blocks)
343    for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Field, None)? {
344        if let StubGenAttr::TypeIgnore(target) = attr {
345            return Ok(Some(target));
346        }
347    }
348    Ok(None)
349}
350
351fn parse_gen_stub_attrs(
352    attrs: &[Attribute],
353    location: AttributeLocation,
354    ignored_idents: Option<&[&str]>,
355) -> Result<Vec<StubGenAttr>> {
356    let mut out = Vec::new();
357    for attr in attrs {
358        let mut new = parse_gen_stub_attr(attr, location, ignored_idents.unwrap_or(&[]))?;
359        out.append(&mut new);
360    }
361    Ok(out)
362}
363
364fn parse_gen_stub_attr(
365    attr: &Attribute,
366    location: AttributeLocation,
367    ignored_idents: &[&str],
368) -> Result<Vec<StubGenAttr>> {
369    let mut gen_stub_attrs = Vec::new();
370    let path = attr.path();
371    if path.is_ident("gen_stub") {
372        attr.parse_args_with(|input: ParseStream| {
373            while !input.is_empty() {
374                let ident: Ident = input.parse()?;
375                let ignored_ident = ignored_idents.iter().any(|other| ident == other);
376                if (ident == "override_type"
377                    && (location == AttributeLocation::Argument || ignored_ident))
378                    || (ident == "override_return_type"
379                        && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident))
380                {
381                    let content;
382                    parenthesized!(content in input);
383                    let override_attr: OverrideTypeAttribute = content.parse()?;
384                    gen_stub_attrs.push(StubGenAttr::OverrideType(override_attr));
385                } else if ident == "skip" && (location == AttributeLocation::Field || ignored_ident)
386                {
387                    gen_stub_attrs.push(StubGenAttr::Skip);
388                } else if ident == "default"
389                    && input.peek(Token![=])
390                    && (location == AttributeLocation::Field || location == AttributeLocation::Function || ignored_ident)
391                {
392                    input.parse::<Token![=]>()?;
393                    gen_stub_attrs.push(StubGenAttr::Default(input.parse()?));
394                } else if ident == "type_ignore"
395                    && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident)
396                {
397                    // Handle two cases:
398                    // 1. type_ignore (without equals) -> IgnoreTarget::All
399                    // 2. type_ignore = [...] -> IgnoreTarget::Specified(rules)
400                    if input.peek(Token![=]) {
401                        input.parse::<Token![=]>()?;
402                        // Parse array of rule names
403                        let content;
404                        syn::bracketed!(content in input);
405                        let rules = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
406
407                        // Validate: empty Specified should be an error
408                        if rules.is_empty() {
409                            return Err(syn::Error::new(
410                                ident.span(),
411                                "type_ignore with empty array is not allowed. Use type_ignore without equals for catch-all, or specify rules in the array."
412                            ));
413                        }
414
415                        // Store the rules as LitStr for now, will be converted to strings during code generation
416                        let rule_lits: Vec<LitStr> = rules.into_iter().collect();
417                        gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::SpecifiedLits(rule_lits)));
418                    } else {
419                        // No equals sign means catch-all
420                        gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::All));
421                    }
422                } else if ident == "override_type" {
423                    return Err(syn::Error::new(
424                        ident.span(),
425                        "`override_type(...)` is only valid in argument position".to_string(),
426                    ));
427                } else if ident == "override_return_type" {
428                    return Err(syn::Error::new(
429                        ident.span(),
430                        "`override_return_type(...)` is only valid in function or method position"
431                            .to_string(),
432                    ));
433                } else if ident == "skip" {
434                    return Err(syn::Error::new(
435                        ident.span(),
436                        "`skip` is only valid in field position".to_string(),
437                    ));
438                } else if ident == "default" {
439                    return Err(syn::Error::new(
440                        ident.span(),
441                        "`default=xxx` is only valid in field or function position".to_string(),
442                    ));
443                } else if ident == "type_ignore" {
444                    return Err(syn::Error::new(
445                        ident.span(),
446                        "`type_ignore` or `type_ignore=[...]` is only valid in function or method position".to_string(),
447                    ));
448                } else if location == AttributeLocation::Argument {
449                    return Err(syn::Error::new(
450                        ident.span(),
451                        format!("Unsupported keyword `{ident}`, valid is `override_type(...)`"),
452                    ));
453                } else if location == AttributeLocation::Field {
454                    return Err(syn::Error::new(
455                        ident.span(),
456                        format!("Unsupported keyword `{ident}`, valid is `default=xxx`, `skip`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"),
457                    ));
458                } else if location == AttributeLocation::Function {
459                    return Err(syn::Error::new(
460                        ident.span(),
461                        format!(
462                            "Unsupported keyword `{ident}`, valid is `default=xxx`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"
463                        ),
464                    ));
465                } else {
466                    return Err(syn::Error::new(
467                        ident.span(),
468                        format!("Unsupported keyword `{ident}`"),
469                    ));
470                }
471                if input.peek(Token![,]) {
472                    input.parse::<Token![,]>()?;
473                } else {
474                    break;
475                }
476            }
477            Ok(())
478        })?;
479    }
480    Ok(gen_stub_attrs)
481}
482
483#[derive(Debug, Clone, Copy, PartialEq)]
484pub(crate) enum AttributeLocation {
485    Argument,
486    Field,
487    Function,
488}
489
490#[derive(Debug, Clone, PartialEq)]
491pub struct OverrideTypeAttribute {
492    pub(crate) type_repr: String,
493    pub(crate) imports: HashSet<String>,
494}
495
496mod kw {
497    syn::custom_keyword!(type_repr);
498    syn::custom_keyword!(imports);
499    syn::custom_keyword!(override_type);
500}
501
502impl Parse for OverrideTypeAttribute {
503    fn parse(input: ParseStream) -> Result<Self> {
504        let mut type_repr = None;
505        let mut imports = HashSet::new();
506
507        while !input.is_empty() {
508            let lookahead = input.lookahead1();
509
510            if lookahead.peek(kw::type_repr) {
511                input.parse::<kw::type_repr>()?;
512                input.parse::<Token![=]>()?;
513                type_repr = Some(input.parse::<LitStr>()?);
514            } else if lookahead.peek(kw::imports) {
515                input.parse::<kw::imports>()?;
516                input.parse::<Token![=]>()?;
517
518                let content;
519                parenthesized!(content in input);
520                let parsed_imports = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
521                imports = parsed_imports.into_iter().collect();
522            } else {
523                return Err(lookahead.error());
524            }
525
526            if !input.is_empty() {
527                input.parse::<Token![,]>()?;
528            }
529        }
530
531        Ok(OverrideTypeAttribute {
532            type_repr: type_repr
533                .ok_or_else(|| input.error("missing type_repr"))?
534                .value(),
535            imports: imports.iter().map(|i| i.value()).collect(),
536        })
537    }
538}
539
540#[cfg(test)]
541mod test {
542    use super::*;
543    use syn::{parse_str, Fields, ItemFn, ItemStruct, PatType};
544
545    #[test]
546    fn test_parse_pyo3_attr() -> Result<()> {
547        let item: ItemStruct = parse_str(
548            r#"
549            #[pyclass(mapping, module = "my_module", name = "Placeholder")]
550            #[pyo3(rename_all = "SCREAMING_SNAKE_CASE")]
551            pub struct PyPlaceholder {
552                #[pyo3(get)]
553                pub name: String,
554            }
555            "#,
556        )?;
557        // `#[pyclass]` part
558        let attrs = parse_pyo3_attrs(&item.attrs)?;
559        assert_eq!(
560            attrs,
561            vec![
562                Attr::Module("my_module".to_string()),
563                Attr::Name("Placeholder".to_string()),
564                Attr::RenameAll(RenamingRule::ScreamingSnakeCase),
565            ]
566        );
567
568        // `#[pyo3(get)]` part
569        if let Fields::Named(fields) = item.fields {
570            let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
571            assert_eq!(attrs, vec![Attr::Get]);
572        } else {
573            unreachable!()
574        }
575        Ok(())
576    }
577
578    #[test]
579    fn test_parse_pyo3_attr_full_path() -> Result<()> {
580        let item: ItemStruct = parse_str(
581            r#"
582            #[pyo3::pyclass(mapping, module = "my_module", name = "Placeholder")]
583            pub struct PyPlaceholder {
584                #[pyo3(get)]
585                pub name: String,
586            }
587            "#,
588        )?;
589        // `#[pyclass]` part
590        let attrs = parse_pyo3_attr(&item.attrs[0])?;
591        assert_eq!(
592            attrs,
593            vec![
594                Attr::Module("my_module".to_string()),
595                Attr::Name("Placeholder".to_string())
596            ]
597        );
598
599        // `#[pyo3(get)]` part
600        if let Fields::Named(fields) = item.fields {
601            let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
602            assert_eq!(attrs, vec![Attr::Get]);
603        } else {
604            unreachable!()
605        }
606        Ok(())
607    }
608    #[test]
609    fn test_parse_gen_stub_field_attr() -> Result<()> {
610        let item: ItemStruct = parse_str(
611            r#"
612            pub struct PyPlaceholder {
613                #[gen_stub(default = String::from("foo"), skip)]
614                pub field0: String,
615                #[gen_stub(skip)]
616                pub field1: String,
617                #[gen_stub(default = 1+2)]
618                pub field2: usize,
619            }
620            "#,
621        )?;
622        let fields: Vec<_> = item.fields.into_iter().collect();
623        let field0_attrs = parse_gen_stub_attrs(&fields[0].attrs, AttributeLocation::Field, None)?;
624        if let StubGenAttr::Default(expr) = &field0_attrs[0] {
625            assert_eq!(
626                expr.to_token_stream().to_string(),
627                "String :: from (\"foo\")"
628            );
629        } else {
630            panic!("attr should be Default");
631        };
632        assert_eq!(&StubGenAttr::Skip, &field0_attrs[1]);
633        let field1_attrs = parse_gen_stub_attrs(&fields[1].attrs, AttributeLocation::Field, None)?;
634        assert_eq!(vec![StubGenAttr::Skip], field1_attrs);
635        let field2_attrs = parse_gen_stub_attrs(&fields[2].attrs, AttributeLocation::Field, None)?;
636        if let StubGenAttr::Default(expr) = &field2_attrs[0] {
637            assert_eq!(expr.to_token_stream().to_string(), "1 + 2");
638        } else {
639            panic!("attr should be Default");
640        };
641        Ok(())
642    }
643    #[test]
644    fn test_parse_gen_stub_override_type_attr() -> Result<()> {
645        let item: ItemFn = parse_str(
646            r#"
647            #[gen_stub_pyfunction]
648            #[pyfunction]
649            #[gen_stub(override_return_type(type_repr="typing.Never", imports=("typing")))]
650            fn say_hello_forever<'a>(
651                #[gen_stub(override_type(type_repr="collections.abc.Callable[[str]]", imports=("collections.abc")))]
652                cb: Bound<'a, PyAny>,
653            ) -> PyResult<()> {
654                loop {
655                    cb.call1(("Hello!",))?;
656                }
657            }
658            "#,
659        )?;
660        let fn_attrs = parse_gen_stub_attrs(&item.attrs, AttributeLocation::Function, None)?;
661        assert_eq!(fn_attrs.len(), 1);
662        if let StubGenAttr::OverrideType(expr) = &fn_attrs[0] {
663            assert_eq!(
664                *expr,
665                OverrideTypeAttribute {
666                    type_repr: "typing.Never".into(),
667                    imports: HashSet::from(["typing".into()])
668                }
669            );
670        } else {
671            panic!("attr should be OverrideType");
672        };
673        if let syn::FnArg::Typed(PatType { attrs, .. }) = &item.sig.inputs[0] {
674            let arg_attrs = parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)?;
675            assert_eq!(arg_attrs.len(), 1);
676            if let StubGenAttr::OverrideType(expr) = &arg_attrs[0] {
677                assert_eq!(
678                    *expr,
679                    OverrideTypeAttribute {
680                        type_repr: "collections.abc.Callable[[str]]".into(),
681                        imports: HashSet::from(["collections.abc".into()])
682                    }
683                );
684            } else {
685                panic!("attr should be OverrideType");
686            };
687        }
688        Ok(())
689    }
690}