1use indexmap::IndexSet;
2
3use super::{RenamingRule, Signature};
4use proc_macro2::{TokenStream as TokenStream2, TokenTree};
5use quote::{quote, ToTokens, TokenStreamExt};
6use syn::{
7 parenthesized,
8 parse::{Parse, ParseStream},
9 punctuated::Punctuated,
10 Attribute, Expr, ExprLit, Ident, Lit, LitStr, Meta, MetaList, Result, Token, Type,
11};
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum IgnoreTarget {
16 All,
18 SpecifiedLits(Vec<LitStr>),
20}
21
22pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
23 let mut docs = Vec::new();
24 for attr in attrs {
25 if attr.path().is_ident("doc") {
27 if let Meta::NameValue(syn::MetaNameValue {
28 value:
29 Expr::Lit(ExprLit {
30 lit: Lit::Str(doc), ..
31 }),
32 ..
33 }) = &attr.meta
34 {
35 let doc = doc.value();
36 docs.push(if !doc.is_empty() && doc.starts_with(' ') {
43 doc[1..].to_string()
44 } else {
45 doc
46 });
47 }
48 }
49 }
50 docs
51}
52
53pub fn extract_deprecated(attrs: &[Attribute]) -> Option<DeprecatedInfo> {
55 for attr in attrs {
56 if attr.path().is_ident("deprecated") {
57 if let Ok(list) = attr.meta.require_list() {
58 let mut since = None;
59 let mut note = None;
60
61 list.parse_nested_meta(|meta| {
62 if meta.path.is_ident("since") {
63 let value = meta.value()?;
64 let lit: LitStr = value.parse()?;
65 since = Some(lit.value());
66 } else if meta.path.is_ident("note") {
67 let value = meta.value()?;
68 let lit: LitStr = value.parse()?;
69 note = Some(lit.value());
70 }
71 Ok(())
72 })
73 .ok()?;
74
75 return Some(DeprecatedInfo { since, note });
76 }
77 }
78 }
79 None
80}
81
82#[derive(Debug, Clone, PartialEq)]
95pub struct DeprecatedInfo {
96 pub since: Option<String>,
97 pub note: Option<String>,
98}
99
100impl ToTokens for DeprecatedInfo {
101 fn to_tokens(&self, tokens: &mut TokenStream2) {
102 let since = self
103 .since
104 .as_ref()
105 .map(|s| quote! { Some(#s) })
106 .unwrap_or_else(|| quote! { None });
107 let note = self
108 .note
109 .as_ref()
110 .map(|n| quote! { Some(#n) })
111 .unwrap_or_else(|| quote! { None });
112 tokens.append_all(quote! {
113 ::pyo3_stub_gen::type_info::DeprecatedInfo {
114 since: #since,
115 note: #note,
116 }
117 })
118 }
119}
120
121#[derive(Debug, Clone, PartialEq)]
122#[expect(clippy::enum_variant_names)]
123pub enum Attr {
124 Name(String),
126 Get,
127 GetAll,
128 Set,
129 SetAll,
130 Module(String),
131 Constructor(Signature),
132 Signature(Signature),
133 RenameAll(RenamingRule),
134 Extends(Type),
135
136 Eq,
138 Ord,
139 Hash,
140 Str,
141 Subclass,
142
143 GenStubModule(String),
145
146 New,
149 Getter(Option<String>),
150 Setter(Option<String>),
151 StaticMethod,
152 ClassMethod,
153 ClassAttr,
154}
155
156pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
157 let mut out = Vec::new();
158 for attr in attrs {
159 let mut new = parse_pyo3_attr(attr)?;
160 out.append(&mut new);
161 if let Some(gen_stub_attr) = parse_gen_stub_module_attr(attr)? {
163 out.push(gen_stub_attr);
164 }
165 }
166 Ok(out)
167}
168
169pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
170 let mut pyo3_attrs = Vec::new();
171 let path = attr.path();
172 let is_full_path_pyo3_attr = path.segments.len() == 2
173 && path
174 .segments
175 .first()
176 .is_some_and(|seg| seg.ident.eq("pyo3"))
177 && path.segments.last().is_some_and(|seg| {
178 seg.ident.eq("pyclass") || seg.ident.eq("pymethods") || seg.ident.eq("pyfunction")
179 });
180 if path.is_ident("pyclass")
181 || path.is_ident("pymethods")
182 || path.is_ident("pyfunction")
183 || path.is_ident("pyo3")
184 || is_full_path_pyo3_attr
185 {
186 if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
191 use TokenTree::*;
192 let tokens: Vec<TokenTree> = tokens.clone().into_iter().collect();
193 for tt in tokens.split(|tt| {
196 if let Punct(p) = tt {
197 p.as_char() == ','
198 } else {
199 false
200 }
201 }) {
202 match tt {
203 [Ident(ident)] => {
204 if ident == "get" {
205 pyo3_attrs.push(Attr::Get);
206 }
207 if ident == "get_all" {
208 pyo3_attrs.push(Attr::GetAll);
209 }
210 if ident == "set" {
211 pyo3_attrs.push(Attr::Set);
212 }
213 if ident == "set_all" {
214 pyo3_attrs.push(Attr::SetAll);
215 }
216 if ident == "eq" {
217 pyo3_attrs.push(Attr::Eq);
218 }
219 if ident == "ord" {
220 pyo3_attrs.push(Attr::Ord);
221 }
222 if ident == "hash" {
223 pyo3_attrs.push(Attr::Hash);
224 }
225 if ident == "str" {
226 pyo3_attrs.push(Attr::Str);
227 }
228 if ident == "subclass" {
229 pyo3_attrs.push(Attr::Subclass);
230 }
231 }
233 [Ident(ident), Punct(_), Literal(lit)] => {
234 if ident == "name" {
235 pyo3_attrs
236 .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
237 }
238 if ident == "module" {
239 pyo3_attrs
240 .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
241 }
242 if ident == "rename_all" {
243 let name = lit.to_string().trim_matches('"').to_string();
244 if let Some(renaming_rule) = RenamingRule::try_new(&name) {
245 pyo3_attrs.push(Attr::RenameAll(renaming_rule));
246 }
247 }
248 }
249 [Ident(ident), Punct(_), Group(group)] => {
250 if ident == "signature" {
251 pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
252 } else if ident == "constructor" {
253 pyo3_attrs
254 .push(Attr::Constructor(syn::parse2(group.to_token_stream())?));
255 }
256 }
257 [Ident(ident), Punct(_), Ident(ident2)] => {
258 if ident == "extends" {
259 pyo3_attrs.push(Attr::Extends(syn::parse2(ident2.to_token_stream())?));
260 }
261 }
262 _ => {}
263 }
264 }
265 }
266 } else if path.is_ident("new") {
267 pyo3_attrs.push(Attr::New);
268 } else if path.is_ident("staticmethod") {
269 pyo3_attrs.push(Attr::StaticMethod);
270 } else if path.is_ident("classmethod") {
271 pyo3_attrs.push(Attr::ClassMethod);
272 } else if path.is_ident("classattr") {
273 pyo3_attrs.push(Attr::ClassAttr);
274 } else if path.is_ident("getter") {
275 if let Ok(inner) = attr.parse_args::<Ident>() {
276 pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
277 } else {
278 pyo3_attrs.push(Attr::Getter(None));
279 }
280 } else if path.is_ident("setter") {
281 if let Ok(inner) = attr.parse_args::<Ident>() {
282 pyo3_attrs.push(Attr::Setter(Some(inner.to_string())));
283 } else {
284 pyo3_attrs.push(Attr::Setter(None));
285 }
286 }
287
288 Ok(pyo3_attrs)
289}
290
291pub fn parse_gen_stub_module_attr(attr: &Attribute) -> Result<Option<Attr>> {
293 let path = attr.path();
294 if path.is_ident("gen_stub") {
295 if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
297 use TokenTree::*;
298 let tokens: Vec<TokenTree> = tokens.clone().into_iter().collect();
299
300 for tt in tokens.split(|tt| {
302 if let Punct(p) = tt {
303 p.as_char() == ','
304 } else {
305 false
306 }
307 }) {
308 match tt {
309 [Ident(ident), Punct(_), Literal(lit)] if ident == "module" => {
310 return Ok(Some(Attr::GenStubModule(
311 lit.to_string().trim_matches('"').to_string(),
312 )));
313 }
314 _ => {}
315 }
316 }
317 }
318 }
319 Ok(None)
320}
321
322#[derive(Debug, Clone, PartialEq)]
323pub enum StubGenAttr {
324 Default(Expr),
326 Skip,
328 OverrideType(OverrideTypeAttribute),
330 TypeIgnore(IgnoreTarget),
332}
333
334pub fn prune_attrs(attrs: &mut Vec<Attribute>) {
335 attrs.retain(|attr| !attr.path().is_ident("gen_stub"));
336}
337
338pub fn parse_gen_stub_override_type(attrs: &[Attribute]) -> Result<Option<OverrideTypeAttribute>> {
339 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)? {
340 if let StubGenAttr::OverrideType(attr) = attr {
341 return Ok(Some(attr));
342 }
343 }
344 Ok(None)
345}
346
347pub fn parse_gen_stub_override_return_type(
348 attrs: &[Attribute],
349) -> Result<Option<OverrideTypeAttribute>> {
350 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
351 if let StubGenAttr::OverrideType(attr) = attr {
352 return Ok(Some(attr));
353 }
354 }
355 Ok(None)
356}
357
358pub fn parse_gen_stub_default(attrs: &[Attribute]) -> Result<Option<Expr>> {
359 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
360 if let StubGenAttr::Default(default) = attr {
361 return Ok(Some(default));
362 }
363 }
364 Ok(None)
365}
366pub fn parse_gen_stub_skip(attrs: &[Attribute]) -> Result<bool> {
367 let skip = parse_gen_stub_attrs(
368 attrs,
369 AttributeLocation::Field,
370 Some(&["override_return_type", "default"]),
371 )?
372 .iter()
373 .any(|attr| matches!(attr, StubGenAttr::Skip));
374 Ok(skip)
375}
376
377pub fn parse_gen_stub_type_ignore(attrs: &[Attribute]) -> Result<Option<IgnoreTarget>> {
378 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
380 if let StubGenAttr::TypeIgnore(target) = attr {
381 return Ok(Some(target));
382 }
383 }
384 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Field, None)? {
386 if let StubGenAttr::TypeIgnore(target) = attr {
387 return Ok(Some(target));
388 }
389 }
390 Ok(None)
391}
392
393fn parse_gen_stub_attrs(
394 attrs: &[Attribute],
395 location: AttributeLocation,
396 ignored_idents: Option<&[&str]>,
397) -> Result<Vec<StubGenAttr>> {
398 let mut out = Vec::new();
399 for attr in attrs {
400 let mut new = parse_gen_stub_attr(attr, location, ignored_idents.unwrap_or(&[]))?;
401 out.append(&mut new);
402 }
403 Ok(out)
404}
405
406fn parse_gen_stub_attr(
407 attr: &Attribute,
408 location: AttributeLocation,
409 ignored_idents: &[&str],
410) -> Result<Vec<StubGenAttr>> {
411 let mut gen_stub_attrs = Vec::new();
412 let path = attr.path();
413 if path.is_ident("gen_stub") {
414 attr.parse_args_with(|input: ParseStream| {
415 while !input.is_empty() {
416 let ident: Ident = input.parse()?;
417 let ignored_ident = ignored_idents.iter().any(|other| ident == other);
418 if (ident == "override_type"
419 && (location == AttributeLocation::Argument || ignored_ident))
420 || (ident == "override_return_type"
421 && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident))
422 {
423 let content;
424 parenthesized!(content in input);
425 let override_attr: OverrideTypeAttribute = content.parse()?;
426 gen_stub_attrs.push(StubGenAttr::OverrideType(override_attr));
427 } else if ident == "skip" && (location == AttributeLocation::Field || ignored_ident)
428 {
429 gen_stub_attrs.push(StubGenAttr::Skip);
430 } else if ident == "default"
431 && input.peek(Token![=])
432 && (location == AttributeLocation::Field || location == AttributeLocation::Function || ignored_ident)
433 {
434 input.parse::<Token![=]>()?;
435 gen_stub_attrs.push(StubGenAttr::Default(input.parse()?));
436 } else if ident == "type_ignore"
437 && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident)
438 {
439 if input.peek(Token![=]) {
443 input.parse::<Token![=]>()?;
444 let content;
446 syn::bracketed!(content in input);
447 let rules = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
448
449 if rules.is_empty() {
451 return Err(syn::Error::new(
452 ident.span(),
453 "type_ignore with empty array is not allowed. Use type_ignore without equals for catch-all, or specify rules in the array."
454 ));
455 }
456
457 let rule_lits: Vec<LitStr> = rules.into_iter().collect();
459 gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::SpecifiedLits(rule_lits)));
460 } else {
461 gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::All));
463 }
464 } else if ident == "override_type" {
465 return Err(syn::Error::new(
466 ident.span(),
467 "`override_type(...)` is only valid in argument position".to_string(),
468 ));
469 } else if ident == "override_return_type" {
470 return Err(syn::Error::new(
471 ident.span(),
472 "`override_return_type(...)` is only valid in function or method position"
473 .to_string(),
474 ));
475 } else if ident == "skip" {
476 return Err(syn::Error::new(
477 ident.span(),
478 "`skip` is only valid in field position".to_string(),
479 ));
480 } else if ident == "default" {
481 return Err(syn::Error::new(
482 ident.span(),
483 "`default=xxx` is only valid in field or function position".to_string(),
484 ));
485 } else if ident == "type_ignore" {
486 return Err(syn::Error::new(
487 ident.span(),
488 "`type_ignore` or `type_ignore=[...]` is only valid in function or method position".to_string(),
489 ));
490 } else if location == AttributeLocation::Argument {
491 return Err(syn::Error::new(
492 ident.span(),
493 format!("Unsupported keyword `{ident}`, valid is `override_type(...)`"),
494 ));
495 } else if location == AttributeLocation::Field {
496 return Err(syn::Error::new(
497 ident.span(),
498 format!("Unsupported keyword `{ident}`, valid is `default=xxx`, `skip`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"),
499 ));
500 } else if location == AttributeLocation::Function {
501 return Err(syn::Error::new(
502 ident.span(),
503 format!(
504 "Unsupported keyword `{ident}`, valid is `default=xxx`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"
505 ),
506 ));
507 } else {
508 return Err(syn::Error::new(
509 ident.span(),
510 format!("Unsupported keyword `{ident}`"),
511 ));
512 }
513 if input.peek(Token![,]) {
514 input.parse::<Token![,]>()?;
515 } else {
516 break;
517 }
518 }
519 Ok(())
520 })?;
521 }
522 Ok(gen_stub_attrs)
523}
524
525#[derive(Debug, Clone, Copy, PartialEq)]
526pub(crate) enum AttributeLocation {
527 Argument,
528 Field,
529 Function,
530}
531
532#[derive(Debug, Clone, PartialEq)]
533pub struct OverrideTypeAttribute {
534 pub(crate) type_repr: String,
535 pub(crate) imports: IndexSet<String>,
536}
537
538mod kw {
539 syn::custom_keyword!(type_repr);
540 syn::custom_keyword!(imports);
541 syn::custom_keyword!(override_type);
542}
543
544impl Parse for OverrideTypeAttribute {
545 fn parse(input: ParseStream) -> Result<Self> {
546 let mut type_repr = None;
547 let mut imports = IndexSet::new();
548
549 while !input.is_empty() {
550 let lookahead = input.lookahead1();
551
552 if lookahead.peek(kw::type_repr) {
553 input.parse::<kw::type_repr>()?;
554 input.parse::<Token![=]>()?;
555 type_repr = Some(input.parse::<LitStr>()?);
556 } else if lookahead.peek(kw::imports) {
557 input.parse::<kw::imports>()?;
558 input.parse::<Token![=]>()?;
559
560 let content;
561 parenthesized!(content in input);
562 let parsed_imports = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
563 imports = parsed_imports.into_iter().collect();
564 } else {
565 return Err(lookahead.error());
566 }
567
568 if !input.is_empty() {
569 input.parse::<Token![,]>()?;
570 }
571 }
572
573 Ok(OverrideTypeAttribute {
574 type_repr: type_repr
575 .ok_or_else(|| input.error("missing type_repr"))?
576 .value(),
577 imports: imports.iter().map(|i| i.value()).collect(),
578 })
579 }
580}
581
582#[derive(Default)]
585pub struct PyClassAttr {
586 pub skip_stub_type: bool,
587 pub module: Option<String>,
588}
589
590impl Parse for PyClassAttr {
591 fn parse(input: ParseStream) -> Result<Self> {
592 let mut skip_stub_type = false;
593 let mut module = None;
594
595 while !input.is_empty() {
597 let key: Ident = input.parse()?;
598
599 match key.to_string().as_str() {
600 "skip_stub_type" => {
601 skip_stub_type = true;
602 }
603 "module" => {
604 let _: Token![=] = input.parse()?;
605 let value: LitStr = input.parse()?;
606 module = Some(value.value());
607 }
608 _ => {
609 return Err(syn::Error::new(
610 key.span(),
611 format!("Unknown parameter: {}", key),
612 ));
613 }
614 }
615
616 if input.peek(Token![,]) {
618 let _: Token![,] = input.parse()?;
619 } else {
620 break;
621 }
622 }
623
624 Ok(Self {
625 skip_stub_type,
626 module,
627 })
628 }
629}
630
631#[cfg(test)]
632mod test {
633 use super::*;
634 use syn::{parse_str, Fields, ItemFn, ItemStruct, PatType};
635
636 #[test]
637 fn test_parse_pyo3_attr() -> Result<()> {
638 let item: ItemStruct = parse_str(
639 r#"
640 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
641 #[pyo3(rename_all = "SCREAMING_SNAKE_CASE")]
642 pub struct PyPlaceholder {
643 #[pyo3(get)]
644 pub name: String,
645 }
646 "#,
647 )?;
648 let attrs = parse_pyo3_attrs(&item.attrs)?;
650 assert_eq!(
651 attrs,
652 vec![
653 Attr::Module("my_module".to_string()),
654 Attr::Name("Placeholder".to_string()),
655 Attr::RenameAll(RenamingRule::ScreamingSnakeCase),
656 ]
657 );
658
659 if let Fields::Named(fields) = item.fields {
661 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
662 assert_eq!(attrs, vec![Attr::Get]);
663 } else {
664 unreachable!()
665 }
666 Ok(())
667 }
668
669 #[test]
670 fn test_parse_pyo3_attr_full_path() -> Result<()> {
671 let item: ItemStruct = parse_str(
672 r#"
673 #[pyo3::pyclass(mapping, module = "my_module", name = "Placeholder")]
674 pub struct PyPlaceholder {
675 #[pyo3(get)]
676 pub name: String,
677 }
678 "#,
679 )?;
680 let attrs = parse_pyo3_attr(&item.attrs[0])?;
682 assert_eq!(
683 attrs,
684 vec![
685 Attr::Module("my_module".to_string()),
686 Attr::Name("Placeholder".to_string())
687 ]
688 );
689
690 if let Fields::Named(fields) = item.fields {
692 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
693 assert_eq!(attrs, vec![Attr::Get]);
694 } else {
695 unreachable!()
696 }
697 Ok(())
698 }
699 #[test]
700 fn test_parse_gen_stub_field_attr() -> Result<()> {
701 let item: ItemStruct = parse_str(
702 r#"
703 pub struct PyPlaceholder {
704 #[gen_stub(default = String::from("foo"), skip)]
705 pub field0: String,
706 #[gen_stub(skip)]
707 pub field1: String,
708 #[gen_stub(default = 1+2)]
709 pub field2: usize,
710 }
711 "#,
712 )?;
713 let fields: Vec<_> = item.fields.into_iter().collect();
714 let field0_attrs = parse_gen_stub_attrs(&fields[0].attrs, AttributeLocation::Field, None)?;
715 if let StubGenAttr::Default(expr) = &field0_attrs[0] {
716 assert_eq!(
717 expr.to_token_stream().to_string(),
718 "String :: from (\"foo\")"
719 );
720 } else {
721 panic!("attr should be Default");
722 };
723 assert_eq!(&StubGenAttr::Skip, &field0_attrs[1]);
724 let field1_attrs = parse_gen_stub_attrs(&fields[1].attrs, AttributeLocation::Field, None)?;
725 assert_eq!(vec![StubGenAttr::Skip], field1_attrs);
726 let field2_attrs = parse_gen_stub_attrs(&fields[2].attrs, AttributeLocation::Field, None)?;
727 if let StubGenAttr::Default(expr) = &field2_attrs[0] {
728 assert_eq!(expr.to_token_stream().to_string(), "1 + 2");
729 } else {
730 panic!("attr should be Default");
731 };
732 Ok(())
733 }
734 #[test]
735 fn test_parse_gen_stub_override_type_attr() -> Result<()> {
736 let item: ItemFn = parse_str(
737 r#"
738 #[gen_stub_pyfunction]
739 #[pyfunction]
740 #[gen_stub(override_return_type(type_repr="typing.Never", imports=("typing")))]
741 fn say_hello_forever<'a>(
742 #[gen_stub(override_type(type_repr="collections.abc.Callable[[str]]", imports=("collections.abc")))]
743 cb: Bound<'a, PyAny>,
744 ) -> PyResult<()> {
745 loop {
746 cb.call1(("Hello!",))?;
747 }
748 }
749 "#,
750 )?;
751 let fn_attrs = parse_gen_stub_attrs(&item.attrs, AttributeLocation::Function, None)?;
752 assert_eq!(fn_attrs.len(), 1);
753 if let StubGenAttr::OverrideType(expr) = &fn_attrs[0] {
754 assert_eq!(
755 *expr,
756 OverrideTypeAttribute {
757 type_repr: "typing.Never".into(),
758 imports: IndexSet::from(["typing".into()])
759 }
760 );
761 } else {
762 panic!("attr should be OverrideType");
763 };
764 if let syn::FnArg::Typed(PatType { attrs, .. }) = &item.sig.inputs[0] {
765 let arg_attrs = parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)?;
766 assert_eq!(arg_attrs.len(), 1);
767 if let StubGenAttr::OverrideType(expr) = &arg_attrs[0] {
768 assert_eq!(
769 *expr,
770 OverrideTypeAttribute {
771 type_repr: "collections.abc.Callable[[str]]".into(),
772 imports: IndexSet::from(["collections.abc".into()])
773 }
774 );
775 } else {
776 panic!("attr should be OverrideType");
777 };
778 }
779 Ok(())
780 }
781}