1use std::collections::HashSet;
2
3use super::{RenamingRule, Signature};
4use proc_macro2::{TokenStream as TokenStream2, TokenTree};
5use quote::{quote, ToTokens, TokenStreamExt};
6use syn::{
7 parenthesized,
8 parse::{Parse, ParseStream},
9 punctuated::Punctuated,
10 Attribute, Expr, ExprLit, Ident, Lit, LitStr, Meta, MetaList, Result, Token, Type,
11};
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum IgnoreTarget {
16 All,
18 SpecifiedLits(Vec<LitStr>),
20}
21
22pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
23 let mut docs = Vec::new();
24 for attr in attrs {
25 if attr.path().is_ident("doc") {
27 if let Meta::NameValue(syn::MetaNameValue {
28 value:
29 Expr::Lit(ExprLit {
30 lit: Lit::Str(doc), ..
31 }),
32 ..
33 }) = &attr.meta
34 {
35 let doc = doc.value();
36 docs.push(if !doc.is_empty() && doc.starts_with(' ') {
43 doc[1..].to_string()
44 } else {
45 doc
46 });
47 }
48 }
49 }
50 docs
51}
52
53pub fn extract_deprecated(attrs: &[Attribute]) -> Option<DeprecatedInfo> {
55 for attr in attrs {
56 if attr.path().is_ident("deprecated") {
57 if let Ok(list) = attr.meta.require_list() {
58 let mut since = None;
59 let mut note = None;
60
61 list.parse_nested_meta(|meta| {
62 if meta.path.is_ident("since") {
63 let value = meta.value()?;
64 let lit: LitStr = value.parse()?;
65 since = Some(lit.value());
66 } else if meta.path.is_ident("note") {
67 let value = meta.value()?;
68 let lit: LitStr = value.parse()?;
69 note = Some(lit.value());
70 }
71 Ok(())
72 })
73 .ok()?;
74
75 return Some(DeprecatedInfo { since, note });
76 }
77 }
78 }
79 None
80}
81
82#[derive(Debug, Clone, PartialEq)]
95pub struct DeprecatedInfo {
96 pub since: Option<String>,
97 pub note: Option<String>,
98}
99
100impl ToTokens for DeprecatedInfo {
101 fn to_tokens(&self, tokens: &mut TokenStream2) {
102 let since = self
103 .since
104 .as_ref()
105 .map(|s| quote! { Some(#s) })
106 .unwrap_or_else(|| quote! { None });
107 let note = self
108 .note
109 .as_ref()
110 .map(|n| quote! { Some(#n) })
111 .unwrap_or_else(|| quote! { None });
112 tokens.append_all(quote! {
113 ::pyo3_stub_gen::type_info::DeprecatedInfo {
114 since: #since,
115 note: #note,
116 }
117 })
118 }
119}
120
121#[derive(Debug, Clone, PartialEq)]
122#[expect(clippy::enum_variant_names)]
123pub enum Attr {
124 Name(String),
126 Get,
127 GetAll,
128 Set,
129 SetAll,
130 Module(String),
131 Constructor(Signature),
132 Signature(Signature),
133 RenameAll(RenamingRule),
134 Extends(Type),
135
136 Eq,
138 Ord,
139 Hash,
140 Str,
141
142 New,
145 Getter(Option<String>),
146 Setter(Option<String>),
147 StaticMethod,
148 ClassMethod,
149 ClassAttr,
150}
151
152pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
153 let mut out = Vec::new();
154 for attr in attrs {
155 let mut new = parse_pyo3_attr(attr)?;
156 out.append(&mut new);
157 }
158 Ok(out)
159}
160
161pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
162 let mut pyo3_attrs = Vec::new();
163 let path = attr.path();
164 let is_full_path_pyo3_attr = path.segments.len() == 2
165 && path
166 .segments
167 .first()
168 .is_some_and(|seg| seg.ident.eq("pyo3"))
169 && path.segments.last().is_some_and(|seg| {
170 seg.ident.eq("pyclass") || seg.ident.eq("pymethods") || seg.ident.eq("pyfunction")
171 });
172 if path.is_ident("pyclass")
173 || path.is_ident("pymethods")
174 || path.is_ident("pyfunction")
175 || path.is_ident("pyo3")
176 || is_full_path_pyo3_attr
177 {
178 if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
183 use TokenTree::*;
184 let tokens: Vec<TokenTree> = tokens.clone().into_iter().collect();
185 for tt in tokens.split(|tt| {
188 if let Punct(p) = tt {
189 p.as_char() == ','
190 } else {
191 false
192 }
193 }) {
194 match tt {
195 [Ident(ident)] => {
196 if ident == "get" {
197 pyo3_attrs.push(Attr::Get);
198 }
199 if ident == "get_all" {
200 pyo3_attrs.push(Attr::GetAll);
201 }
202 if ident == "set" {
203 pyo3_attrs.push(Attr::Set);
204 }
205 if ident == "set_all" {
206 pyo3_attrs.push(Attr::SetAll);
207 }
208 if ident == "eq" {
209 pyo3_attrs.push(Attr::Eq);
210 }
211 if ident == "ord" {
212 pyo3_attrs.push(Attr::Ord);
213 }
214 if ident == "hash" {
215 pyo3_attrs.push(Attr::Hash);
216 }
217 if ident == "str" {
218 pyo3_attrs.push(Attr::Str);
219 }
220 }
222 [Ident(ident), Punct(_), Literal(lit)] => {
223 if ident == "name" {
224 pyo3_attrs
225 .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
226 }
227 if ident == "module" {
228 pyo3_attrs
229 .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
230 }
231 if ident == "rename_all" {
232 let name = lit.to_string().trim_matches('"').to_string();
233 if let Some(renaming_rule) = RenamingRule::try_new(&name) {
234 pyo3_attrs.push(Attr::RenameAll(renaming_rule));
235 }
236 }
237 }
238 [Ident(ident), Punct(_), Group(group)] => {
239 if ident == "signature" {
240 pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
241 } else if ident == "constructor" {
242 pyo3_attrs
243 .push(Attr::Constructor(syn::parse2(group.to_token_stream())?));
244 }
245 }
246 [Ident(ident), Punct(_), Ident(ident2)] => {
247 if ident == "extends" {
248 pyo3_attrs.push(Attr::Extends(syn::parse2(ident2.to_token_stream())?));
249 }
250 }
251 _ => {}
252 }
253 }
254 }
255 } else if path.is_ident("new") {
256 pyo3_attrs.push(Attr::New);
257 } else if path.is_ident("staticmethod") {
258 pyo3_attrs.push(Attr::StaticMethod);
259 } else if path.is_ident("classmethod") {
260 pyo3_attrs.push(Attr::ClassMethod);
261 } else if path.is_ident("classattr") {
262 pyo3_attrs.push(Attr::ClassAttr);
263 } else if path.is_ident("getter") {
264 if let Ok(inner) = attr.parse_args::<Ident>() {
265 pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
266 } else {
267 pyo3_attrs.push(Attr::Getter(None));
268 }
269 } else if path.is_ident("setter") {
270 if let Ok(inner) = attr.parse_args::<Ident>() {
271 pyo3_attrs.push(Attr::Setter(Some(inner.to_string())));
272 } else {
273 pyo3_attrs.push(Attr::Setter(None));
274 }
275 }
276
277 Ok(pyo3_attrs)
278}
279
280#[derive(Debug, Clone, PartialEq)]
281pub enum StubGenAttr {
282 Default(Expr),
284 Skip,
286 OverrideType(OverrideTypeAttribute),
288 TypeIgnore(IgnoreTarget),
290}
291
292pub fn prune_attrs(attrs: &mut Vec<Attribute>) {
293 attrs.retain(|attr| !attr.path().is_ident("gen_stub"));
294}
295
296pub fn parse_gen_stub_override_type(attrs: &[Attribute]) -> Result<Option<OverrideTypeAttribute>> {
297 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)? {
298 if let StubGenAttr::OverrideType(attr) = attr {
299 return Ok(Some(attr));
300 }
301 }
302 Ok(None)
303}
304
305pub fn parse_gen_stub_override_return_type(
306 attrs: &[Attribute],
307) -> Result<Option<OverrideTypeAttribute>> {
308 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
309 if let StubGenAttr::OverrideType(attr) = attr {
310 return Ok(Some(attr));
311 }
312 }
313 Ok(None)
314}
315
316pub fn parse_gen_stub_default(attrs: &[Attribute]) -> Result<Option<Expr>> {
317 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
318 if let StubGenAttr::Default(default) = attr {
319 return Ok(Some(default));
320 }
321 }
322 Ok(None)
323}
324pub fn parse_gen_stub_skip(attrs: &[Attribute]) -> Result<bool> {
325 let skip = parse_gen_stub_attrs(
326 attrs,
327 AttributeLocation::Field,
328 Some(&["override_return_type", "default"]),
329 )?
330 .iter()
331 .any(|attr| matches!(attr, StubGenAttr::Skip));
332 Ok(skip)
333}
334
335pub fn parse_gen_stub_type_ignore(attrs: &[Attribute]) -> Result<Option<IgnoreTarget>> {
336 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
338 if let StubGenAttr::TypeIgnore(target) = attr {
339 return Ok(Some(target));
340 }
341 }
342 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Field, None)? {
344 if let StubGenAttr::TypeIgnore(target) = attr {
345 return Ok(Some(target));
346 }
347 }
348 Ok(None)
349}
350
351fn parse_gen_stub_attrs(
352 attrs: &[Attribute],
353 location: AttributeLocation,
354 ignored_idents: Option<&[&str]>,
355) -> Result<Vec<StubGenAttr>> {
356 let mut out = Vec::new();
357 for attr in attrs {
358 let mut new = parse_gen_stub_attr(attr, location, ignored_idents.unwrap_or(&[]))?;
359 out.append(&mut new);
360 }
361 Ok(out)
362}
363
364fn parse_gen_stub_attr(
365 attr: &Attribute,
366 location: AttributeLocation,
367 ignored_idents: &[&str],
368) -> Result<Vec<StubGenAttr>> {
369 let mut gen_stub_attrs = Vec::new();
370 let path = attr.path();
371 if path.is_ident("gen_stub") {
372 attr.parse_args_with(|input: ParseStream| {
373 while !input.is_empty() {
374 let ident: Ident = input.parse()?;
375 let ignored_ident = ignored_idents.iter().any(|other| ident == other);
376 if (ident == "override_type"
377 && (location == AttributeLocation::Argument || ignored_ident))
378 || (ident == "override_return_type"
379 && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident))
380 {
381 let content;
382 parenthesized!(content in input);
383 let override_attr: OverrideTypeAttribute = content.parse()?;
384 gen_stub_attrs.push(StubGenAttr::OverrideType(override_attr));
385 } else if ident == "skip" && (location == AttributeLocation::Field || ignored_ident)
386 {
387 gen_stub_attrs.push(StubGenAttr::Skip);
388 } else if ident == "default"
389 && input.peek(Token![=])
390 && (location == AttributeLocation::Field || location == AttributeLocation::Function || ignored_ident)
391 {
392 input.parse::<Token![=]>()?;
393 gen_stub_attrs.push(StubGenAttr::Default(input.parse()?));
394 } else if ident == "type_ignore"
395 && (location == AttributeLocation::Function || location == AttributeLocation::Field || ignored_ident)
396 {
397 if input.peek(Token![=]) {
401 input.parse::<Token![=]>()?;
402 let content;
404 syn::bracketed!(content in input);
405 let rules = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
406
407 if rules.is_empty() {
409 return Err(syn::Error::new(
410 ident.span(),
411 "type_ignore with empty array is not allowed. Use type_ignore without equals for catch-all, or specify rules in the array."
412 ));
413 }
414
415 let rule_lits: Vec<LitStr> = rules.into_iter().collect();
417 gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::SpecifiedLits(rule_lits)));
418 } else {
419 gen_stub_attrs.push(StubGenAttr::TypeIgnore(IgnoreTarget::All));
421 }
422 } else if ident == "override_type" {
423 return Err(syn::Error::new(
424 ident.span(),
425 "`override_type(...)` is only valid in argument position".to_string(),
426 ));
427 } else if ident == "override_return_type" {
428 return Err(syn::Error::new(
429 ident.span(),
430 "`override_return_type(...)` is only valid in function or method position"
431 .to_string(),
432 ));
433 } else if ident == "skip" {
434 return Err(syn::Error::new(
435 ident.span(),
436 "`skip` is only valid in field position".to_string(),
437 ));
438 } else if ident == "default" {
439 return Err(syn::Error::new(
440 ident.span(),
441 "`default=xxx` is only valid in field or function position".to_string(),
442 ));
443 } else if ident == "type_ignore" {
444 return Err(syn::Error::new(
445 ident.span(),
446 "`type_ignore` or `type_ignore=[...]` is only valid in function or method position".to_string(),
447 ));
448 } else if location == AttributeLocation::Argument {
449 return Err(syn::Error::new(
450 ident.span(),
451 format!("Unsupported keyword `{ident}`, valid is `override_type(...)`"),
452 ));
453 } else if location == AttributeLocation::Field {
454 return Err(syn::Error::new(
455 ident.span(),
456 format!("Unsupported keyword `{ident}`, valid is `default=xxx`, `skip`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"),
457 ));
458 } else if location == AttributeLocation::Function {
459 return Err(syn::Error::new(
460 ident.span(),
461 format!(
462 "Unsupported keyword `{ident}`, valid is `default=xxx`, `override_return_type(...)`, `type_ignore`, or `type_ignore=[...]`"
463 ),
464 ));
465 } else {
466 return Err(syn::Error::new(
467 ident.span(),
468 format!("Unsupported keyword `{ident}`"),
469 ));
470 }
471 if input.peek(Token![,]) {
472 input.parse::<Token![,]>()?;
473 } else {
474 break;
475 }
476 }
477 Ok(())
478 })?;
479 }
480 Ok(gen_stub_attrs)
481}
482
483#[derive(Debug, Clone, Copy, PartialEq)]
484pub(crate) enum AttributeLocation {
485 Argument,
486 Field,
487 Function,
488}
489
490#[derive(Debug, Clone, PartialEq)]
491pub struct OverrideTypeAttribute {
492 pub(crate) type_repr: String,
493 pub(crate) imports: HashSet<String>,
494}
495
496mod kw {
497 syn::custom_keyword!(type_repr);
498 syn::custom_keyword!(imports);
499 syn::custom_keyword!(override_type);
500}
501
502impl Parse for OverrideTypeAttribute {
503 fn parse(input: ParseStream) -> Result<Self> {
504 let mut type_repr = None;
505 let mut imports = HashSet::new();
506
507 while !input.is_empty() {
508 let lookahead = input.lookahead1();
509
510 if lookahead.peek(kw::type_repr) {
511 input.parse::<kw::type_repr>()?;
512 input.parse::<Token![=]>()?;
513 type_repr = Some(input.parse::<LitStr>()?);
514 } else if lookahead.peek(kw::imports) {
515 input.parse::<kw::imports>()?;
516 input.parse::<Token![=]>()?;
517
518 let content;
519 parenthesized!(content in input);
520 let parsed_imports = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
521 imports = parsed_imports.into_iter().collect();
522 } else {
523 return Err(lookahead.error());
524 }
525
526 if !input.is_empty() {
527 input.parse::<Token![,]>()?;
528 }
529 }
530
531 Ok(OverrideTypeAttribute {
532 type_repr: type_repr
533 .ok_or_else(|| input.error("missing type_repr"))?
534 .value(),
535 imports: imports.iter().map(|i| i.value()).collect(),
536 })
537 }
538}
539
540#[cfg(test)]
541mod test {
542 use super::*;
543 use syn::{parse_str, Fields, ItemFn, ItemStruct, PatType};
544
545 #[test]
546 fn test_parse_pyo3_attr() -> Result<()> {
547 let item: ItemStruct = parse_str(
548 r#"
549 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
550 #[pyo3(rename_all = "SCREAMING_SNAKE_CASE")]
551 pub struct PyPlaceholder {
552 #[pyo3(get)]
553 pub name: String,
554 }
555 "#,
556 )?;
557 let attrs = parse_pyo3_attrs(&item.attrs)?;
559 assert_eq!(
560 attrs,
561 vec![
562 Attr::Module("my_module".to_string()),
563 Attr::Name("Placeholder".to_string()),
564 Attr::RenameAll(RenamingRule::ScreamingSnakeCase),
565 ]
566 );
567
568 if let Fields::Named(fields) = item.fields {
570 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
571 assert_eq!(attrs, vec![Attr::Get]);
572 } else {
573 unreachable!()
574 }
575 Ok(())
576 }
577
578 #[test]
579 fn test_parse_pyo3_attr_full_path() -> Result<()> {
580 let item: ItemStruct = parse_str(
581 r#"
582 #[pyo3::pyclass(mapping, module = "my_module", name = "Placeholder")]
583 pub struct PyPlaceholder {
584 #[pyo3(get)]
585 pub name: String,
586 }
587 "#,
588 )?;
589 let attrs = parse_pyo3_attr(&item.attrs[0])?;
591 assert_eq!(
592 attrs,
593 vec![
594 Attr::Module("my_module".to_string()),
595 Attr::Name("Placeholder".to_string())
596 ]
597 );
598
599 if let Fields::Named(fields) = item.fields {
601 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
602 assert_eq!(attrs, vec![Attr::Get]);
603 } else {
604 unreachable!()
605 }
606 Ok(())
607 }
608 #[test]
609 fn test_parse_gen_stub_field_attr() -> Result<()> {
610 let item: ItemStruct = parse_str(
611 r#"
612 pub struct PyPlaceholder {
613 #[gen_stub(default = String::from("foo"), skip)]
614 pub field0: String,
615 #[gen_stub(skip)]
616 pub field1: String,
617 #[gen_stub(default = 1+2)]
618 pub field2: usize,
619 }
620 "#,
621 )?;
622 let fields: Vec<_> = item.fields.into_iter().collect();
623 let field0_attrs = parse_gen_stub_attrs(&fields[0].attrs, AttributeLocation::Field, None)?;
624 if let StubGenAttr::Default(expr) = &field0_attrs[0] {
625 assert_eq!(
626 expr.to_token_stream().to_string(),
627 "String :: from (\"foo\")"
628 );
629 } else {
630 panic!("attr should be Default");
631 };
632 assert_eq!(&StubGenAttr::Skip, &field0_attrs[1]);
633 let field1_attrs = parse_gen_stub_attrs(&fields[1].attrs, AttributeLocation::Field, None)?;
634 assert_eq!(vec![StubGenAttr::Skip], field1_attrs);
635 let field2_attrs = parse_gen_stub_attrs(&fields[2].attrs, AttributeLocation::Field, None)?;
636 if let StubGenAttr::Default(expr) = &field2_attrs[0] {
637 assert_eq!(expr.to_token_stream().to_string(), "1 + 2");
638 } else {
639 panic!("attr should be Default");
640 };
641 Ok(())
642 }
643 #[test]
644 fn test_parse_gen_stub_override_type_attr() -> Result<()> {
645 let item: ItemFn = parse_str(
646 r#"
647 #[gen_stub_pyfunction]
648 #[pyfunction]
649 #[gen_stub(override_return_type(type_repr="typing.Never", imports=("typing")))]
650 fn say_hello_forever<'a>(
651 #[gen_stub(override_type(type_repr="collections.abc.Callable[[str]]", imports=("collections.abc")))]
652 cb: Bound<'a, PyAny>,
653 ) -> PyResult<()> {
654 loop {
655 cb.call1(("Hello!",))?;
656 }
657 }
658 "#,
659 )?;
660 let fn_attrs = parse_gen_stub_attrs(&item.attrs, AttributeLocation::Function, None)?;
661 assert_eq!(fn_attrs.len(), 1);
662 if let StubGenAttr::OverrideType(expr) = &fn_attrs[0] {
663 assert_eq!(
664 *expr,
665 OverrideTypeAttribute {
666 type_repr: "typing.Never".into(),
667 imports: HashSet::from(["typing".into()])
668 }
669 );
670 } else {
671 panic!("attr should be OverrideType");
672 };
673 if let syn::FnArg::Typed(PatType { attrs, .. }) = &item.sig.inputs[0] {
674 let arg_attrs = parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)?;
675 assert_eq!(arg_attrs.len(), 1);
676 if let StubGenAttr::OverrideType(expr) = &arg_attrs[0] {
677 assert_eq!(
678 *expr,
679 OverrideTypeAttribute {
680 type_repr: "collections.abc.Callable[[str]]".into(),
681 imports: HashSet::from(["collections.abc".into()])
682 }
683 );
684 } else {
685 panic!("attr should be OverrideType");
686 };
687 }
688 Ok(())
689 }
690}