pyo3_stub_gen_derive/gen_stub/
attr.rs1use super::{RenamingRule, Signature};
2use proc_macro2::TokenTree;
3use quote::ToTokens;
4use syn::{Attribute, Expr, ExprLit, Ident, Lit, Meta, MetaList, Result, Type};
5
6pub fn extract_documents(attrs: &[Attribute]) -> Vec<String> {
7 let mut docs = Vec::new();
8 for attr in attrs {
9 if attr.path().is_ident("doc") {
11 if let Meta::NameValue(syn::MetaNameValue {
12 value:
13 Expr::Lit(ExprLit {
14 lit: Lit::Str(doc), ..
15 }),
16 ..
17 }) = &attr.meta
18 {
19 let doc = doc.value();
20 docs.push(if !doc.is_empty() && doc.starts_with(' ') {
27 doc[1..].to_string()
28 } else {
29 doc
30 });
31 }
32 }
33 }
34 docs
35}
36
37#[derive(Debug, Clone, PartialEq)]
50pub enum Attr {
51 Name(String),
53 Get,
54 GetAll,
55 Module(String),
56 Signature(Signature),
57 RenameAll(RenamingRule),
58 Extends(Type),
59
60 New,
63 Getter(Option<String>),
64 StaticMethod,
65 ClassMethod,
66}
67
68pub fn parse_pyo3_attrs(attrs: &[Attribute]) -> Result<Vec<Attr>> {
69 let mut out = Vec::new();
70 for attr in attrs {
71 let mut new = parse_pyo3_attr(attr)?;
72 out.append(&mut new);
73 }
74 Ok(out)
75}
76
77pub fn parse_pyo3_attr(attr: &Attribute) -> Result<Vec<Attr>> {
78 let mut pyo3_attrs = Vec::new();
79 let path = attr.path();
80 let is_full_path_pyo3_attr = path.segments.len() == 2
81 && path
82 .segments
83 .first()
84 .is_some_and(|seg| seg.ident.eq("pyo3"))
85 && path.segments.last().is_some_and(|seg| {
86 seg.ident.eq("pyclass") || seg.ident.eq("pymethods") || seg.ident.eq("pyfunction")
87 });
88 if path.is_ident("pyclass")
89 || path.is_ident("pymethods")
90 || path.is_ident("pyfunction")
91 || path.is_ident("pyo3")
92 || is_full_path_pyo3_attr
93 {
94 if let Meta::List(MetaList { tokens, .. }) = &attr.meta {
99 use TokenTree::*;
100 let tokens: Vec<TokenTree> = tokens.clone().into_iter().collect();
101 for tt in tokens.split(|tt| {
104 if let Punct(p) = tt {
105 p.as_char() == ','
106 } else {
107 false
108 }
109 }) {
110 match tt {
111 [Ident(ident)] => {
112 if ident == "get" {
113 pyo3_attrs.push(Attr::Get);
114 }
115 if ident == "get_all" {
116 pyo3_attrs.push(Attr::GetAll);
117 }
118 }
119 [Ident(ident), Punct(_), Literal(lit)] => {
120 if ident == "name" {
121 pyo3_attrs
122 .push(Attr::Name(lit.to_string().trim_matches('"').to_string()));
123 }
124 if ident == "module" {
125 pyo3_attrs
126 .push(Attr::Module(lit.to_string().trim_matches('"').to_string()));
127 }
128 if ident == "rename_all" {
129 let name = lit.to_string().trim_matches('"').to_string();
130 if let Some(renaming_rule) = RenamingRule::try_new(&name) {
131 pyo3_attrs.push(Attr::RenameAll(renaming_rule));
132 }
133 }
134 }
135 [Ident(ident), Punct(_), Group(group)] => {
136 if ident == "signature" {
137 pyo3_attrs.push(Attr::Signature(syn::parse2(group.to_token_stream())?));
138 }
139 }
140 [Ident(ident), Punct(_), Ident(ident2)] => {
141 if ident == "extends" {
142 pyo3_attrs.push(Attr::Extends(syn::parse2(ident2.to_token_stream())?));
143 }
144 }
145 _ => {}
146 }
147 }
148 }
149 } else if path.is_ident("new") {
150 pyo3_attrs.push(Attr::New);
151 } else if path.is_ident("staticmethod") {
152 pyo3_attrs.push(Attr::StaticMethod);
153 } else if path.is_ident("classmethod") {
154 pyo3_attrs.push(Attr::ClassMethod);
155 } else if path.is_ident("getter") {
156 if let Ok(inner) = attr.parse_args::<Ident>() {
157 pyo3_attrs.push(Attr::Getter(Some(inner.to_string())));
158 } else {
159 pyo3_attrs.push(Attr::Getter(None));
160 }
161 }
162
163 Ok(pyo3_attrs)
164}
165
166#[cfg(test)]
167mod test {
168 use super::*;
169 use syn::{parse_str, Fields, ItemStruct};
170
171 #[test]
172 fn test_parse_pyo3_attr() -> Result<()> {
173 let item: ItemStruct = parse_str(
174 r#"
175 #[pyclass(mapping, module = "my_module", name = "Placeholder")]
176 #[pyo3(rename_all = "SCREAMING_SNAKE_CASE")]
177 pub struct PyPlaceholder {
178 #[pyo3(get)]
179 pub name: String,
180 }
181 "#,
182 )?;
183 let attrs = parse_pyo3_attrs(&item.attrs)?;
185 assert_eq!(
186 attrs,
187 vec![
188 Attr::Module("my_module".to_string()),
189 Attr::Name("Placeholder".to_string()),
190 Attr::RenameAll(RenamingRule::ScreamingSnakeCase),
191 ]
192 );
193
194 if let Fields::Named(fields) = item.fields {
196 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
197 assert_eq!(attrs, vec![Attr::Get]);
198 } else {
199 unreachable!()
200 }
201 Ok(())
202 }
203
204 #[test]
205 fn test_parse_pyo3_attr_full_path() -> Result<()> {
206 let item: ItemStruct = parse_str(
207 r#"
208 #[pyo3::pyclass(mapping, module = "my_module", name = "Placeholder")]
209 pub struct PyPlaceholder {
210 #[pyo3(get)]
211 pub name: String,
212 }
213 "#,
214 )?;
215 let attrs = parse_pyo3_attr(&item.attrs[0])?;
217 assert_eq!(
218 attrs,
219 vec![
220 Attr::Module("my_module".to_string()),
221 Attr::Name("Placeholder".to_string())
222 ]
223 );
224
225 if let Fields::Named(fields) = item.fields {
227 let attrs = parse_pyo3_attr(&fields.named[0].attrs[0])?;
228 assert_eq!(attrs, vec![Attr::Get]);
229 } else {
230 unreachable!()
231 }
232 Ok(())
233 }
234}