1use indexmap::IndexSet;
2
3use super::{RenamingRule, Signature};
4use proc_macro2::{TokenStream as TokenStream2, TokenTree};
5use quote::{quote, ToTokens, TokenStreamExt};
6use syn::{
7 parenthesized,
8 parse::{Parse, ParseStream},
9 punctuated::Punctuated,
10 Attribute, Expr, ExprLit, Ident, Lit, LitStr, Meta, MetaList, Result, Token, Type,
11};
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum IgnoreTarget {
16 All,
18 SpecifiedLits(Vec<LitStr>),
20}
21
22pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
23 let mut docs = Vec::new();
24 for attr in attrs {
25 if attr.path().is_ident("doc") {
27 if let Meta::NameValue(syn::MetaNameValue {
28 value:
29 Expr::Lit(ExprLit {
30 lit: Lit::Str(doc), ..
31 }),
32 ..
33 }) = &attr.meta
34 {
35 let doc = doc.value();
36 docs.push(if !doc.is_empty() && doc.starts_with(' ') {
43 doc[1..].to_string()
44 } else {
45 doc
46 });
47 }
48 }
49 }
50 docs
51}
52
53pub fn extract_deprecated(attrs: &[Attribute]) -> Option<DeprecatedInfo> {
55 for attr in attrs {
56 if attr.path().is_ident("deprecated") {
57 if let Ok(list) = attr.meta.require_list() {
58 let mut since = None;
59 let mut note = None;
60
61 list.parse_nested_meta(|meta| {
62 if meta.path.is_ident("since") {
63 let value = meta.value()?;
64 let lit: LitStr = value.parse()?;
65 since = Some(lit.value());
66 } else if meta.path.is_ident("note") {
67 let value = meta.value()?;
68 let lit: LitStr = value.parse()?;
69 note = Some(lit.value());
70 }
71 Ok(())
72 })
73 .ok()?;
74
75 return Some(DeprecatedInfo { since, note });
76 }
77 }
78 }
79 None
80}
81
82#[derive(Debug, Clone, PartialEq)]
95pub struct DeprecatedInfo {
96 pub since: Option<String>,
97 pub note: Option<String>,
98}
99
100impl ToTokens for DeprecatedInfo {
101 fn to_tokens(&self, tokens: &mut TokenStream2) {
102 let since = self
103 .since
104 .as_ref()
105 .map(|s| quote! { Some(#s) })
106 .unwrap_or_else(|| quote! { None });
107 let note = self
108 .note
109 .as_ref()
110 .map(|n| quote! { Some(#n) })
111 .unwrap_or_else(|| quote! { None });
112 tokens.append_all(quote! {
113 ::pyo3_stub_gen::type_info::DeprecatedInfo {
114 since: #since,
115 note: #note,
116 }
117 })
118 }
119}
120
121#[derive(Debug, Clone, PartialEq)]
122#[expect(clippy::enum_variant_names)]
123pub enum Attr {
124 Name(String),
126 Get,
127 GetAll,
128 Set,
129 SetAll,
130 Module(String),
131 Constructor(Signature),
132 Signature(Signature),
133 RenameAll(RenamingRule),
134 Extends(Type),
135
136 Eq,
138 Ord,
139 Hash,
140 Str,
141 Subclass,
142
143 New,
146 Getter(Option<String>),
147 Setter(Option<String>),
148 StaticMethod,
149 ClassMethod,
150 ClassAttr,
151}
152
153pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
154 let mut out = Vec::new();
155 for attr in attrs {
156 let mut new = parse_pyo3_attr(attr)?;
157 out.append(&mut new);
158 }
159 Ok(out)
160}
161
162pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
163 let mut pyo3_attrs = Vec::new();
164 let path = attr.path();
165 let is_full_path_pyo3_attr = path.segments.len() == 2
166 && path
167 .segments
168 .first()
169 .is_some_and(|seg| seg.ident.eq("pyo3"))
170 && path.segments.last().is_some_and(|seg| {
171 seg.ident.eq("pyclass") || seg.ident.eq("pymethods") || seg.ident.eq("pyfunction")
172 });
173 if path.is_ident("pyclass")
174 || path.is_ident("pymethods")
175 || path.is_ident("pyfunction")
176 || path.is_ident("pyo3")
177 || is_full_path_pyo3_attr
178 {
179 if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
184 use TokenTree::*;
185 let tokens: Vec<TokenTree> = tokens.clone().into_iter().collect();
186 for tt in tokens.split(|tt| {
189 if let Punct(p) = tt {
190 p.as_char() == ','
191 } else {
192 false
193 }
194 }) {
195 match tt {
196 [Ident(ident)] => {
197 if ident == "get" {
198 pyo3_attrs.push(Attr::Get);
199 }
200 if ident == "get_all" {
201 pyo3_attrs.push(Attr::GetAll);
202 }
203 if ident == "set" {
204 pyo3_attrs.push(Attr::Set);
205 }
206 if ident == "set_all" {
207 pyo3_attrs.push(Attr::SetAll);
208 }
209 if ident == "eq" {
210 pyo3_attrs.push(Attr::Eq);
211 }
212 if ident == "ord" {
213 pyo3_attrs.push(Attr::Ord);
214 }
215 if ident == "hash" {
216 pyo3_attrs.push(Attr::Hash);
217 }
218 if ident == "str" {
219 pyo3_attrs.push(Attr::Str);
220 }
221 if ident == "subclass" {
222 pyo3_attrs.push(Attr::Subclass);
223 }
224 }
226 [Ident(ident), Punct(_), Literal(lit)] => {
227 if ident == "name" {
228 pyo3_attrs
229 .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
230 }
231 if ident == "module" {
232 pyo3_attrs
233 .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
234 }
235 if ident == "rename_all" {
236 let name = lit.to_string().trim_matches('"').to_string();
237 if let Some(renaming_rule) = RenamingRule::try_new(&name) {
238 pyo3_attrs.push(Attr::RenameAll(renaming_rule));
239 }
240 }
241 }
242 [Ident(ident), Punct(_), Group(group)] => {
243 if ident == "signature" {
244 pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
245 } else if ident == "constructor" {
246 pyo3_attrs
247 .push(Attr::Constructor(syn::parse2(group.to_token_stream())?));
248 }
249 }
250 [Ident(ident), Punct(_), Ident(ident2)] => {
251 if ident == "extends" {
252 pyo3_attrs.push(Attr::Extends(syn::parse2(ident2.to_token_stream())?));
253 }
254 }
255 _ => {}
256 }
257 }
258 }
259 } else if path.is_ident("new") {
260 pyo3_attrs.push(Attr::New);
261 } else if path.is_ident("staticmethod") {
262 pyo3_attrs.push(Attr::StaticMethod);
263 } else if path.is_ident("classmethod") {
264 pyo3_attrs.push(Attr::ClassMethod);
265 } else if path.is_ident("classattr") {
266 pyo3_attrs.push(Attr::ClassAttr);
267 } else if path.is_ident("getter") {
268 if let Ok(inner) = attr.parse_args::<Ident>() {
269 pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
270 } else {
271 pyo3_attrs.push(Attr::Getter(None));
272 }
273 } else if path.is_ident("setter") {
274 if let Ok(inner) = attr.parse_args::<Ident>() {
275 pyo3_attrs.push(Attr::Setter(Some(inner.to_string())));
276 } else {
277 pyo3_attrs.push(Attr::Setter(None));
278 }
279 }
280
281 Ok(pyo3_attrs)
282}
283
284#[derive(Debug, Clone, PartialEq)]
285pub enum StubGenAttr {
286 Default(Expr),
288 Skip,
290 OverrideType(OverrideTypeAttribute),
292 TypeIgnore(IgnoreTarget),
294}
295
296pub fn prune_attrs(attrs: &mut Vec<Attribute>) {
297 attrs.retain(|attr| !attr.path().is_ident("gen_stub"));
298}
299
300pub fn parse_gen_stub_override_type(attrs: &[Attribute]) -> Result<Option<OverrideTypeAttribute>> {
301 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)? {
302 if let StubGenAttr::OverrideType(attr) = attr {
303 return Ok(Some(attr));
304 }
305 }
306 Ok(None)
307}
308
309pub fn parse_gen_stub_override_return_type(
310 attrs: &[Attribute],
311) -> Result<Option<OverrideTypeAttribute>> {
312 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
313 if let StubGenAttr::OverrideType(attr) = attr {
314 return Ok(Some(attr));
315 }
316 }
317 Ok(None)
318}
319
320pub fn parse_gen_stub_default(attrs: &[Attribute]) -> Result<Option<Expr>> {
321 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
322 if let StubGenAttr::Default(default) = attr {
323 return Ok(Some(default));
324 }
325 }
326 Ok(None)
327}
328pub fn parse_gen_stub_skip(attrs: &[Attribute]) -> Result<bool> {
329 let skip = parse_gen_stub_attrs(
330 attrs,
331 AttributeLocation::Field,
332 Some(&["override_return_type", "default"]),
333 )?
334 .iter()
335 .any(|attr| matches!(attr, StubGenAttr::Skip));
336 Ok(skip)
337}
338
339pub fn parse_gen_stub_type_ignore(attrs: &[Attribute]) -> Result<Option<IgnoreTarget>> {
340 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
342 if let StubGenAttr::TypeIgnore(target) = attr {
343 return Ok(Some(target));
344 }
345 }
346 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Field, None)? {
348 if let StubGenAttr::TypeIgnore(target) = attr {
349 return Ok(Some(target));
350 }
351 }
352 Ok(None)
353}
354
355fn parse_gen_stub_attrs(
356 attrs: &[Attribute],
357 location: AttributeLocation,
358 ignored_idents: Option<&[&str]>,
359) -> Result<Vec<StubGenAttr>> {
360 let mut out = Vec::new();
361 for attr in attrs {
362 let mut new = parse_gen_stub_attr(attr, location, ignored_idents.unwrap_or(&[]))?;
363 out.append(&mut new);
364 }
365 Ok(out)
366}
367
368fn parse_gen_stub_attr(
369 attr: &Attribute,
370 location: AttributeLocation,
371 ignored_idents: &[&str],
372) -> Result<Vec<StubGenAttr>> {
373 let mut gen_stub_attrs = Vec::new();
374 let path = attr.path();
375 if path.is_ident("gen_stub") {
376 attr.parse_args_with(|input: ParseStream| {
377 while !input.is_empty() {
378 let ident: Ident = input.parse()?;
379 let ignored_ident = ignored_idents.iter().any(|other| ident == other);
380 if (ident == "override_type"
381 && (location == AttributeLocation::Argument || ignored_ident))
382 || (ident == "override_return_type"
383 && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident))
384 {
385 let content;
386 parenthesized!(content in input);
387 let override_attr: OverrideTypeAttribute = content.parse()?;
388 gen_stub_attrs.push(StubGenAttr::OverrideType(override_attr));
389 } else if ident == "skip" && (location == AttributeLocation::Field || ignored_ident)
390 {
391 gen_stub_attrs.push(StubGenAttr::Skip);
392 } else if ident == "default"
393 && input.peek(Token![=])
394 && (location == AttributeLocation::Field || location == AttributeLocation::Function || ignored_ident)
395 {
396 input.parse::<Token![=]>()?;
397 gen_stub_attrs.push(StubGenAttr::Default(input.parse()?));
398 } else if ident == "type_ignore"
399 && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident)
400 {
401 if input.peek(Token![=]) {
405 input.parse::<Token![=]>()?;
406 let content;
408 syn::bracketed!(content in input);
409 let rules = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
410
411 if rules.is_empty() {
413 return Err(syn::Error::new(
414 ident.span(),
415 "type_ignore with empty array is not allowed. Use type_ignore without equals for catch-all, or specify rules in the array."
416 ));
417 }
418
419 let rule_lits: Vec<LitStr> = rules.into_iter().collect();
421 gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::SpecifiedLits(rule_lits)));
422 } else {
423 gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::All));
425 }
426 } else if ident == "override_type" {
427 return Err(syn::Error::new(
428 ident.span(),
429 "`override_type(...)` is only valid in argument position".to_string(),
430 ));
431 } else if ident == "override_return_type" {
432 return Err(syn::Error::new(
433 ident.span(),
434 "`override_return_type(...)` is only valid in function or method position"
435 .to_string(),
436 ));
437 } else if ident == "skip" {
438 return Err(syn::Error::new(
439 ident.span(),
440 "`skip` is only valid in field position".to_string(),
441 ));
442 } else if ident == "default" {
443 return Err(syn::Error::new(
444 ident.span(),
445 "`default=xxx` is only valid in field or function position".to_string(),
446 ));
447 } else if ident == "type_ignore" {
448 return Err(syn::Error::new(
449 ident.span(),
450 "`type_ignore` or `type_ignore=[...]` is only valid in function or method position".to_string(),
451 ));
452 } else if location == AttributeLocation::Argument {
453 return Err(syn::Error::new(
454 ident.span(),
455 format!("Unsupported keyword `{ident}`, valid is `override_type(...)`"),
456 ));
457 } else if location == AttributeLocation::Field {
458 return Err(syn::Error::new(
459 ident.span(),
460 format!("Unsupported keyword `{ident}`, valid is `default=xxx`, `skip`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"),
461 ));
462 } else if location == AttributeLocation::Function {
463 return Err(syn::Error::new(
464 ident.span(),
465 format!(
466 "Unsupported keyword `{ident}`, valid is `default=xxx`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"
467 ),
468 ));
469 } else {
470 return Err(syn::Error::new(
471 ident.span(),
472 format!("Unsupported keyword `{ident}`"),
473 ));
474 }
475 if input.peek(Token![,]) {
476 input.parse::<Token![,]>()?;
477 } else {
478 break;
479 }
480 }
481 Ok(())
482 })?;
483 }
484 Ok(gen_stub_attrs)
485}
486
487#[derive(Debug, Clone, Copy, PartialEq)]
488pub(crate) enum AttributeLocation {
489 Argument,
490 Field,
491 Function,
492}
493
494#[derive(Debug, Clone, PartialEq)]
495pub struct OverrideTypeAttribute {
496 pub(crate) type_repr: String,
497 pub(crate) imports: IndexSet<String>,
498}
499
500mod kw {
501 syn::custom_keyword!(type_repr);
502 syn::custom_keyword!(imports);
503 syn::custom_keyword!(override_type);
504}
505
506impl Parse for OverrideTypeAttribute {
507 fn parse(input: ParseStream) -> Result<Self> {
508 let mut type_repr = None;
509 let mut imports = IndexSet::new();
510
511 while !input.is_empty() {
512 let lookahead = input.lookahead1();
513
514 if lookahead.peek(kw::type_repr) {
515 input.parse::<kw::type_repr>()?;
516 input.parse::<Token![=]>()?;
517 type_repr = Some(input.parse::<LitStr>()?);
518 } else if lookahead.peek(kw::imports) {
519 input.parse::<kw::imports>()?;
520 input.parse::<Token![=]>()?;
521
522 let content;
523 parenthesized!(content in input);
524 let parsed_imports = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
525 imports = parsed_imports.into_iter().collect();
526 } else {
527 return Err(lookahead.error());
528 }
529
530 if !input.is_empty() {
531 input.parse::<Token![,]>()?;
532 }
533 }
534
535 Ok(OverrideTypeAttribute {
536 type_repr: type_repr
537 .ok_or_else(|| input.error("missing type_repr"))?
538 .value(),
539 imports: imports.iter().map(|i| i.value()).collect(),
540 })
541 }
542}
543
544#[cfg(test)]
545mod test {
546 use super::*;
547 use syn::{parse_str, Fields, ItemFn, ItemStruct, PatType};
548
549 #[test]
550 fn test_parse_pyo3_attr() -> Result<()> {
551 let item: ItemStruct = parse_str(
552 r#"
553 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
554 #[pyo3(rename_all = "SCREAMING_SNAKE_CASE")]
555 pub struct PyPlaceholder {
556 #[pyo3(get)]
557 pub name: String,
558 }
559 "#,
560 )?;
561 let attrs = parse_pyo3_attrs(&item.attrs)?;
563 assert_eq!(
564 attrs,
565 vec![
566 Attr::Module("my_module".to_string()),
567 Attr::Name("Placeholder".to_string()),
568 Attr::RenameAll(RenamingRule::ScreamingSnakeCase),
569 ]
570 );
571
572 if let Fields::Named(fields) = item.fields {
574 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
575 assert_eq!(attrs, vec![Attr::Get]);
576 } else {
577 unreachable!()
578 }
579 Ok(())
580 }
581
582 #[test]
583 fn test_parse_pyo3_attr_full_path() -> Result<()> {
584 let item: ItemStruct = parse_str(
585 r#"
586 #[pyo3::pyclass(mapping, module = "my_module", name = "Placeholder")]
587 pub struct PyPlaceholder {
588 #[pyo3(get)]
589 pub name: String,
590 }
591 "#,
592 )?;
593 let attrs = parse_pyo3_attr(&item.attrs[0])?;
595 assert_eq!(
596 attrs,
597 vec![
598 Attr::Module("my_module".to_string()),
599 Attr::Name("Placeholder".to_string())
600 ]
601 );
602
603 if let Fields::Named(fields) = item.fields {
605 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
606 assert_eq!(attrs, vec![Attr::Get]);
607 } else {
608 unreachable!()
609 }
610 Ok(())
611 }
612 #[test]
613 fn test_parse_gen_stub_field_attr() -> Result<()> {
614 let item: ItemStruct = parse_str(
615 r#"
616 pub struct PyPlaceholder {
617 #[gen_stub(default = String::from("foo"), skip)]
618 pub field0: String,
619 #[gen_stub(skip)]
620 pub field1: String,
621 #[gen_stub(default = 1+2)]
622 pub field2: usize,
623 }
624 "#,
625 )?;
626 let fields: Vec<_> = item.fields.into_iter().collect();
627 let field0_attrs = parse_gen_stub_attrs(&fields[0].attrs, AttributeLocation::Field, None)?;
628 if let StubGenAttr::Default(expr) = &field0_attrs[0] {
629 assert_eq!(
630 expr.to_token_stream().to_string(),
631 "String :: from (\"foo\")"
632 );
633 } else {
634 panic!("attr should be Default");
635 };
636 assert_eq!(&StubGenAttr::Skip, &field0_attrs[1]);
637 let field1_attrs = parse_gen_stub_attrs(&fields[1].attrs, AttributeLocation::Field, None)?;
638 assert_eq!(vec![StubGenAttr::Skip], field1_attrs);
639 let field2_attrs = parse_gen_stub_attrs(&fields[2].attrs, AttributeLocation::Field, None)?;
640 if let StubGenAttr::Default(expr) = &field2_attrs[0] {
641 assert_eq!(expr.to_token_stream().to_string(), "1 + 2");
642 } else {
643 panic!("attr should be Default");
644 };
645 Ok(())
646 }
647 #[test]
648 fn test_parse_gen_stub_override_type_attr() -> Result<()> {
649 let item: ItemFn = parse_str(
650 r#"
651 #[gen_stub_pyfunction]
652 #[pyfunction]
653 #[gen_stub(override_return_type(type_repr="typing.Never", imports=("typing")))]
654 fn say_hello_forever<'a>(
655 #[gen_stub(override_type(type_repr="collections.abc.Callable[[str]]", imports=("collections.abc")))]
656 cb: Bound<'a, PyAny>,
657 ) -> PyResult<()> {
658 loop {
659 cb.call1(("Hello!",))?;
660 }
661 }
662 "#,
663 )?;
664 let fn_attrs = parse_gen_stub_attrs(&item.attrs, AttributeLocation::Function, None)?;
665 assert_eq!(fn_attrs.len(), 1);
666 if let StubGenAttr::OverrideType(expr) = &fn_attrs[0] {
667 assert_eq!(
668 *expr,
669 OverrideTypeAttribute {
670 type_repr: "typing.Never".into(),
671 imports: IndexSet::from(["typing".into()])
672 }
673 );
674 } else {
675 panic!("attr should be OverrideType");
676 };
677 if let syn::FnArg::Typed(PatType { attrs, .. }) = &item.sig.inputs[0] {
678 let arg_attrs = parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)?;
679 assert_eq!(arg_attrs.len(), 1);
680 if let StubGenAttr::OverrideType(expr) = &arg_attrs[0] {
681 assert_eq!(
682 *expr,
683 OverrideTypeAttribute {
684 type_repr: "collections.abc.Callable[[str]]".into(),
685 imports: IndexSet::from(["collections.abc".into()])
686 }
687 );
688 } else {
689 panic!("attr should be OverrideType");
690 };
691 }
692 Ok(())
693 }
694}