pyo3_stub_gen_derive/gen_stub/
attr.rs
1use super::Signature;
2use proc_macro2::TokenTree;
3use quote::ToTokens;
4use syn::{Attribute, Expr, ExprLit, Ident, Lit, Meta, MetaList, Result};
5
6pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
7 let mut docs = Vec::new();
8 for attr in attrs {
9 if attr.path().is_ident("doc") {
11 if let Meta::NameValue(syn::MetaNameValue {
12 value:
13 Expr::Lit(ExprLit {
14 lit: Lit::Str(doc), ..
15 }),
16 ..
17 }) = &attr.meta
18 {
19 let doc = doc.value();
20 docs.push(if !doc.is_empty() && doc.starts_with(' ') {
27 doc[1..].to_string()
28 } else {
29 doc
30 });
31 }
32 }
33 }
34 docs
35}
36
37#[derive(Debug, Clone, PartialEq)]
50pub enum Attr {
51 Name(String),
53 Get,
54 GetAll,
55 Module(String),
56 Signature(Signature),
57
58 New,
61 Getter(Option<String>),
62 StaticMethod,
63 ClassMethod,
64}
65
66pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
67 let mut out = Vec::new();
68 for attr in attrs {
69 let mut new = parse_pyo3_attr(attr)?;
70 out.append(&mut new);
71 }
72 Ok(out)
73}
74
75pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
76 let mut pyo3_attrs = Vec::new();
77 let path = attr.path();
78 if path.is_ident("pyclass")
79 || path.is_ident("pymethods")
80 || path.is_ident("pyfunction")
81 || path.is_ident("pyo3")
82 {
83 if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
88 use TokenTree::*;
89 let tokens: Vec<TokenTree> = tokens.clone().into_iter().collect();
90 for tt in tokens.split(|tt| {
93 if let Punct(p) = tt {
94 p.as_char() == ','
95 } else {
96 false
97 }
98 }) {
99 match tt {
100 [Ident(ident)] => {
101 if ident == "get" {
102 pyo3_attrs.push(Attr::Get);
103 }
104 if ident == "get_all" {
105 pyo3_attrs.push(Attr::GetAll);
106 }
107 }
108 [Ident(ident), Punct(_), Literal(lit)] => {
109 if ident == "name" {
110 pyo3_attrs
111 .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
112 }
113 if ident == "module" {
114 pyo3_attrs
115 .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
116 }
117 }
118 [Ident(ident), Punct(_), Group(group)] => {
119 if ident == "signature" {
120 pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
121 }
122 }
123 _ => {}
124 }
125 }
126 }
127 } else if path.is_ident("new") {
128 pyo3_attrs.push(Attr::New);
129 } else if path.is_ident("staticmethod") {
130 pyo3_attrs.push(Attr::StaticMethod);
131 } else if path.is_ident("classmethod") {
132 pyo3_attrs.push(Attr::ClassMethod);
133 } else if path.is_ident("getter") {
134 if let Ok(inner) = attr.parse_args::<Ident>() {
135 pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
136 } else {
137 pyo3_attrs.push(Attr::Getter(None));
138 }
139 }
140
141 Ok(pyo3_attrs)
142}
143
144#[cfg(test)]
145mod test {
146 use super::*;
147 use syn::{parse_str, Fields, ItemStruct};
148
149 #[test]
150 fn test_parse_pyo3_attr() -> Result<()> {
151 let item: ItemStruct = parse_str(
152 r#"
153 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
154 pub struct PyPlaceholder {
155 #[pyo3(get)]
156 pub name: String,
157 }
158 "#,
159 )?;
160 let attrs = parse_pyo3_attr(&item.attrs[0])?;
162 assert_eq!(
163 attrs,
164 vec![
165 Attr::Module("my_module".to_string()),
166 Attr::Name("Placeholder".to_string())
167 ]
168 );
169
170 if let Fields::Named(fields) = item.fields {
172 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
173 assert_eq!(attrs, vec![Attr::Get]);
174 } else {
175 unreachable!()
176 }
177 Ok(())
178 }
179}