1use indexmap::IndexSet;
2
3use super::{RenamingRule, Signature};
4use proc_macro2::{Delimiter, TokenStream as TokenStream2, TokenTree};
5use quote::{quote, ToTokens, TokenStreamExt};
6use syn::{
7 parenthesized,
8 parse::{Parse, ParseStream},
9 punctuated::Punctuated,
10 Attribute, Expr, ExprLit, Ident, Lit, LitStr, Meta, MetaList, Result, Token, Type,
11};
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum IgnoreTarget {
16 All,
18 SpecifiedLits(Vec<LitStr>),
20}
21
22pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
23 let mut docs = Vec::new();
24 for attr in attrs {
25 if attr.path().is_ident("doc") {
27 if let Meta::NameValue(syn::MetaNameValue {
28 value:
29 Expr::Lit(ExprLit {
30 lit: Lit::Str(doc), ..
31 }),
32 ..
33 }) = &attr.meta
34 {
35 let doc = doc.value();
36 docs.push(if !doc.is_empty() && doc.starts_with(' ') {
43 doc[1..].to_string()
44 } else {
45 doc
46 });
47 }
48 }
49 }
50 docs
51}
52
53pub fn extract_deprecated(attrs: &[Attribute]) -> Option<DeprecatedInfo> {
55 for attr in attrs {
56 if attr.path().is_ident("deprecated") {
57 if let Ok(list) = attr.meta.require_list() {
58 let mut since = None;
59 let mut note = None;
60
61 list.parse_nested_meta(|meta| {
62 if meta.path.is_ident("since") {
63 let value = meta.value()?;
64 let lit: LitStr = value.parse()?;
65 since = Some(lit.value());
66 } else if meta.path.is_ident("note") {
67 let value = meta.value()?;
68 let lit: LitStr = value.parse()?;
69 note = Some(lit.value());
70 }
71 Ok(())
72 })
73 .ok()?;
74
75 return Some(DeprecatedInfo { since, note });
76 }
77 }
78 }
79 None
80}
81
82#[derive(Debug, Clone, PartialEq)]
95pub struct DeprecatedInfo {
96 pub since: Option<String>,
97 pub note: Option<String>,
98}
99
100impl ToTokens for DeprecatedInfo {
101 fn to_tokens(&self, tokens: &mut TokenStream2) {
102 let since = self
103 .since
104 .as_ref()
105 .map(|s| quote! { Some(#s) })
106 .unwrap_or_else(|| quote! { None });
107 let note = self
108 .note
109 .as_ref()
110 .map(|n| quote! { Some(#n) })
111 .unwrap_or_else(|| quote! { None });
112 tokens.append_all(quote! {
113 ::pyo3_stub_gen::type_info::DeprecatedInfo {
114 since: #since,
115 note: #note,
116 }
117 })
118 }
119}
120
121#[derive(Debug, Clone, PartialEq)]
122#[expect(clippy::enum_variant_names)]
123pub enum Attr {
124 Name(String),
126 Get,
127 GetAll,
128 Set,
129 SetAll,
130 Module(String),
131 Constructor(Signature),
132 Signature(Signature),
133 RenameAll(RenamingRule),
134 Extends(Type),
135
136 Eq,
138 Ord,
139 Hash,
140 Str,
141 Subclass,
142
143 GenStubModule(String),
145
146 New,
149 Getter(Option<String>),
150 Setter(Option<String>),
151 StaticMethod,
152 ClassMethod,
153 ClassAttr,
154}
155
156pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
157 let mut out = Vec::new();
158 for attr in attrs {
159 let mut new = parse_pyo3_attr(attr)?;
160 out.append(&mut new);
161 if let Some(gen_stub_attr) = parse_gen_stub_module_attr(attr)? {
163 out.push(gen_stub_attr);
164 }
165 }
166 Ok(out)
167}
168
169pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
170 let mut pyo3_attrs = Vec::new();
171 let path = attr.path();
172 let is_full_path_pyo3_attr = path.segments.len() == 2
173 && path
174 .segments
175 .first()
176 .is_some_and(|seg| seg.ident.eq("pyo3"))
177 && path.segments.last().is_some_and(|seg| {
178 seg.ident.eq("pyclass") || seg.ident.eq("pymethods") || seg.ident.eq("pyfunction")
179 });
180 if path.is_ident("pyclass")
181 || path.is_ident("pymethods")
182 || path.is_ident("pyfunction")
183 || path.is_ident("pyo3")
184 || is_full_path_pyo3_attr
185 {
186 if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
191 use TokenTree::*;
192 let tokens: Vec<TokenTree> = flatten_invisible_groups(tokens.clone()).collect();
198 for tt in tokens.split(|tt| {
201 if let Punct(p) = tt {
202 p.as_char() == ','
203 } else {
204 false
205 }
206 }) {
207 match tt {
208 [Ident(ident)] => {
209 if ident == "get" {
210 pyo3_attrs.push(Attr::Get);
211 }
212 if ident == "get_all" {
213 pyo3_attrs.push(Attr::GetAll);
214 }
215 if ident == "set" {
216 pyo3_attrs.push(Attr::Set);
217 }
218 if ident == "set_all" {
219 pyo3_attrs.push(Attr::SetAll);
220 }
221 if ident == "eq" {
222 pyo3_attrs.push(Attr::Eq);
223 }
224 if ident == "ord" {
225 pyo3_attrs.push(Attr::Ord);
226 }
227 if ident == "hash" {
228 pyo3_attrs.push(Attr::Hash);
229 }
230 if ident == "str" {
231 pyo3_attrs.push(Attr::Str);
232 }
233 if ident == "subclass" {
234 pyo3_attrs.push(Attr::Subclass);
235 }
236 }
238 [Ident(ident), Punct(_), Literal(lit)] => {
239 if ident == "name" {
240 pyo3_attrs
241 .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
242 }
243 if ident == "module" {
244 pyo3_attrs
245 .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
246 }
247 if ident == "rename_all" {
248 let name = lit.to_string().trim_matches('"').to_string();
249 if let Some(renaming_rule) = RenamingRule::try_new(&name) {
250 pyo3_attrs.push(Attr::RenameAll(renaming_rule));
251 }
252 }
253 }
254 [Ident(ident), Punct(_), Group(group)] => {
255 if ident == "signature" {
256 pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
257 } else if ident == "constructor" {
258 pyo3_attrs
259 .push(Attr::Constructor(syn::parse2(group.to_token_stream())?));
260 }
261 }
262 [Ident(ident), Punct(_), Ident(ident2)] => {
263 if ident == "extends" {
264 pyo3_attrs.push(Attr::Extends(syn::parse2(ident2.to_token_stream())?));
265 }
266 }
267 _ => {}
268 }
269 }
270 }
271 } else if path.is_ident("new") {
272 pyo3_attrs.push(Attr::New);
273 } else if path.is_ident("staticmethod") {
274 pyo3_attrs.push(Attr::StaticMethod);
275 } else if path.is_ident("classmethod") {
276 pyo3_attrs.push(Attr::ClassMethod);
277 } else if path.is_ident("classattr") {
278 pyo3_attrs.push(Attr::ClassAttr);
279 } else if path.is_ident("getter") {
280 if let Ok(inner) = attr.parse_args::<Ident>() {
281 pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
282 } else {
283 pyo3_attrs.push(Attr::Getter(None));
284 }
285 } else if path.is_ident("setter") {
286 if let Ok(inner) = attr.parse_args::<Ident>() {
287 pyo3_attrs.push(Attr::Setter(Some(inner.to_string())));
288 } else {
289 pyo3_attrs.push(Attr::Setter(None));
290 }
291 }
292
293 Ok(pyo3_attrs)
294}
295
296fn flatten_invisible_groups(tokens: TokenStream2) -> impl Iterator<Item = TokenTree> {
300 tokens
301 .into_iter()
302 .flat_map(|tt| -> Box<dyn Iterator<Item = TokenTree>> {
303 match tt {
304 TokenTree::Group(g) if g.delimiter() == Delimiter::None => {
305 Box::new(flatten_invisible_groups(g.stream()))
306 }
307 other => Box::new(std::iter::once(other)),
308 }
309 })
310}
311
312pub fn parse_gen_stub_module_attr(attr: &Attribute) -> Result<Option<Attr>> {
314 let path = attr.path();
315 if path.is_ident("gen_stub") {
316 if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
318 use TokenTree::*;
319 let tokens: Vec<TokenTree> = flatten_invisible_groups(tokens.clone()).collect();
321
322 for tt in tokens.split(|tt| {
324 if let Punct(p) = tt {
325 p.as_char() == ','
326 } else {
327 false
328 }
329 }) {
330 match tt {
331 [Ident(ident), Punct(_), Literal(lit)] if ident == "module" => {
332 return Ok(Some(Attr::GenStubModule(
333 lit.to_string().trim_matches('"').to_string(),
334 )));
335 }
336 _ => {}
337 }
338 }
339 }
340 }
341 Ok(None)
342}
343
344#[derive(Debug, Clone, PartialEq)]
345pub enum StubGenAttr {
346 Default(Expr),
348 Skip,
350 OverrideType(OverrideTypeAttribute),
352 TypeIgnore(IgnoreTarget),
354}
355
356pub fn prune_attrs(attrs: &mut Vec<Attribute>) {
357 attrs.retain(|attr| !attr.path().is_ident("gen_stub"));
358}
359
360pub fn parse_gen_stub_override_type(attrs: &[Attribute]) -> Result<Option<OverrideTypeAttribute>> {
361 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)? {
362 if let StubGenAttr::OverrideType(attr) = attr {
363 return Ok(Some(attr));
364 }
365 }
366 Ok(None)
367}
368
369pub fn parse_gen_stub_override_return_type(
370 attrs: &[Attribute],
371) -> Result<Option<OverrideTypeAttribute>> {
372 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
373 if let StubGenAttr::OverrideType(attr) = attr {
374 return Ok(Some(attr));
375 }
376 }
377 Ok(None)
378}
379
380pub fn parse_gen_stub_default(attrs: &[Attribute]) -> Result<Option<Expr>> {
381 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
382 if let StubGenAttr::Default(default) = attr {
383 return Ok(Some(default));
384 }
385 }
386 Ok(None)
387}
388pub fn parse_gen_stub_skip(attrs: &[Attribute]) -> Result<bool> {
389 let skip = parse_gen_stub_attrs(
390 attrs,
391 AttributeLocation::Field,
392 Some(&["override_return_type", "default"]),
393 )?
394 .iter()
395 .any(|attr| matches!(attr, StubGenAttr::Skip));
396 Ok(skip)
397}
398
399pub fn parse_gen_stub_type_ignore(attrs: &[Attribute]) -> Result<Option<IgnoreTarget>> {
400 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
402 if let StubGenAttr::TypeIgnore(target) = attr {
403 return Ok(Some(target));
404 }
405 }
406 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Field, None)? {
408 if let StubGenAttr::TypeIgnore(target) = attr {
409 return Ok(Some(target));
410 }
411 }
412 Ok(None)
413}
414
415fn parse_gen_stub_attrs(
416 attrs: &[Attribute],
417 location: AttributeLocation,
418 ignored_idents: Option<&[&str]>,
419) -> Result<Vec<StubGenAttr>> {
420 let mut out = Vec::new();
421 for attr in attrs {
422 let mut new = parse_gen_stub_attr(attr, location, ignored_idents.unwrap_or(&[]))?;
423 out.append(&mut new);
424 }
425 Ok(out)
426}
427
428fn parse_gen_stub_attr(
429 attr: &Attribute,
430 location: AttributeLocation,
431 ignored_idents: &[&str],
432) -> Result<Vec<StubGenAttr>> {
433 let mut gen_stub_attrs = Vec::new();
434 let path = attr.path();
435 if path.is_ident("gen_stub") {
436 attr.parse_args_with(|input: ParseStream| {
437 while !input.is_empty() {
438 let ident: Ident = input.parse()?;
439 let ignored_ident = ignored_idents.iter().any(|other| ident == other);
440 if (ident == "override_type"
441 && (location == AttributeLocation::Argument || ignored_ident))
442 || (ident == "override_return_type"
443 && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident))
444 {
445 let content;
446 parenthesized!(content in input);
447 let override_attr: OverrideTypeAttribute = content.parse()?;
448 gen_stub_attrs.push(StubGenAttr::OverrideType(override_attr));
449 } else if ident == "skip" && (location == AttributeLocation::Field || ignored_ident)
450 {
451 gen_stub_attrs.push(StubGenAttr::Skip);
452 } else if ident == "default"
453 && input.peek(Token![=])
454 && (location == AttributeLocation::Field || location == AttributeLocation::Function || ignored_ident)
455 {
456 input.parse::<Token![=]>()?;
457 gen_stub_attrs.push(StubGenAttr::Default(input.parse()?));
458 } else if ident == "type_ignore"
459 && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident)
460 {
461 if input.peek(Token![=]) {
465 input.parse::<Token![=]>()?;
466 let content;
468 syn::bracketed!(content in input);
469 let rules = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
470
471 if rules.is_empty() {
473 return Err(syn::Error::new(
474 ident.span(),
475 "type_ignore with empty array is not allowed. Use type_ignore without equals for catch-all, or specify rules in the array."
476 ));
477 }
478
479 let rule_lits: Vec<LitStr> = rules.into_iter().collect();
481 gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::SpecifiedLits(rule_lits)));
482 } else {
483 gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::All));
485 }
486 } else if ident == "override_type" {
487 return Err(syn::Error::new(
488 ident.span(),
489 "`override_type(...)` is only valid in argument position".to_string(),
490 ));
491 } else if ident == "override_return_type" {
492 return Err(syn::Error::new(
493 ident.span(),
494 "`override_return_type(...)` is only valid in function or method position"
495 .to_string(),
496 ));
497 } else if ident == "skip" {
498 return Err(syn::Error::new(
499 ident.span(),
500 "`skip` is only valid in field position".to_string(),
501 ));
502 } else if ident == "default" {
503 return Err(syn::Error::new(
504 ident.span(),
505 "`default=xxx` is only valid in field or function position".to_string(),
506 ));
507 } else if ident == "type_ignore" {
508 return Err(syn::Error::new(
509 ident.span(),
510 "`type_ignore` or `type_ignore=[...]` is only valid in function or method position".to_string(),
511 ));
512 } else if location == AttributeLocation::Argument {
513 return Err(syn::Error::new(
514 ident.span(),
515 format!("Unsupported keyword `{ident}`, valid is `override_type(...)`"),
516 ));
517 } else if location == AttributeLocation::Field {
518 return Err(syn::Error::new(
519 ident.span(),
520 format!("Unsupported keyword `{ident}`, valid is `default=xxx`, `skip`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"),
521 ));
522 } else if location == AttributeLocation::Function {
523 return Err(syn::Error::new(
524 ident.span(),
525 format!(
526 "Unsupported keyword `{ident}`, valid is `default=xxx`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"
527 ),
528 ));
529 } else {
530 return Err(syn::Error::new(
531 ident.span(),
532 format!("Unsupported keyword `{ident}`"),
533 ));
534 }
535 if input.peek(Token![,]) {
536 input.parse::<Token![,]>()?;
537 } else {
538 break;
539 }
540 }
541 Ok(())
542 })?;
543 }
544 Ok(gen_stub_attrs)
545}
546
547#[derive(Debug, Clone, Copy, PartialEq)]
548pub(crate) enum AttributeLocation {
549 Argument,
550 Field,
551 Function,
552}
553
554#[derive(Debug, Clone, PartialEq)]
555pub struct OverrideTypeAttribute {
556 pub(crate) type_repr: String,
557 pub(crate) imports: IndexSet<String>,
558}
559
560mod kw {
561 syn::custom_keyword!(type_repr);
562 syn::custom_keyword!(imports);
563 syn::custom_keyword!(override_type);
564}
565
566impl Parse for OverrideTypeAttribute {
567 fn parse(input: ParseStream) -> Result<Self> {
568 let mut type_repr = None;
569 let mut imports = IndexSet::new();
570
571 while !input.is_empty() {
572 let lookahead = input.lookahead1();
573
574 if lookahead.peek(kw::type_repr) {
575 input.parse::<kw::type_repr>()?;
576 input.parse::<Token![=]>()?;
577 type_repr = Some(input.parse::<LitStr>()?);
578 } else if lookahead.peek(kw::imports) {
579 input.parse::<kw::imports>()?;
580 input.parse::<Token![=]>()?;
581
582 let content;
583 parenthesized!(content in input);
584 let parsed_imports = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
585 imports = parsed_imports.into_iter().collect();
586 } else {
587 return Err(lookahead.error());
588 }
589
590 if !input.is_empty() {
591 input.parse::<Token![,]>()?;
592 }
593 }
594
595 Ok(OverrideTypeAttribute {
596 type_repr: type_repr
597 .ok_or_else(|| input.error("missing type_repr"))?
598 .value(),
599 imports: imports.iter().map(|i| i.value()).collect(),
600 })
601 }
602}
603
604#[derive(Default)]
607pub struct PyClassAttr {
608 pub skip_stub_type: bool,
609 pub module: Option<String>,
610}
611
612impl Parse for PyClassAttr {
613 fn parse(input: ParseStream) -> Result<Self> {
614 let mut skip_stub_type = false;
615 let mut module = None;
616
617 while !input.is_empty() {
619 let key: Ident = input.parse()?;
620
621 match key.to_string().as_str() {
622 "skip_stub_type" => {
623 skip_stub_type = true;
624 }
625 "module" => {
626 let _: Token![=] = input.parse()?;
627 let value: LitStr = input.parse()?;
628 module = Some(value.value());
629 }
630 _ => {
631 return Err(syn::Error::new(
632 key.span(),
633 format!("Unknown parameter: {}", key),
634 ));
635 }
636 }
637
638 if input.peek(Token![,]) {
640 let _: Token![,] = input.parse()?;
641 } else {
642 break;
643 }
644 }
645
646 Ok(Self {
647 skip_stub_type,
648 module,
649 })
650 }
651}
652
653#[cfg(test)]
654mod test {
655 use super::*;
656 use syn::{parse_str, Fields, ItemFn, ItemStruct, PatType};
657
658 #[test]
659 fn test_parse_pyo3_attr() -> Result<()> {
660 let item: ItemStruct = parse_str(
661 r#"
662 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
663 #[pyo3(rename_all = "SCREAMING_SNAKE_CASE")]
664 pub struct PyPlaceholder {
665 #[pyo3(get)]
666 pub name: String,
667 }
668 "#,
669 )?;
670 let attrs = parse_pyo3_attrs(&item.attrs)?;
672 assert_eq!(
673 attrs,
674 vec![
675 Attr::Module("my_module".to_string()),
676 Attr::Name("Placeholder".to_string()),
677 Attr::RenameAll(RenamingRule::ScreamingSnakeCase),
678 ]
679 );
680
681 if let Fields::Named(fields) = item.fields {
683 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
684 assert_eq!(attrs, vec![Attr::Get]);
685 } else {
686 unreachable!()
687 }
688 Ok(())
689 }
690
691 #[test]
692 fn test_parse_pyo3_attr_full_path() -> Result<()> {
693 let item: ItemStruct = parse_str(
694 r#"
695 #[pyo3::pyclass(mapping, module = "my_module", name = "Placeholder")]
696 pub struct PyPlaceholder {
697 #[pyo3(get)]
698 pub name: String,
699 }
700 "#,
701 )?;
702 let attrs = parse_pyo3_attr(&item.attrs[0])?;
704 assert_eq!(
705 attrs,
706 vec![
707 Attr::Module("my_module".to_string()),
708 Attr::Name("Placeholder".to_string())
709 ]
710 );
711
712 if let Fields::Named(fields) = item.fields {
714 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
715 assert_eq!(attrs, vec![Attr::Get]);
716 } else {
717 unreachable!()
718 }
719 Ok(())
720 }
721 fn attrs_with_invisible_group(
726 attr_path: TokenStream2,
727 prefix: TokenStream2,
728 value: TokenStream2,
729 ) -> Vec<Attribute> {
730 let invisible = TokenTree::Group(proc_macro2::Group::new(Delimiter::None, value));
731 let item: ItemStruct = syn::parse2(quote! {
732 #[#attr_path(#prefix = #invisible)]
733 struct Dummy;
734 })
735 .unwrap();
736 item.attrs
737 }
738
739 #[test]
740 fn test_parse_pyo3_attr_name_from_macro_substitution() -> Result<()> {
741 let attrs = attrs_with_invisible_group(quote!(pyclass), quote!(name), quote!("FromMacro"));
743 let parsed = parse_pyo3_attrs(&attrs)?;
744 assert_eq!(parsed, vec![Attr::Name("FromMacro".to_string())]);
745 Ok(())
746 }
747
748 #[test]
749 fn test_parse_pyo3_attr_module_from_macro_substitution() -> Result<()> {
750 let attrs =
751 attrs_with_invisible_group(quote!(pyclass), quote!(module), quote!("my.mod.path"));
752 let parsed = parse_pyo3_attrs(&attrs)?;
753 assert_eq!(parsed, vec![Attr::Module("my.mod.path".to_string())]);
754 Ok(())
755 }
756
757 #[test]
758 fn test_parse_pyo3_attr_rename_all_from_macro_substitution() -> Result<()> {
759 let attrs = attrs_with_invisible_group(
760 quote!(pyo3),
761 quote!(rename_all),
762 quote!("SCREAMING_SNAKE_CASE"),
763 );
764 let parsed = parse_pyo3_attrs(&attrs)?;
765 assert_eq!(
766 parsed,
767 vec![Attr::RenameAll(RenamingRule::ScreamingSnakeCase)]
768 );
769 Ok(())
770 }
771
772 #[test]
773 fn test_parse_gen_stub_module_attr_from_macro_substitution() -> Result<()> {
774 let attrs =
775 attrs_with_invisible_group(quote!(gen_stub), quote!(module), quote!("explicit.mod"));
776 let parsed = parse_pyo3_attrs(&attrs)?;
778 assert_eq!(
779 parsed,
780 vec![Attr::GenStubModule("explicit.mod".to_string())]
781 );
782 Ok(())
783 }
784
785 #[test]
786 fn test_parse_gen_stub_field_attr() -> Result<()> {
787 let item: ItemStruct = parse_str(
788 r#"
789 pub struct PyPlaceholder {
790 #[gen_stub(default = String::from("foo"), skip)]
791 pub field0: String,
792 #[gen_stub(skip)]
793 pub field1: String,
794 #[gen_stub(default = 1+2)]
795 pub field2: usize,
796 }
797 "#,
798 )?;
799 let fields: Vec<_> = item.fields.into_iter().collect();
800 let field0_attrs = parse_gen_stub_attrs(&fields[0].attrs, AttributeLocation::Field, None)?;
801 if let StubGenAttr::Default(expr) = &field0_attrs[0] {
802 assert_eq!(
803 expr.to_token_stream().to_string(),
804 "String :: from (\"foo\")"
805 );
806 } else {
807 panic!("attr should be Default");
808 };
809 assert_eq!(&StubGenAttr::Skip, &field0_attrs[1]);
810 let field1_attrs = parse_gen_stub_attrs(&fields[1].attrs, AttributeLocation::Field, None)?;
811 assert_eq!(vec![StubGenAttr::Skip], field1_attrs);
812 let field2_attrs = parse_gen_stub_attrs(&fields[2].attrs, AttributeLocation::Field, None)?;
813 if let StubGenAttr::Default(expr) = &field2_attrs[0] {
814 assert_eq!(expr.to_token_stream().to_string(), "1 + 2");
815 } else {
816 panic!("attr should be Default");
817 };
818 Ok(())
819 }
820 #[test]
821 fn test_parse_gen_stub_override_type_attr() -> Result<()> {
822 let item: ItemFn = parse_str(
823 r#"
824 #[gen_stub_pyfunction]
825 #[pyfunction]
826 #[gen_stub(override_return_type(type_repr="typing.Never", imports=("typing")))]
827 fn say_hello_forever<'a>(
828 #[gen_stub(override_type(type_repr="collections.abc.Callable[[str]]", imports=("collections.abc")))]
829 cb: Bound<'a, PyAny>,
830 ) -> PyResult<()> {
831 loop {
832 cb.call1(("Hello!",))?;
833 }
834 }
835 "#,
836 )?;
837 let fn_attrs = parse_gen_stub_attrs(&item.attrs, AttributeLocation::Function, None)?;
838 assert_eq!(fn_attrs.len(), 1);
839 if let StubGenAttr::OverrideType(expr) = &fn_attrs[0] {
840 assert_eq!(
841 *expr,
842 OverrideTypeAttribute {
843 type_repr: "typing.Never".into(),
844 imports: IndexSet::from(["typing".into()])
845 }
846 );
847 } else {
848 panic!("attr should be OverrideType");
849 };
850 if let syn::FnArg::Typed(PatType { attrs, .. }) = &item.sig.inputs[0] {
851 let arg_attrs = parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)?;
852 assert_eq!(arg_attrs.len(), 1);
853 if let StubGenAttr::OverrideType(expr) = &arg_attrs[0] {
854 assert_eq!(
855 *expr,
856 OverrideTypeAttribute {
857 type_repr: "collections.abc.Callable[[str]]".into(),
858 imports: IndexSet::from(["collections.abc".into()])
859 }
860 );
861 } else {
862 panic!("attr should be OverrideType");
863 };
864 }
865 Ok(())
866 }
867}