pyo3_stub_gen/generate/qualifier.rs
1//! Context-aware type name qualification for Python stub files.
2//!
3//! This module provides utilities to qualify type identifiers within compound type expressions
4//! based on the target module context. For example, `typing.Optional[ClassA]` should become
5//! `typing.Optional[sub_mod.ClassA]` when ClassA is from a different module.
6
7use crate::stub_type::{ImportKind, TypeIdentifierRef};
8use std::collections::HashMap;
9
10/// Token types in Python type expressions
11#[derive(Debug, Clone, PartialEq)]
12enum Token {
13 /// Bare identifier (e.g., "ClassA", "int")
14 Identifier(String),
15 /// Dotted path (e.g., "typing.Optional", "collections.abc.Callable")
16 DottedPath(Vec<String>),
17 /// Opening bracket: [ or (
18 OpenBracket(char),
19 /// Closing bracket: ] or )
20 CloseBracket(char),
21 /// Comma separator
22 Comma,
23 /// Pipe operator for unions (PEP 604)
24 Pipe,
25 /// Ellipsis (...)
26 Ellipsis,
27 /// String literal for forward references
28 StringLiteral(String),
29 /// Whitespace (preserved for formatting)
30 Whitespace(String),
31}
32
33/// Tokenizes a Python type expression into tokens.
34///
35/// Handles:
36/// - Identifiers: `ClassA`, `int`, `str`
37/// - Dotted paths: `typing.Optional`, `collections.abc.Callable`
38/// - Brackets: `[`, `]`, `(`, `)`
39/// - Special characters: `,`, `|`, `...`
40/// - String literals: `"ForwardRef"`
41/// - Whitespace preservation
42fn tokenize(expr: &str) -> Vec<Token> {
43 let mut tokens = Vec::new();
44 let mut chars = expr.chars().peekable();
45
46 while let Some(&ch) = chars.peek() {
47 match ch {
48 // Whitespace
49 ' ' | '\t' | '\n' | '\r' => {
50 let mut ws = String::new();
51 while let Some(&c) = chars.peek() {
52 if c.is_whitespace() {
53 ws.push(c);
54 chars.next();
55 } else {
56 break;
57 }
58 }
59 tokens.push(Token::Whitespace(ws));
60 }
61
62 // Brackets
63 '[' | '(' => {
64 tokens.push(Token::OpenBracket(ch));
65 chars.next();
66 }
67 ']' | ')' => {
68 tokens.push(Token::CloseBracket(ch));
69 chars.next();
70 }
71
72 // Comma
73 ',' => {
74 tokens.push(Token::Comma);
75 chars.next();
76 }
77
78 // Pipe (union operator)
79 '|' => {
80 tokens.push(Token::Pipe);
81 chars.next();
82 }
83
84 // String literals (forward references)
85 '"' | '\'' => {
86 let quote_char = ch;
87 chars.next(); // consume opening quote
88 let mut content = String::new();
89
90 while let Some(&c) = chars.peek() {
91 chars.next();
92 if c == quote_char {
93 break;
94 }
95 // Handle escape sequences
96 if c == '\\' {
97 if let Some(&next) = chars.peek() {
98 content.push(c);
99 content.push(next);
100 chars.next();
101 }
102 } else {
103 content.push(c);
104 }
105 }
106
107 tokens.push(Token::StringLiteral(content));
108 }
109
110 // Dot - could be start of ellipsis or part of dotted path
111 '.' => {
112 // Look ahead for ellipsis
113 let mut peek_chars = chars.clone();
114 peek_chars.next(); // skip first dot
115 if matches!(peek_chars.peek(), Some(&'.')) {
116 peek_chars.next();
117 if matches!(peek_chars.peek(), Some(&'.')) {
118 // It's an ellipsis
119 chars.next();
120 chars.next();
121 chars.next();
122 tokens.push(Token::Ellipsis);
123 continue;
124 }
125 }
126
127 // Otherwise, it's part of a dotted path - this shouldn't happen
128 // as dots should be consumed as part of identifiers
129 chars.next();
130 }
131
132 // Identifier or dotted path
133 _ if ch.is_alphabetic() || ch == '_' => {
134 let mut ident = String::new();
135 let mut parts = Vec::new();
136
137 // Read first identifier
138 while let Some(&c) = chars.peek() {
139 if c.is_alphanumeric() || c == '_' {
140 ident.push(c);
141 chars.next();
142 } else {
143 break;
144 }
145 }
146
147 parts.push(ident.clone());
148
149 // Check for dotted path
150 while let Some(&'.') = chars.peek() {
151 // Look ahead to see if there's an identifier after the dot
152 let mut peek = chars.clone();
153 peek.next(); // skip dot
154
155 if let Some(&c) = peek.peek() {
156 if c.is_alphabetic() || c == '_' {
157 // It's a dotted path
158 chars.next(); // consume dot
159 ident.clear();
160
161 while let Some(&c) = chars.peek() {
162 if c.is_alphanumeric() || c == '_' {
163 ident.push(c);
164 chars.next();
165 } else {
166 break;
167 }
168 }
169
170 parts.push(ident.clone());
171 } else {
172 break;
173 }
174 } else {
175 break;
176 }
177 }
178
179 // Create token based on whether it's a dotted path
180 if parts.len() > 1 {
181 tokens.push(Token::DottedPath(parts));
182 } else {
183 tokens.push(Token::Identifier(parts[0].clone()));
184 }
185 }
186
187 // Skip other characters (shouldn't happen in valid type expressions)
188 _ => {
189 chars.next();
190 }
191 }
192 }
193
194 tokens
195}
196
197/// Type expression qualifier that rewrites identifiers based on module context.
198pub(crate) struct TypeExpressionQualifier;
199
200impl TypeExpressionQualifier {
201 /// Qualify a type expression based on the type references
202 ///
203 /// This rewrites bare identifiers in the expression to add module qualifiers
204 /// when necessary, based on the import context.
205 ///
206 /// # Parameters
207 /// - `expr`: The type expression to qualify
208 /// - `type_refs`: Map of type names to their module references
209 /// - `target_module`: The module where this type expression will be used
210 pub(crate) fn qualify_expression(
211 expr: &str,
212 type_refs: &HashMap<String, TypeIdentifierRef>,
213 target_module: &str,
214 ) -> String {
215 let tokens = tokenize(expr);
216 let mut result = String::new();
217
218 for token in tokens {
219 match token {
220 Token::Identifier(ref name) => {
221 // Check if this identifier needs qualification
222 if let Some(type_ref) = type_refs.get(name) {
223 match type_ref.import_kind {
224 ImportKind::ByName | ImportKind::SameModule => {
225 // Can use unqualified
226 result.push_str(name);
227 }
228 ImportKind::Module => {
229 // Need to qualify with module component
230 if let Some(module_name) = type_ref.module.get() {
231 // Check if type is from same module as target
232 if module_name == target_module {
233 // Same module - use unqualified name
234 result.push_str(name);
235 } else {
236 // Different module - qualify with last component
237 let module_component =
238 module_name.rsplit('.').next().unwrap_or(module_name);
239 result.push_str(module_component);
240 result.push('.');
241 result.push_str(name);
242 }
243 } else {
244 // No module info, use as-is
245 result.push_str(name);
246 }
247 }
248 }
249 } else if Self::is_python_builtin(name) {
250 // Known Python builtin or typing construct - use as-is
251 result.push_str(name);
252 } else {
253 // Unknown identifier - preserve as-is
254 result.push_str(name);
255 }
256 }
257 Token::DottedPath(parts) => {
258 // Check if this is an over-qualified path (e.g., "my_module.Type" when we're already in "my_module")
259 // If the dotted path is module.Type and module matches target_module, simplify to just Type
260 if parts.len() == 2 {
261 let module_path = &parts[0];
262 let type_name = &parts[1];
263
264 // Check if target_module matches or ends with the module_path
265 // E.g., target="pkg.sub_mod" matches module_path="sub_mod"
266 let is_same_module = module_path == target_module
267 || target_module.ends_with(&format!(".{}", module_path));
268
269 if is_same_module {
270 // Over-qualified - just use the type name
271 result.push_str(type_name);
272 } else {
273 // Different module - keep the qualification
274 result.push_str(&parts.join("."));
275 }
276 } else {
277 // More complex path - preserve as-is
278 result.push_str(&parts.join("."));
279 }
280 }
281 Token::OpenBracket(ch) => result.push(ch),
282 Token::CloseBracket(ch) => result.push(ch),
283 Token::Comma => result.push(','),
284 Token::Pipe => result.push_str(" | "),
285 Token::Ellipsis => result.push_str("..."),
286 Token::StringLiteral(s) => {
287 // String literals (forward references) - wrap in quotes
288 result.push('"');
289 result.push_str(&s);
290 result.push('"');
291 }
292 Token::Whitespace(ws) => result.push_str(&ws),
293 }
294 }
295
296 result
297 }
298
299 /// Check if an identifier is a known Python builtin or typing construct
300 fn is_python_builtin(identifier: &str) -> bool {
301 matches!(
302 identifier,
303 // typing module types
304 "Any" | "Optional" | "Union" | "List" | "Dict" | "Tuple" | "Set" |
305 "Callable" | "Sequence" | "Mapping" | "Iterable" | "Iterator" |
306 "Literal" | "TypeVar" | "Generic" | "Protocol" | "TypeAlias" |
307 "Final" | "ClassVar" | "Annotated" | "TypeGuard" | "Never" |
308 // builtins
309 "int" | "str" | "float" | "bool" | "bytes" | "bytearray" |
310 "list" | "dict" | "tuple" | "set" | "frozenset" |
311 "object" | "type" | "None" | "Ellipsis" |
312 "complex" | "slice" | "range" | "memoryview" |
313 // Special
314 "typing" | "collections" | "abc" | "builtins"
315 )
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate::stub_type::ModuleRef;
323
324 #[test]
325 fn test_tokenize_simple() {
326 let tokens = tokenize("ClassA");
327 assert_eq!(tokens, vec![Token::Identifier("ClassA".to_string())]);
328 }
329
330 #[test]
331 fn test_tokenize_optional() {
332 let tokens = tokenize("typing.Optional[ClassA]");
333 assert_eq!(
334 tokens,
335 vec![
336 Token::DottedPath(vec!["typing".to_string(), "Optional".to_string()]),
337 Token::OpenBracket('['),
338 Token::Identifier("ClassA".to_string()),
339 Token::CloseBracket(']'),
340 ]
341 );
342 }
343
344 #[test]
345 fn test_tokenize_callable() {
346 let tokens = tokenize("Callable[[ClassA, str], int]");
347 assert_eq!(
348 tokens,
349 vec![
350 Token::Identifier("Callable".to_string()),
351 Token::OpenBracket('['),
352 Token::OpenBracket('['),
353 Token::Identifier("ClassA".to_string()),
354 Token::Comma,
355 Token::Whitespace(" ".to_string()),
356 Token::Identifier("str".to_string()),
357 Token::CloseBracket(']'),
358 Token::Comma,
359 Token::Whitespace(" ".to_string()),
360 Token::Identifier("int".to_string()),
361 Token::CloseBracket(']'),
362 ]
363 );
364 }
365
366 #[test]
367 fn test_tokenize_union() {
368 let tokens = tokenize("ClassA | ClassB");
369 assert_eq!(
370 tokens,
371 vec![
372 Token::Identifier("ClassA".to_string()),
373 Token::Whitespace(" ".to_string()),
374 Token::Pipe,
375 Token::Whitespace(" ".to_string()),
376 Token::Identifier("ClassB".to_string()),
377 ]
378 );
379 }
380
381 #[test]
382 fn test_qualify_simple() {
383 let mut type_refs = HashMap::new();
384 type_refs.insert(
385 "ClassA".to_string(),
386 TypeIdentifierRef {
387 module: ModuleRef::Named("test_package.sub_mod".into()),
388 import_kind: ImportKind::Module,
389 },
390 );
391
392 let result =
393 TypeExpressionQualifier::qualify_expression("ClassA", &type_refs, "test_package");
394 assert_eq!(result, "sub_mod.ClassA");
395 }
396
397 #[test]
398 fn test_qualify_optional() {
399 let mut type_refs = HashMap::new();
400 type_refs.insert(
401 "ClassA".to_string(),
402 TypeIdentifierRef {
403 module: ModuleRef::Named("test_package.sub_mod".into()),
404 import_kind: ImportKind::Module,
405 },
406 );
407
408 let result = TypeExpressionQualifier::qualify_expression(
409 "typing.Optional[ClassA]",
410 &type_refs,
411 "test_package",
412 );
413 assert_eq!(result, "typing.Optional[sub_mod.ClassA]");
414 }
415
416 #[test]
417 fn test_qualify_same_module() {
418 let mut type_refs = HashMap::new();
419 type_refs.insert(
420 "ClassA".to_string(),
421 TypeIdentifierRef {
422 module: ModuleRef::Named("test_package.sub_mod".into()),
423 import_kind: ImportKind::SameModule,
424 },
425 );
426
427 let result = TypeExpressionQualifier::qualify_expression(
428 "typing.Optional[ClassA]",
429 &type_refs,
430 "test_package.sub_mod",
431 );
432 assert_eq!(result, "typing.Optional[ClassA]");
433 }
434
435 #[test]
436 fn test_qualify_callable() {
437 let mut type_refs = HashMap::new();
438 type_refs.insert(
439 "ClassA".to_string(),
440 TypeIdentifierRef {
441 module: ModuleRef::Named("test_package.sub_mod".into()),
442 import_kind: ImportKind::Module,
443 },
444 );
445 type_refs.insert(
446 "ClassB".to_string(),
447 TypeIdentifierRef {
448 module: ModuleRef::Named("test_package.other_mod".into()),
449 import_kind: ImportKind::Module,
450 },
451 );
452
453 let result = TypeExpressionQualifier::qualify_expression(
454 "collections.abc.Callable[[ClassA, str], ClassB]",
455 &type_refs,
456 "test_package",
457 );
458 assert_eq!(
459 result,
460 "collections.abc.Callable[[sub_mod.ClassA, str], other_mod.ClassB]"
461 );
462 }
463}