pyo3_stub_gen/generate/qualifier.rs
1//! Context-aware type name qualification for Python stub files.
2//!
3//! This module provides utilities to qualify type identifiers within compound type expressions
4//! based on the target module context. For example, `typing.Optional[ClassA]` should become
5//! `typing.Optional[sub_mod.ClassA]` when ClassA is from a different module.
6
7use crate::stub_type::{ImportKind, TypeIdentifierRef};
8use std::collections::HashMap;
9
10/// Token types in Python type expressions
11#[derive(Debug, Clone, PartialEq)]
12pub(crate) enum Token {
13 /// Bare identifier (e.g., "ClassA", "int")
14 Identifier(String),
15 /// Dotted path (e.g., "typing.Optional", "collections.abc.Callable")
16 DottedPath(Vec<String>),
17 /// Opening bracket: [ or (
18 OpenBracket(char),
19 /// Closing bracket: ] or )
20 CloseBracket(char),
21 /// Comma separator
22 Comma,
23 /// Pipe operator for unions (PEP 604)
24 Pipe,
25 /// Ellipsis (...)
26 Ellipsis,
27 /// String literal for forward references
28 StringLiteral(String),
29 /// Whitespace (preserved for formatting)
30 Whitespace(String),
31 /// Numeric literal (e.g., "42", "3.14", "-1")
32 NumericLiteral(String),
33}
34
35/// Tokenizes a Python type expression into tokens.
36///
37/// Handles:
38/// - Identifiers: `ClassA`, `int`, `str`
39/// - Dotted paths: `typing.Optional`, `collections.abc.Callable`
40/// - Brackets: `[`, `]`, `(`, `)`
41/// - Special characters: `,`, `|`, `...`
42/// - String literals: `"ForwardRef"`
43/// - Whitespace preservation
44pub(crate) fn tokenize(expr: &str) -> Vec<Token> {
45 let mut tokens = Vec::new();
46 let mut chars = expr.chars().peekable();
47
48 while let Some(&ch) = chars.peek() {
49 match ch {
50 // Whitespace
51 ' ' | '\t' | '\n' | '\r' => {
52 let mut ws = String::new();
53 while let Some(&c) = chars.peek() {
54 if c.is_whitespace() {
55 ws.push(c);
56 chars.next();
57 } else {
58 break;
59 }
60 }
61 tokens.push(Token::Whitespace(ws));
62 }
63
64 // Brackets
65 '[' | '(' => {
66 tokens.push(Token::OpenBracket(ch));
67 chars.next();
68 }
69 ']' | ')' => {
70 tokens.push(Token::CloseBracket(ch));
71 chars.next();
72 }
73
74 // Comma
75 ',' => {
76 tokens.push(Token::Comma);
77 chars.next();
78 }
79
80 // Pipe (union operator)
81 '|' => {
82 tokens.push(Token::Pipe);
83 chars.next();
84 }
85
86 // String literals (forward references)
87 '"' | '\'' => {
88 let quote_char = ch;
89 chars.next(); // consume opening quote
90 let mut content = String::new();
91
92 while let Some(&c) = chars.peek() {
93 chars.next();
94 if c == quote_char {
95 break;
96 }
97 // Handle escape sequences
98 if c == '\\' {
99 if let Some(&next) = chars.peek() {
100 content.push(c);
101 content.push(next);
102 chars.next();
103 }
104 } else {
105 content.push(c);
106 }
107 }
108
109 tokens.push(Token::StringLiteral(content));
110 }
111
112 // Dot - could be start of ellipsis or part of dotted path
113 '.' => {
114 // Look ahead for ellipsis
115 let mut peek_chars = chars.clone();
116 peek_chars.next(); // skip first dot
117 if matches!(peek_chars.peek(), Some(&'.')) {
118 peek_chars.next();
119 if matches!(peek_chars.peek(), Some(&'.')) {
120 // It's an ellipsis
121 chars.next();
122 chars.next();
123 chars.next();
124 tokens.push(Token::Ellipsis);
125 continue;
126 }
127 }
128
129 // Otherwise, it's part of a dotted path - this shouldn't happen
130 // as dots should be consumed as part of identifiers
131 chars.next();
132 }
133
134 // Identifier or dotted path
135 _ if ch.is_alphabetic() || ch == '_' => {
136 let mut ident = String::new();
137 let mut parts = Vec::new();
138
139 // Read first identifier
140 while let Some(&c) = chars.peek() {
141 if c.is_alphanumeric() || c == '_' {
142 ident.push(c);
143 chars.next();
144 } else {
145 break;
146 }
147 }
148
149 parts.push(ident.clone());
150
151 // Check for dotted path
152 while let Some(&'.') = chars.peek() {
153 // Look ahead to see if there's an identifier after the dot
154 let mut peek = chars.clone();
155 peek.next(); // skip dot
156
157 if let Some(&c) = peek.peek() {
158 if c.is_alphabetic() || c == '_' {
159 // It's a dotted path
160 chars.next(); // consume dot
161 ident.clear();
162
163 while let Some(&c) = chars.peek() {
164 if c.is_alphanumeric() || c == '_' {
165 ident.push(c);
166 chars.next();
167 } else {
168 break;
169 }
170 }
171
172 parts.push(ident.clone());
173 } else {
174 break;
175 }
176 } else {
177 break;
178 }
179 }
180
181 // Create token based on whether it's a dotted path
182 if parts.len() > 1 {
183 tokens.push(Token::DottedPath(parts));
184 } else {
185 tokens.push(Token::Identifier(parts[0].clone()));
186 }
187 }
188
189 // Numeric literals (e.g., 42, 3.14, -1)
190 _ if ch.is_ascii_digit()
191 || (ch == '-' && chars.clone().nth(1).is_some_and(|c| c.is_ascii_digit())) =>
192 {
193 let mut num = String::new();
194 // Handle negative sign
195 if ch == '-' {
196 num.push(ch);
197 chars.next();
198 }
199 // Read digits, dots, and scientific notation
200 while let Some(&c) = chars.peek() {
201 if c.is_ascii_digit()
202 || c == '.'
203 || c == 'e'
204 || c == 'E'
205 || c == '+'
206 || c == '-'
207 {
208 // Special handling: dot must be followed by digit for float
209 if c == '.' {
210 let mut peek = chars.clone();
211 peek.next();
212 if !peek.peek().is_some_and(|&d| d.is_ascii_digit()) {
213 break;
214 }
215 }
216 // For +/- in scientific notation, must be after e/E
217 if (c == '+' || c == '-') && !num.ends_with('e') && !num.ends_with('E') {
218 break;
219 }
220 num.push(c);
221 chars.next();
222 } else {
223 break;
224 }
225 }
226 tokens.push(Token::NumericLiteral(num));
227 }
228
229 // Skip other characters (shouldn't happen in valid type expressions)
230 _ => {
231 chars.next();
232 }
233 }
234 }
235
236 tokens
237}
238
239/// Type expression qualifier that rewrites identifiers based on module context.
240pub(crate) struct TypeExpressionQualifier;
241
242impl TypeExpressionQualifier {
243 /// Qualify a type expression based on the type references
244 ///
245 /// This rewrites bare identifiers in the expression to add module qualifiers
246 /// when necessary, based on the import context.
247 ///
248 /// # Parameters
249 /// - `expr`: The type expression to qualify
250 /// - `type_refs`: Map of type names to their module references
251 /// - `target_module`: The module where this type expression will be used
252 pub(crate) fn qualify_expression(
253 expr: &str,
254 type_refs: &HashMap<String, TypeIdentifierRef>,
255 target_module: &str,
256 ) -> String {
257 let tokens = tokenize(expr);
258 let mut result = String::new();
259
260 for token in tokens {
261 match token {
262 Token::Identifier(ref name) => {
263 // Check if this identifier needs qualification
264 if let Some(type_ref) = type_refs.get(name) {
265 match type_ref.import_kind {
266 ImportKind::ByName | ImportKind::SameModule => {
267 // Can use unqualified
268 result.push_str(name);
269 }
270 ImportKind::Module => {
271 // Need to qualify with module component
272 if let Some(module_name) = type_ref.module.get() {
273 // Check if type is from same module as target
274 if module_name == target_module {
275 // Same module - use unqualified name
276 result.push_str(name);
277 } else {
278 // Different module - qualify with last component
279 let module_component =
280 module_name.rsplit('.').next().unwrap_or(module_name);
281 result.push_str(module_component);
282 result.push('.');
283 result.push_str(name);
284 }
285 } else {
286 // No module info, use as-is
287 result.push_str(name);
288 }
289 }
290 }
291 } else if Self::is_python_builtin(name) {
292 // Known Python builtin or typing construct - use as-is
293 result.push_str(name);
294 } else {
295 // Unknown identifier - preserve as-is
296 result.push_str(name);
297 }
298 }
299 Token::DottedPath(parts) => {
300 // Handle dotted paths like "module.Type" or "module.Class.Member"
301 // We need to distinguish between:
302 // 1. Local over-qualified paths: "_core.C.C1" in "pkg._core" → "C.C1"
303 // 2. External paths: "collections.abc.Callable" → preserve as-is
304 //
305 // Strategy:
306 // - If any part is a known Python builtin → external, preserve
307 // - If the type is in type_refs → use that module info
308 // - Otherwise, use suffix matching for local modules
309 if parts.len() >= 2 {
310 // Check if any part is a Python builtin (indicates external path)
311 let contains_builtin = parts.iter().any(|p| Self::is_python_builtin(p));
312
313 if contains_builtin {
314 // External type path - preserve as-is
315 result.push_str(&parts.join("."));
316 } else {
317 // Check if the type (parts[1] or last meaningful part) is in type_refs
318 let type_name = &parts[1];
319 let module_prefix = &parts[0];
320
321 if let Some(type_ref) = type_refs.get(type_name) {
322 // We have type info - use it to decide
323 if let Some(module_name) = type_ref.module.get() {
324 // Check if the module matches target_module
325 if module_name == target_module {
326 // Same module - strip prefix
327 result.push_str(&parts[1..].join("."));
328 } else {
329 // Different module - preserve full path
330 result.push_str(&parts.join("."));
331 }
332 } else {
333 // No module info in type_ref - preserve as-is
334 result.push_str(&parts.join("."));
335 }
336 } else {
337 // Type not in type_refs - use suffix matching as fallback
338 // This handles cases like "_core.C.C1" when C is not explicitly tracked
339 let is_local_module = target_module == module_prefix
340 || target_module.ends_with(&format!(".{}", module_prefix));
341
342 if is_local_module {
343 // Local module - strip prefix
344 result.push_str(&parts[1..].join("."));
345 } else {
346 // Different module - preserve full path
347 result.push_str(&parts.join("."));
348 }
349 }
350 }
351 } else {
352 // Single-part path - preserve as-is (shouldn't happen for DottedPath)
353 result.push_str(&parts.join("."));
354 }
355 }
356 Token::OpenBracket(ch) => result.push(ch),
357 Token::CloseBracket(ch) => result.push(ch),
358 Token::Comma => result.push(','),
359 Token::Pipe => result.push_str(" | "),
360 Token::Ellipsis => result.push_str("..."),
361 Token::StringLiteral(s) => {
362 // String literals (forward references) - wrap in quotes
363 result.push('"');
364 result.push_str(&s);
365 result.push('"');
366 }
367 Token::Whitespace(ws) => result.push_str(&ws),
368 Token::NumericLiteral(num) => result.push_str(&num),
369 }
370 }
371
372 result
373 }
374
375 /// Check if an identifier is a known Python builtin or typing construct
376 fn is_python_builtin(identifier: &str) -> bool {
377 matches!(
378 identifier,
379 // typing module types
380 "Any" | "Optional" | "Union" | "List" | "Dict" | "Tuple" | "Set" |
381 "Callable" | "Sequence" | "Mapping" | "Iterable" | "Iterator" |
382 "Literal" | "TypeVar" | "Generic" | "Protocol" | "TypeAlias" |
383 "Final" | "ClassVar" | "Annotated" | "TypeGuard" | "Never" |
384 // builtins
385 "int" | "str" | "float" | "bool" | "bytes" | "bytearray" |
386 "list" | "dict" | "tuple" | "set" | "frozenset" |
387 "object" | "type" | "None" | "Ellipsis" |
388 "complex" | "slice" | "range" | "memoryview" |
389 // Special
390 "typing" | "collections" | "abc" | "builtins"
391 )
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use crate::stub_type::ModuleRef;
399
400 #[test]
401 fn test_tokenize_simple() {
402 let tokens = tokenize("ClassA");
403 assert_eq!(tokens, vec![Token::Identifier("ClassA".to_string())]);
404 }
405
406 #[test]
407 fn test_tokenize_optional() {
408 let tokens = tokenize("typing.Optional[ClassA]");
409 assert_eq!(
410 tokens,
411 vec![
412 Token::DottedPath(vec!["typing".to_string(), "Optional".to_string()]),
413 Token::OpenBracket('['),
414 Token::Identifier("ClassA".to_string()),
415 Token::CloseBracket(']'),
416 ]
417 );
418 }
419
420 #[test]
421 fn test_tokenize_callable() {
422 let tokens = tokenize("Callable[[ClassA, str], int]");
423 assert_eq!(
424 tokens,
425 vec![
426 Token::Identifier("Callable".to_string()),
427 Token::OpenBracket('['),
428 Token::OpenBracket('['),
429 Token::Identifier("ClassA".to_string()),
430 Token::Comma,
431 Token::Whitespace(" ".to_string()),
432 Token::Identifier("str".to_string()),
433 Token::CloseBracket(']'),
434 Token::Comma,
435 Token::Whitespace(" ".to_string()),
436 Token::Identifier("int".to_string()),
437 Token::CloseBracket(']'),
438 ]
439 );
440 }
441
442 #[test]
443 fn test_tokenize_union() {
444 let tokens = tokenize("ClassA | ClassB");
445 assert_eq!(
446 tokens,
447 vec![
448 Token::Identifier("ClassA".to_string()),
449 Token::Whitespace(" ".to_string()),
450 Token::Pipe,
451 Token::Whitespace(" ".to_string()),
452 Token::Identifier("ClassB".to_string()),
453 ]
454 );
455 }
456
457 #[test]
458 fn test_qualify_simple() {
459 let mut type_refs = HashMap::new();
460 type_refs.insert(
461 "ClassA".to_string(),
462 TypeIdentifierRef {
463 module: ModuleRef::Named("test_package.sub_mod".into()),
464 import_kind: ImportKind::Module,
465 },
466 );
467
468 let result =
469 TypeExpressionQualifier::qualify_expression("ClassA", &type_refs, "test_package");
470 assert_eq!(result, "sub_mod.ClassA");
471 }
472
473 #[test]
474 fn test_qualify_optional() {
475 let mut type_refs = HashMap::new();
476 type_refs.insert(
477 "ClassA".to_string(),
478 TypeIdentifierRef {
479 module: ModuleRef::Named("test_package.sub_mod".into()),
480 import_kind: ImportKind::Module,
481 },
482 );
483
484 let result = TypeExpressionQualifier::qualify_expression(
485 "typing.Optional[ClassA]",
486 &type_refs,
487 "test_package",
488 );
489 assert_eq!(result, "typing.Optional[sub_mod.ClassA]");
490 }
491
492 #[test]
493 fn test_qualify_same_module() {
494 let mut type_refs = HashMap::new();
495 type_refs.insert(
496 "ClassA".to_string(),
497 TypeIdentifierRef {
498 module: ModuleRef::Named("test_package.sub_mod".into()),
499 import_kind: ImportKind::SameModule,
500 },
501 );
502
503 let result = TypeExpressionQualifier::qualify_expression(
504 "typing.Optional[ClassA]",
505 &type_refs,
506 "test_package.sub_mod",
507 );
508 assert_eq!(result, "typing.Optional[ClassA]");
509 }
510
511 #[test]
512 fn test_qualify_callable() {
513 let mut type_refs = HashMap::new();
514 type_refs.insert(
515 "ClassA".to_string(),
516 TypeIdentifierRef {
517 module: ModuleRef::Named("test_package.sub_mod".into()),
518 import_kind: ImportKind::Module,
519 },
520 );
521 type_refs.insert(
522 "ClassB".to_string(),
523 TypeIdentifierRef {
524 module: ModuleRef::Named("test_package.other_mod".into()),
525 import_kind: ImportKind::Module,
526 },
527 );
528
529 let result = TypeExpressionQualifier::qualify_expression(
530 "collections.abc.Callable[[ClassA, str], ClassB]",
531 &type_refs,
532 "test_package",
533 );
534 assert_eq!(
535 result,
536 "collections.abc.Callable[[sub_mod.ClassA, str], other_mod.ClassB]"
537 );
538 }
539
540 #[test]
541 fn test_qualify_dotted_path_three_parts_same_module() {
542 // Test: _core.C.C1 in module "pkg._core" should become C.C1
543 let result =
544 TypeExpressionQualifier::qualify_expression("_core.C.C1", &HashMap::new(), "pkg._core");
545 assert_eq!(result, "C.C1");
546 }
547
548 #[test]
549 fn test_qualify_dotted_path_three_parts_different_module() {
550 // Test: _core.C.C1 in module "pkg.other" should stay _core.C.C1
551 let result =
552 TypeExpressionQualifier::qualify_expression("_core.C.C1", &HashMap::new(), "pkg.other");
553 assert_eq!(result, "_core.C.C1");
554 }
555
556 #[test]
557 fn test_qualify_dotted_path_two_parts_same_module() {
558 // Test: _core.C in module "pkg._core" should become C
559 let result =
560 TypeExpressionQualifier::qualify_expression("_core.C", &HashMap::new(), "pkg._core");
561 assert_eq!(result, "C");
562 }
563
564 #[test]
565 fn test_tokenize_numeric_literals() {
566 // Test integer
567 let tokens = tokenize("42");
568 assert_eq!(tokens, vec![Token::NumericLiteral("42".to_string())]);
569
570 // Test float
571 let tokens = tokenize("3.14");
572 assert_eq!(tokens, vec![Token::NumericLiteral("3.14".to_string())]);
573
574 // Test negative integer
575 let tokens = tokenize("-1");
576 assert_eq!(tokens, vec![Token::NumericLiteral("-1".to_string())]);
577 }
578
579 #[test]
580 fn test_qualify_numeric_literal_preserved() {
581 // Test: numeric literals should be preserved as-is
582 let result = TypeExpressionQualifier::qualify_expression("2", &HashMap::new(), "pkg._core");
583 assert_eq!(result, "2");
584
585 let result =
586 TypeExpressionQualifier::qualify_expression("1.0", &HashMap::new(), "pkg._core");
587 assert_eq!(result, "1.0");
588 }
589
590 #[test]
591 fn test_external_dotted_path_preserved() {
592 // Test: external paths like collections.abc.Callable should NEVER be stripped
593 // even if the target module happens to end with a matching component
594
595 // collections.abc.Callable should be preserved even in pkg.collections
596 let result = TypeExpressionQualifier::qualify_expression(
597 "collections.abc.Callable",
598 &HashMap::new(),
599 "pkg.collections",
600 );
601 assert_eq!(result, "collections.abc.Callable");
602
603 // typing.Optional should be preserved as-is
604 let result = TypeExpressionQualifier::qualify_expression(
605 "typing.Optional",
606 &HashMap::new(),
607 "pkg.typing",
608 );
609 assert_eq!(result, "typing.Optional");
610
611 // builtins.int should be preserved
612 let result = TypeExpressionQualifier::qualify_expression(
613 "builtins.int",
614 &HashMap::new(),
615 "pkg.builtins",
616 );
617 assert_eq!(result, "builtins.int");
618 }
619}