1use indexmap::IndexSet;
2
3use super::{RenamingRule, Signature};
4use proc_macro2::{TokenStream as TokenStream2, TokenTree};
5use quote::{quote, ToTokens, TokenStreamExt};
6use syn::{
7 parenthesized,
8 parse::{Parse, ParseStream},
9 punctuated::Punctuated,
10 Attribute, Expr, ExprLit, Ident, Lit, LitStr, Meta, MetaList, Result, Token, Type,
11};
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum IgnoreTarget {
16 All,
18 SpecifiedLits(Vec<LitStr>),
20}
21
22pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
23 let mut docs = Vec::new();
24 for attr in attrs {
25 if attr.path().is_ident("doc") {
27 if let Meta::NameValue(syn::MetaNameValue {
28 value:
29 Expr::Lit(ExprLit {
30 lit: Lit::Str(doc), ..
31 }),
32 ..
33 }) = &attr.meta
34 {
35 let doc = doc.value();
36 docs.push(if !doc.is_empty() && doc.starts_with(' ') {
43 doc[1..].to_string()
44 } else {
45 doc
46 });
47 }
48 }
49 }
50 docs
51}
52
53pub fn extract_deprecated(attrs: &[Attribute]) -> Option<DeprecatedInfo> {
55 for attr in attrs {
56 if attr.path().is_ident("deprecated") {
57 if let Ok(list) = attr.meta.require_list() {
58 let mut since = None;
59 let mut note = None;
60
61 list.parse_nested_meta(|meta| {
62 if meta.path.is_ident("since") {
63 let value = meta.value()?;
64 let lit: LitStr = value.parse()?;
65 since = Some(lit.value());
66 } else if meta.path.is_ident("note") {
67 let value = meta.value()?;
68 let lit: LitStr = value.parse()?;
69 note = Some(lit.value());
70 }
71 Ok(())
72 })
73 .ok()?;
74
75 return Some(DeprecatedInfo { since, note });
76 }
77 }
78 }
79 None
80}
81
82#[derive(Debug, Clone, PartialEq)]
95pub struct DeprecatedInfo {
96 pub since: Option<String>,
97 pub note: Option<String>,
98}
99
100impl ToTokens for DeprecatedInfo {
101 fn to_tokens(&self, tokens: &mut TokenStream2) {
102 let since = self
103 .since
104 .as_ref()
105 .map(|s| quote! { Some(#s) })
106 .unwrap_or_else(|| quote! { None });
107 let note = self
108 .note
109 .as_ref()
110 .map(|n| quote! { Some(#n) })
111 .unwrap_or_else(|| quote! { None });
112 tokens.append_all(quote! {
113 ::pyo3_stub_gen::type_info::DeprecatedInfo {
114 since: #since,
115 note: #note,
116 }
117 })
118 }
119}
120
121#[derive(Debug, Clone, PartialEq)]
122#[expect(clippy::enum_variant_names)]
123pub enum Attr {
124 Name(String),
126 Get,
127 GetAll,
128 Set,
129 SetAll,
130 Module(String),
131 Constructor(Signature),
132 Signature(Signature),
133 RenameAll(RenamingRule),
134 Extends(Type),
135
136 Eq,
138 Ord,
139 Hash,
140 Str,
141 Subclass,
142
143 New,
146 Getter(Option<String>),
147 Setter(Option<String>),
148 StaticMethod,
149 ClassMethod,
150 ClassAttr,
151}
152
153pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
154 let mut out = Vec::new();
155 for attr in attrs {
156 let mut new = parse_pyo3_attr(attr)?;
157 out.append(&mut new);
158 }
159 Ok(out)
160}
161
162pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
163 let mut pyo3_attrs = Vec::new();
164 let path = attr.path();
165 let is_full_path_pyo3_attr = path.segments.len() == 2
166 && path
167 .segments
168 .first()
169 .is_some_and(|seg| seg.ident.eq("pyo3"))
170 && path.segments.last().is_some_and(|seg| {
171 seg.ident.eq("pyclass") || seg.ident.eq("pymethods") || seg.ident.eq("pyfunction")
172 });
173 if path.is_ident("pyclass")
174 || path.is_ident("pymethods")
175 || path.is_ident("pyfunction")
176 || path.is_ident("pyo3")
177 || is_full_path_pyo3_attr
178 {
179 if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
184 use TokenTree::*;
185 let tokens: Vec<TokenTree> = tokens.clone().into_iter().collect();
186 for tt in tokens.split(|tt| {
189 if let Punct(p) = tt {
190 p.as_char() == ','
191 } else {
192 false
193 }
194 }) {
195 match tt {
196 [Ident(ident)] => {
197 if ident == "get" {
198 pyo3_attrs.push(Attr::Get);
199 }
200 if ident == "get_all" {
201 pyo3_attrs.push(Attr::GetAll);
202 }
203 if ident == "set" {
204 pyo3_attrs.push(Attr::Set);
205 }
206 if ident == "set_all" {
207 pyo3_attrs.push(Attr::SetAll);
208 }
209 if ident == "eq" {
210 pyo3_attrs.push(Attr::Eq);
211 }
212 if ident == "ord" {
213 pyo3_attrs.push(Attr::Ord);
214 }
215 if ident == "hash" {
216 pyo3_attrs.push(Attr::Hash);
217 }
218 if ident == "str" {
219 pyo3_attrs.push(Attr::Str);
220 }
221 if ident == "subclass" {
222 pyo3_attrs.push(Attr::Subclass);
223 }
224 }
226 [Ident(ident), Punct(_), Literal(lit)] => {
227 if ident == "name" {
228 pyo3_attrs
229 .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
230 }
231 if ident == "module" {
232 pyo3_attrs
233 .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
234 }
235 if ident == "rename_all" {
236 let name = lit.to_string().trim_matches('"').to_string();
237 if let Some(renaming_rule) = RenamingRule::try_new(&name) {
238 pyo3_attrs.push(Attr::RenameAll(renaming_rule));
239 }
240 }
241 }
242 [Ident(ident), Punct(_), Group(group)] => {
243 if ident == "signature" {
244 pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
245 } else if ident == "constructor" {
246 pyo3_attrs
247 .push(Attr::Constructor(syn::parse2(group.to_token_stream())?));
248 }
249 }
250 [Ident(ident), Punct(_), Ident(ident2)] => {
251 if ident == "extends" {
252 pyo3_attrs.push(Attr::Extends(syn::parse2(ident2.to_token_stream())?));
253 }
254 }
255 _ => {}
256 }
257 }
258 }
259 } else if path.is_ident("new") {
260 pyo3_attrs.push(Attr::New);
261 } else if path.is_ident("staticmethod") {
262 pyo3_attrs.push(Attr::StaticMethod);
263 } else if path.is_ident("classmethod") {
264 pyo3_attrs.push(Attr::ClassMethod);
265 } else if path.is_ident("classattr") {
266 pyo3_attrs.push(Attr::ClassAttr);
267 } else if path.is_ident("getter") {
268 if let Ok(inner) = attr.parse_args::<Ident>() {
269 pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
270 } else {
271 pyo3_attrs.push(Attr::Getter(None));
272 }
273 } else if path.is_ident("setter") {
274 if let Ok(inner) = attr.parse_args::<Ident>() {
275 pyo3_attrs.push(Attr::Setter(Some(inner.to_string())));
276 } else {
277 pyo3_attrs.push(Attr::Setter(None));
278 }
279 }
280
281 Ok(pyo3_attrs)
282}
283
284#[derive(Debug, Clone, PartialEq)]
285pub enum StubGenAttr {
286 Default(Expr),
288 Skip,
290 OverrideType(OverrideTypeAttribute),
292 TypeIgnore(IgnoreTarget),
294}
295
296pub fn prune_attrs(attrs: &mut Vec<Attribute>) {
297 attrs.retain(|attr| !attr.path().is_ident("gen_stub"));
298}
299
300pub fn parse_gen_stub_override_type(attrs: &[Attribute]) -> Result<Option<OverrideTypeAttribute>> {
301 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)? {
302 if let StubGenAttr::OverrideType(attr) = attr {
303 return Ok(Some(attr));
304 }
305 }
306 Ok(None)
307}
308
309pub fn parse_gen_stub_override_return_type(
310 attrs: &[Attribute],
311) -> Result<Option<OverrideTypeAttribute>> {
312 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
313 if let StubGenAttr::OverrideType(attr) = attr {
314 return Ok(Some(attr));
315 }
316 }
317 Ok(None)
318}
319
320pub fn parse_gen_stub_default(attrs: &[Attribute]) -> Result<Option<Expr>> {
321 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
322 if let StubGenAttr::Default(default) = attr {
323 return Ok(Some(default));
324 }
325 }
326 Ok(None)
327}
328pub fn parse_gen_stub_skip(attrs: &[Attribute]) -> Result<bool> {
329 let skip = parse_gen_stub_attrs(
330 attrs,
331 AttributeLocation::Field,
332 Some(&["override_return_type", "default"]),
333 )?
334 .iter()
335 .any(|attr| matches!(attr, StubGenAttr::Skip));
336 Ok(skip)
337}
338
339pub fn parse_gen_stub_type_ignore(attrs: &[Attribute]) -> Result<Option<IgnoreTarget>> {
340 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
342 if let StubGenAttr::TypeIgnore(target) = attr {
343 return Ok(Some(target));
344 }
345 }
346 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Field, None)? {
348 if let StubGenAttr::TypeIgnore(target) = attr {
349 return Ok(Some(target));
350 }
351 }
352 Ok(None)
353}
354
355fn parse_gen_stub_attrs(
356 attrs: &[Attribute],
357 location: AttributeLocation,
358 ignored_idents: Option<&[&str]>,
359) -> Result<Vec<StubGenAttr>> {
360 let mut out = Vec::new();
361 for attr in attrs {
362 let mut new = parse_gen_stub_attr(attr, location, ignored_idents.unwrap_or(&[]))?;
363 out.append(&mut new);
364 }
365 Ok(out)
366}
367
368fn parse_gen_stub_attr(
369 attr: &Attribute,
370 location: AttributeLocation,
371 ignored_idents: &[&str],
372) -> Result<Vec<StubGenAttr>> {
373 let mut gen_stub_attrs = Vec::new();
374 let path = attr.path();
375 if path.is_ident("gen_stub") {
376 attr.parse_args_with(|input: ParseStream| {
377 while !input.is_empty() {
378 let ident: Ident = input.parse()?;
379 let ignored_ident = ignored_idents.iter().any(|other| ident == other);
380 if (ident == "override_type"
381 && (location == AttributeLocation::Argument || ignored_ident))
382 || (ident == "override_return_type"
383 && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident))
384 {
385 let content;
386 parenthesized!(content in input);
387 let override_attr: OverrideTypeAttribute = content.parse()?;
388 gen_stub_attrs.push(StubGenAttr::OverrideType(override_attr));
389 } else if ident == "skip" && (location == AttributeLocation::Field || ignored_ident)
390 {
391 gen_stub_attrs.push(StubGenAttr::Skip);
392 } else if ident == "default"
393 && input.peek(Token![=])
394 && (location == AttributeLocation::Field || location == AttributeLocation::Function || ignored_ident)
395 {
396 input.parse::<Token![=]>()?;
397 gen_stub_attrs.push(StubGenAttr::Default(input.parse()?));
398 } else if ident == "type_ignore"
399 && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident)
400 {
401 if input.peek(Token![=]) {
405 input.parse::<Token![=]>()?;
406 let content;
408 syn::bracketed!(content in input);
409 let rules = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
410
411 if rules.is_empty() {
413 return Err(syn::Error::new(
414 ident.span(),
415 "type_ignore with empty array is not allowed. Use type_ignore without equals for catch-all, or specify rules in the array."
416 ));
417 }
418
419 let rule_lits: Vec<LitStr> = rules.into_iter().collect();
421 gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::SpecifiedLits(rule_lits)));
422 } else {
423 gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::All));
425 }
426 } else if ident == "override_type" {
427 return Err(syn::Error::new(
428 ident.span(),
429 "`override_type(...)` is only valid in argument position".to_string(),
430 ));
431 } else if ident == "override_return_type" {
432 return Err(syn::Error::new(
433 ident.span(),
434 "`override_return_type(...)` is only valid in function or method position"
435 .to_string(),
436 ));
437 } else if ident == "skip" {
438 return Err(syn::Error::new(
439 ident.span(),
440 "`skip` is only valid in field position".to_string(),
441 ));
442 } else if ident == "default" {
443 return Err(syn::Error::new(
444 ident.span(),
445 "`default=xxx` is only valid in field or function position".to_string(),
446 ));
447 } else if ident == "type_ignore" {
448 return Err(syn::Error::new(
449 ident.span(),
450 "`type_ignore` or `type_ignore=[...]` is only valid in function or method position".to_string(),
451 ));
452 } else if location == AttributeLocation::Argument {
453 return Err(syn::Error::new(
454 ident.span(),
455 format!("Unsupported keyword `{ident}`, valid is `override_type(...)`"),
456 ));
457 } else if location == AttributeLocation::Field {
458 return Err(syn::Error::new(
459 ident.span(),
460 format!("Unsupported keyword `{ident}`, valid is `default=xxx`, `skip`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"),
461 ));
462 } else if location == AttributeLocation::Function {
463 return Err(syn::Error::new(
464 ident.span(),
465 format!(
466 "Unsupported keyword `{ident}`, valid is `default=xxx`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"
467 ),
468 ));
469 } else {
470 return Err(syn::Error::new(
471 ident.span(),
472 format!("Unsupported keyword `{ident}`"),
473 ));
474 }
475 if input.peek(Token![,]) {
476 input.parse::<Token![,]>()?;
477 } else {
478 break;
479 }
480 }
481 Ok(())
482 })?;
483 }
484 Ok(gen_stub_attrs)
485}
486
487#[derive(Debug, Clone, Copy, PartialEq)]
488pub(crate) enum AttributeLocation {
489 Argument,
490 Field,
491 Function,
492}
493
494#[derive(Debug, Clone, PartialEq)]
495pub struct OverrideTypeAttribute {
496 pub(crate) type_repr: String,
497 pub(crate) imports: IndexSet<String>,
498}
499
500mod kw {
501 syn::custom_keyword!(type_repr);
502 syn::custom_keyword!(imports);
503 syn::custom_keyword!(override_type);
504}
505
506impl Parse for OverrideTypeAttribute {
507 fn parse(input: ParseStream) -> Result<Self> {
508 let mut type_repr = None;
509 let mut imports = IndexSet::new();
510
511 while !input.is_empty() {
512 let lookahead = input.lookahead1();
513
514 if lookahead.peek(kw::type_repr) {
515 input.parse::<kw::type_repr>()?;
516 input.parse::<Token![=]>()?;
517 type_repr = Some(input.parse::<LitStr>()?);
518 } else if lookahead.peek(kw::imports) {
519 input.parse::<kw::imports>()?;
520 input.parse::<Token![=]>()?;
521
522 let content;
523 parenthesized!(content in input);
524 let parsed_imports = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
525 imports = parsed_imports.into_iter().collect();
526 } else {
527 return Err(lookahead.error());
528 }
529
530 if !input.is_empty() {
531 input.parse::<Token![,]>()?;
532 }
533 }
534
535 Ok(OverrideTypeAttribute {
536 type_repr: type_repr
537 .ok_or_else(|| input.error("missing type_repr"))?
538 .value(),
539 imports: imports.iter().map(|i| i.value()).collect(),
540 })
541 }
542}
543
544#[derive(Default)]
547pub struct PyClassAttr {
548 pub skip_stub_type: bool,
549}
550
551impl Parse for PyClassAttr {
552 fn parse(input: ParseStream) -> Result<Self> {
553 let mut skip_stub_type = false;
554
555 while !input.is_empty() {
557 let key: Ident = input.parse()?;
558
559 match key.to_string().as_str() {
560 "skip_stub_type" => {
561 skip_stub_type = true;
562 }
563 _ => {
564 return Err(syn::Error::new(
565 key.span(),
566 format!("Unknown parameter: {}", key),
567 ));
568 }
569 }
570
571 if input.peek(Token![,]) {
573 let _: Token![,] = input.parse()?;
574 } else {
575 break;
576 }
577 }
578
579 Ok(Self { skip_stub_type })
580 }
581}
582
583#[cfg(test)]
584mod test {
585 use super::*;
586 use syn::{parse_str, Fields, ItemFn, ItemStruct, PatType};
587
588 #[test]
589 fn test_parse_pyo3_attr() -> Result<()> {
590 let item: ItemStruct = parse_str(
591 r#"
592 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
593 #[pyo3(rename_all = "SCREAMING_SNAKE_CASE")]
594 pub struct PyPlaceholder {
595 #[pyo3(get)]
596 pub name: String,
597 }
598 "#,
599 )?;
600 let attrs = parse_pyo3_attrs(&item.attrs)?;
602 assert_eq!(
603 attrs,
604 vec![
605 Attr::Module("my_module".to_string()),
606 Attr::Name("Placeholder".to_string()),
607 Attr::RenameAll(RenamingRule::ScreamingSnakeCase),
608 ]
609 );
610
611 if let Fields::Named(fields) = item.fields {
613 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
614 assert_eq!(attrs, vec![Attr::Get]);
615 } else {
616 unreachable!()
617 }
618 Ok(())
619 }
620
621 #[test]
622 fn test_parse_pyo3_attr_full_path() -> Result<()> {
623 let item: ItemStruct = parse_str(
624 r#"
625 #[pyo3::pyclass(mapping, module = "my_module", name = "Placeholder")]
626 pub struct PyPlaceholder {
627 #[pyo3(get)]
628 pub name: String,
629 }
630 "#,
631 )?;
632 let attrs = parse_pyo3_attr(&item.attrs[0])?;
634 assert_eq!(
635 attrs,
636 vec![
637 Attr::Module("my_module".to_string()),
638 Attr::Name("Placeholder".to_string())
639 ]
640 );
641
642 if let Fields::Named(fields) = item.fields {
644 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
645 assert_eq!(attrs, vec![Attr::Get]);
646 } else {
647 unreachable!()
648 }
649 Ok(())
650 }
651 #[test]
652 fn test_parse_gen_stub_field_attr() -> Result<()> {
653 let item: ItemStruct = parse_str(
654 r#"
655 pub struct PyPlaceholder {
656 #[gen_stub(default = String::from("foo"), skip)]
657 pub field0: String,
658 #[gen_stub(skip)]
659 pub field1: String,
660 #[gen_stub(default = 1+2)]
661 pub field2: usize,
662 }
663 "#,
664 )?;
665 let fields: Vec<_> = item.fields.into_iter().collect();
666 let field0_attrs = parse_gen_stub_attrs(&fields[0].attrs, AttributeLocation::Field, None)?;
667 if let StubGenAttr::Default(expr) = &field0_attrs[0] {
668 assert_eq!(
669 expr.to_token_stream().to_string(),
670 "String :: from (\"foo\")"
671 );
672 } else {
673 panic!("attr should be Default");
674 };
675 assert_eq!(&StubGenAttr::Skip, &field0_attrs[1]);
676 let field1_attrs = parse_gen_stub_attrs(&fields[1].attrs, AttributeLocation::Field, None)?;
677 assert_eq!(vec![StubGenAttr::Skip], field1_attrs);
678 let field2_attrs = parse_gen_stub_attrs(&fields[2].attrs, AttributeLocation::Field, None)?;
679 if let StubGenAttr::Default(expr) = &field2_attrs[0] {
680 assert_eq!(expr.to_token_stream().to_string(), "1 + 2");
681 } else {
682 panic!("attr should be Default");
683 };
684 Ok(())
685 }
686 #[test]
687 fn test_parse_gen_stub_override_type_attr() -> Result<()> {
688 let item: ItemFn = parse_str(
689 r#"
690 #[gen_stub_pyfunction]
691 #[pyfunction]
692 #[gen_stub(override_return_type(type_repr="typing.Never", imports=("typing")))]
693 fn say_hello_forever<'a>(
694 #[gen_stub(override_type(type_repr="collections.abc.Callable[[str]]", imports=("collections.abc")))]
695 cb: Bound<'a, PyAny>,
696 ) -> PyResult<()> {
697 loop {
698 cb.call1(("Hello!",))?;
699 }
700 }
701 "#,
702 )?;
703 let fn_attrs = parse_gen_stub_attrs(&item.attrs, AttributeLocation::Function, None)?;
704 assert_eq!(fn_attrs.len(), 1);
705 if let StubGenAttr::OverrideType(expr) = &fn_attrs[0] {
706 assert_eq!(
707 *expr,
708 OverrideTypeAttribute {
709 type_repr: "typing.Never".into(),
710 imports: IndexSet::from(["typing".into()])
711 }
712 );
713 } else {
714 panic!("attr should be OverrideType");
715 };
716 if let syn::FnArg::Typed(PatType { attrs, .. }) = &item.sig.inputs[0] {
717 let arg_attrs = parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)?;
718 assert_eq!(arg_attrs.len(), 1);
719 if let StubGenAttr::OverrideType(expr) = &arg_attrs[0] {
720 assert_eq!(
721 *expr,
722 OverrideTypeAttribute {
723 type_repr: "collections.abc.Callable[[str]]".into(),
724 imports: IndexSet::from(["collections.abc".into()])
725 }
726 );
727 } else {
728 panic!("attr should be OverrideType");
729 };
730 }
731 Ok(())
732 }
733}