1use std::collections::HashSet;
2
3use super::{RenamingRule, Signature};
4use proc_macro2::{TokenStream as TokenStream2, TokenTree};
5use quote::{quote, ToTokens, TokenStreamExt};
6use syn::{
7 parenthesized,
8 parse::{Parse, ParseStream},
9 punctuated::Punctuated,
10 Attribute, Expr, ExprLit, Ident, Lit, LitStr, Meta, MetaList, Result, Token, Type,
11};
12
13pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
14 let mut docs = Vec::new();
15 for attr in attrs {
16 if attr.path().is_ident("doc") {
18 if let Meta::NameValue(syn::MetaNameValue {
19 value:
20 Expr::Lit(ExprLit {
21 lit: Lit::Str(doc), ..
22 }),
23 ..
24 }) = &attr.meta
25 {
26 let doc = doc.value();
27 docs.push(if !doc.is_empty() && doc.starts_with(' ') {
34 doc[1..].to_string()
35 } else {
36 doc
37 });
38 }
39 }
40 }
41 docs
42}
43
44pub fn extract_deprecated(attrs: &[Attribute]) -> Option<DeprecatedInfo> {
46 for attr in attrs {
47 if attr.path().is_ident("deprecated") {
48 if let Ok(list) = attr.meta.require_list() {
49 let mut since = None;
50 let mut note = None;
51
52 list.parse_nested_meta(|meta| {
53 if meta.path.is_ident("since") {
54 let value = meta.value()?;
55 let lit: LitStr = value.parse()?;
56 since = Some(lit.value());
57 } else if meta.path.is_ident("note") {
58 let value = meta.value()?;
59 let lit: LitStr = value.parse()?;
60 note = Some(lit.value());
61 }
62 Ok(())
63 })
64 .ok()?;
65
66 return Some(DeprecatedInfo { since, note });
67 }
68 }
69 }
70 None
71}
72
73#[derive(Debug, Clone, PartialEq)]
86pub struct DeprecatedInfo {
87 pub since: Option<String>,
88 pub note: Option<String>,
89}
90
91impl ToTokens for DeprecatedInfo {
92 fn to_tokens(&self, tokens: &mut TokenStream2) {
93 let since = self
94 .since
95 .as_ref()
96 .map(|s| quote! { Some(#s) })
97 .unwrap_or_else(|| quote! { None });
98 let note = self
99 .note
100 .as_ref()
101 .map(|n| quote! { Some(#n) })
102 .unwrap_or_else(|| quote! { None });
103 tokens.append_all(quote! {
104 ::pyo3_stub_gen::type_info::DeprecatedInfo {
105 since: #since,
106 note: #note,
107 }
108 })
109 }
110}
111
112#[derive(Debug, Clone, PartialEq)]
113#[expect(clippy::enum_variant_names)]
114pub enum Attr {
115 Name(String),
117 Get,
118 GetAll,
119 Set,
120 SetAll,
121 Module(String),
122 Constructor(Signature),
123 Signature(Signature),
124 RenameAll(RenamingRule),
125 Extends(Type),
126
127 New,
130 Getter(Option<String>),
131 Setter(Option<String>),
132 StaticMethod,
133 ClassMethod,
134 ClassAttr,
135}
136
137pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
138 let mut out = Vec::new();
139 for attr in attrs {
140 let mut new = parse_pyo3_attr(attr)?;
141 out.append(&mut new);
142 }
143 Ok(out)
144}
145
146pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
147 let mut pyo3_attrs = Vec::new();
148 let path = attr.path();
149 let is_full_path_pyo3_attr = path.segments.len() == 2
150 && path
151 .segments
152 .first()
153 .is_some_and(|seg| seg.ident.eq("pyo3"))
154 && path.segments.last().is_some_and(|seg| {
155 seg.ident.eq("pyclass") || seg.ident.eq("pymethods") || seg.ident.eq("pyfunction")
156 });
157 if path.is_ident("pyclass")
158 || path.is_ident("pymethods")
159 || path.is_ident("pyfunction")
160 || path.is_ident("pyo3")
161 || is_full_path_pyo3_attr
162 {
163 if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
168 use TokenTree::*;
169 let tokens: Vec<TokenTree> = tokens.clone().into_iter().collect();
170 for tt in tokens.split(|tt| {
173 if let Punct(p) = tt {
174 p.as_char() == ','
175 } else {
176 false
177 }
178 }) {
179 match tt {
180 [Ident(ident)] => {
181 if ident == "get" {
182 pyo3_attrs.push(Attr::Get);
183 }
184 if ident == "get_all" {
185 pyo3_attrs.push(Attr::GetAll);
186 }
187 if ident == "set" {
188 pyo3_attrs.push(Attr::Set);
189 }
190 if ident == "set_all" {
191 pyo3_attrs.push(Attr::SetAll);
192 }
193 }
194 [Ident(ident), Punct(_), Literal(lit)] => {
195 if ident == "name" {
196 pyo3_attrs
197 .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
198 }
199 if ident == "module" {
200 pyo3_attrs
201 .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
202 }
203 if ident == "rename_all" {
204 let name = lit.to_string().trim_matches('"').to_string();
205 if let Some(renaming_rule) = RenamingRule::try_new(&name) {
206 pyo3_attrs.push(Attr::RenameAll(renaming_rule));
207 }
208 }
209 }
210 [Ident(ident), Punct(_), Group(group)] => {
211 if ident == "signature" {
212 pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
213 } else if ident == "constructor" {
214 pyo3_attrs
215 .push(Attr::Constructor(syn::parse2(group.to_token_stream())?));
216 }
217 }
218 [Ident(ident), Punct(_), Ident(ident2)] => {
219 if ident == "extends" {
220 pyo3_attrs.push(Attr::Extends(syn::parse2(ident2.to_token_stream())?));
221 }
222 }
223 _ => {}
224 }
225 }
226 }
227 } else if path.is_ident("new") {
228 pyo3_attrs.push(Attr::New);
229 } else if path.is_ident("staticmethod") {
230 pyo3_attrs.push(Attr::StaticMethod);
231 } else if path.is_ident("classmethod") {
232 pyo3_attrs.push(Attr::ClassMethod);
233 } else if path.is_ident("classattr") {
234 pyo3_attrs.push(Attr::ClassAttr);
235 } else if path.is_ident("getter") {
236 if let Ok(inner) = attr.parse_args::<Ident>() {
237 pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
238 } else {
239 pyo3_attrs.push(Attr::Getter(None));
240 }
241 } else if path.is_ident("setter") {
242 if let Ok(inner) = attr.parse_args::<Ident>() {
243 pyo3_attrs.push(Attr::Setter(Some(inner.to_string())));
244 } else {
245 pyo3_attrs.push(Attr::Setter(None));
246 }
247 }
248
249 Ok(pyo3_attrs)
250}
251
252#[derive(Debug, Clone, PartialEq)]
253pub enum StubGenAttr {
254 Default(Expr),
256 Skip,
258 OverrideType(OverrideTypeAttribute),
260}
261
262pub fn prune_attrs(attrs: &mut Vec<Attribute>) {
263 attrs.retain(|attr| !attr.path().is_ident("gen_stub"));
264}
265
266pub fn parse_gen_stub_override_type(attrs: &[Attribute]) -> Result<Option<OverrideTypeAttribute>> {
267 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)? {
268 if let StubGenAttr::OverrideType(attr) = attr {
269 return Ok(Some(attr));
270 }
271 }
272 Ok(None)
273}
274
275pub fn parse_gen_stub_override_return_type(
276 attrs: &[Attribute],
277) -> Result<Option<OverrideTypeAttribute>> {
278 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
279 if let StubGenAttr::OverrideType(attr) = attr {
280 return Ok(Some(attr));
281 }
282 }
283 Ok(None)
284}
285
286pub fn parse_gen_stub_default(attrs: &[Attribute]) -> Result<Option<Expr>> {
287 for attr in parse_gen_stub_attrs(attrs, AttributeLocation::Function, None)? {
288 if let StubGenAttr::Default(default) = attr {
289 return Ok(Some(default));
290 }
291 }
292 Ok(None)
293}
294pub fn parse_gen_stub_skip(attrs: &[Attribute]) -> Result<bool> {
295 let skip = parse_gen_stub_attrs(
296 attrs,
297 AttributeLocation::Field,
298 Some(&["override_return_type", "default"]),
299 )?
300 .iter()
301 .any(|attr| matches!(attr, StubGenAttr::Skip));
302 Ok(skip)
303}
304fn parse_gen_stub_attrs(
305 attrs: &[Attribute],
306 location: AttributeLocation,
307 ignored_idents: Option<&[&str]>,
308) -> Result<Vec<StubGenAttr>> {
309 let mut out = Vec::new();
310 for attr in attrs {
311 let mut new = parse_gen_stub_attr(attr, location, ignored_idents.unwrap_or(&[]))?;
312 out.append(&mut new);
313 }
314 Ok(out)
315}
316
317fn parse_gen_stub_attr(
318 attr: &Attribute,
319 location: AttributeLocation,
320 ignored_idents: &[&str],
321) -> Result<Vec<StubGenAttr>> {
322 let mut gen_stub_attrs = Vec::new();
323 let path = attr.path();
324 if path.is_ident("gen_stub") {
325 attr.parse_args_with(|input: ParseStream| {
326 while !input.is_empty() {
327 let ident: Ident = input.parse()?;
328 let ignored_ident = ignored_idents.iter().any(|other| ident == other);
329 if (ident == "override_type"
330 && (location == AttributeLocation::Argument || ignored_ident))
331 || (ident == "override_return_type"
332 && (location == AttributeLocation::Function || ignored_ident))
333 {
334 let content;
335 parenthesized!(content in input);
336 let override_attr: OverrideTypeAttribute = content.parse()?;
337 gen_stub_attrs.push(StubGenAttr::OverrideType(override_attr));
338 } else if ident == "skip" && (location == AttributeLocation::Field || ignored_ident)
339 {
340 gen_stub_attrs.push(StubGenAttr::Skip);
341 } else if ident == "default"
342 && input.peek(Token![=])
343 && (location == AttributeLocation::Field || location == AttributeLocation::Function || ignored_ident)
344 {
345 input.parse::<Token![=]>()?;
346 gen_stub_attrs.push(StubGenAttr::Default(input.parse()?));
347 } else if ident == "override_type" {
348 return Err(syn::Error::new(
349 ident.span(),
350 "`override_type(...)` is only valid in argument position".to_string(),
351 ));
352 } else if ident == "override_return_type" {
353 return Err(syn::Error::new(
354 ident.span(),
355 "`override_return_type(...)` is only valid in function position"
356 .to_string(),
357 ));
358 } else if ident == "skip" {
359 return Err(syn::Error::new(
360 ident.span(),
361 "`skip` is only valid in field position".to_string(),
362 ));
363 } else if ident == "default" {
364 return Err(syn::Error::new(
365 ident.span(),
366 "`default=xxx` is only valid in field or function position".to_string(),
367 ));
368 } else if location == AttributeLocation::Argument {
369 return Err(syn::Error::new(
370 ident.span(),
371 format!("Unsupported keyword `{ident}`, valid is `override_type(...)`"),
372 ));
373 } else if location == AttributeLocation::Field {
374 return Err(syn::Error::new(
375 ident.span(),
376 format!("Unsupported keyword `{ident}`, valid is `default=xxx` or `skip`"),
377 ));
378 } else if location == AttributeLocation::Function {
379 return Err(syn::Error::new(
380 ident.span(),
381 format!(
382 "Unsupported keyword `{ident}`, valid is `default=xxx` or `override_return_type(...)`"
383 ),
384 ));
385 } else {
386 return Err(syn::Error::new(
387 ident.span(),
388 format!("Unsupported keyword `{ident}`"),
389 ));
390 }
391 if input.peek(Token![,]) {
392 input.parse::<Token![,]>()?;
393 } else {
394 break;
395 }
396 }
397 Ok(())
398 })?;
399 }
400 Ok(gen_stub_attrs)
401}
402
403#[derive(Debug, Clone, Copy, PartialEq)]
404pub(crate) enum AttributeLocation {
405 Argument,
406 Field,
407 Function,
408}
409
410#[derive(Debug, Clone, PartialEq)]
411pub struct OverrideTypeAttribute {
412 pub(crate) type_repr: String,
413 pub(crate) imports: HashSet<String>,
414}
415
416mod kw {
417 syn::custom_keyword!(type_repr);
418 syn::custom_keyword!(imports);
419 syn::custom_keyword!(override_type);
420}
421
422impl Parse for OverrideTypeAttribute {
423 fn parse(input: ParseStream) -> Result<Self> {
424 let mut type_repr = None;
425 let mut imports = HashSet::new();
426
427 while !input.is_empty() {
428 let lookahead = input.lookahead1();
429
430 if lookahead.peek(kw::type_repr) {
431 input.parse::<kw::type_repr>()?;
432 input.parse::<Token![=]>()?;
433 type_repr = Some(input.parse::<LitStr>()?);
434 } else if lookahead.peek(kw::imports) {
435 input.parse::<kw::imports>()?;
436 input.parse::<Token![=]>()?;
437
438 let content;
439 parenthesized!(content in input);
440 let parsed_imports = Punctuated::<LitStr, Token![,]>::parse_terminated(&content)?;
441 imports = parsed_imports.into_iter().collect();
442 } else {
443 return Err(lookahead.error());
444 }
445
446 if !input.is_empty() {
447 input.parse::<Token![,]>()?;
448 }
449 }
450
451 Ok(OverrideTypeAttribute {
452 type_repr: type_repr
453 .ok_or_else(|| input.error("missing type_repr"))?
454 .value(),
455 imports: imports.iter().map(|i| i.value()).collect(),
456 })
457 }
458}
459
460#[cfg(test)]
461mod test {
462 use super::*;
463 use syn::{parse_str, Fields, ItemFn, ItemStruct, PatType};
464
465 #[test]
466 fn test_parse_pyo3_attr() -> Result<()> {
467 let item: ItemStruct = parse_str(
468 r#"
469 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
470 #[pyo3(rename_all = "SCREAMING_SNAKE_CASE")]
471 pub struct PyPlaceholder {
472 #[pyo3(get)]
473 pub name: String,
474 }
475 "#,
476 )?;
477 let attrs = parse_pyo3_attrs(&item.attrs)?;
479 assert_eq!(
480 attrs,
481 vec![
482 Attr::Module("my_module".to_string()),
483 Attr::Name("Placeholder".to_string()),
484 Attr::RenameAll(RenamingRule::ScreamingSnakeCase),
485 ]
486 );
487
488 if let Fields::Named(fields) = item.fields {
490 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
491 assert_eq!(attrs, vec![Attr::Get]);
492 } else {
493 unreachable!()
494 }
495 Ok(())
496 }
497
498 #[test]
499 fn test_parse_pyo3_attr_full_path() -> Result<()> {
500 let item: ItemStruct = parse_str(
501 r#"
502 #[pyo3::pyclass(mapping, module = "my_module", name = "Placeholder")]
503 pub struct PyPlaceholder {
504 #[pyo3(get)]
505 pub name: String,
506 }
507 "#,
508 )?;
509 let attrs = parse_pyo3_attr(&item.attrs[0])?;
511 assert_eq!(
512 attrs,
513 vec![
514 Attr::Module("my_module".to_string()),
515 Attr::Name("Placeholder".to_string())
516 ]
517 );
518
519 if let Fields::Named(fields) = item.fields {
521 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
522 assert_eq!(attrs, vec![Attr::Get]);
523 } else {
524 unreachable!()
525 }
526 Ok(())
527 }
528 #[test]
529 fn test_parse_gen_stub_field_attr() -> Result<()> {
530 let item: ItemStruct = parse_str(
531 r#"
532 pub struct PyPlaceholder {
533 #[gen_stub(default = String::from("foo"), skip)]
534 pub field0: String,
535 #[gen_stub(skip)]
536 pub field1: String,
537 #[gen_stub(default = 1+2)]
538 pub field2: usize,
539 }
540 "#,
541 )?;
542 let fields: Vec<_> = item.fields.into_iter().collect();
543 let field0_attrs = parse_gen_stub_attrs(&fields[0].attrs, AttributeLocation::Field, None)?;
544 if let StubGenAttr::Default(expr) = &field0_attrs[0] {
545 assert_eq!(
546 expr.to_token_stream().to_string(),
547 "String :: from (\"foo\")"
548 );
549 } else {
550 panic!("attr should be Default");
551 };
552 assert_eq!(&StubGenAttr::Skip, &field0_attrs[1]);
553 let field1_attrs = parse_gen_stub_attrs(&fields[1].attrs, AttributeLocation::Field, None)?;
554 assert_eq!(vec![StubGenAttr::Skip], field1_attrs);
555 let field2_attrs = parse_gen_stub_attrs(&fields[2].attrs, AttributeLocation::Field, None)?;
556 if let StubGenAttr::Default(expr) = &field2_attrs[0] {
557 assert_eq!(expr.to_token_stream().to_string(), "1 + 2");
558 } else {
559 panic!("attr should be Default");
560 };
561 Ok(())
562 }
563 #[test]
564 fn test_parse_gen_stub_override_type_attr() -> Result<()> {
565 let item: ItemFn = parse_str(
566 r#"
567 #[gen_stub_pyfunction]
568 #[pyfunction]
569 #[gen_stub(override_return_type(type_repr="typing.Never", imports=("typing")))]
570 fn say_hello_forever<'a>(
571 #[gen_stub(override_type(type_repr="collections.abc.Callable[[str]]", imports=("collections.abc")))]
572 cb: Bound<'a, PyAny>,
573 ) -> PyResult<()> {
574 loop {
575 cb.call1(("Hello!",))?;
576 }
577 }
578 "#,
579 )?;
580 let fn_attrs = parse_gen_stub_attrs(&item.attrs, AttributeLocation::Function, None)?;
581 assert_eq!(fn_attrs.len(), 1);
582 if let StubGenAttr::OverrideType(expr) = &fn_attrs[0] {
583 assert_eq!(
584 *expr,
585 OverrideTypeAttribute {
586 type_repr: "typing.Never".into(),
587 imports: HashSet::from(["typing".into()])
588 }
589 );
590 } else {
591 panic!("attr should be OverrideType");
592 };
593 if let syn::FnArg::Typed(PatType { attrs, .. }) = &item.sig.inputs[0] {
594 let arg_attrs = parse_gen_stub_attrs(attrs, AttributeLocation::Argument, None)?;
595 assert_eq!(arg_attrs.len(), 1);
596 if let StubGenAttr::OverrideType(expr) = &arg_attrs[0] {
597 assert_eq!(
598 *expr,
599 OverrideTypeAttribute {
600 type_repr: "collections.abc.Callable[[str]]".into(),
601 imports: HashSet::from(["collections.abc".into()])
602 }
603 );
604 } else {
605 panic!("attr should be OverrideType");
606 };
607 }
608 Ok(())
609 }
610}