pyo3_stub_gen/generate/
qualifier.rs

1//! Context-aware type name qualification for Python stub files.
2//!
3//! This module provides utilities to qualify type identifiers within compound type expressions
4//! based on the target module context. For example, `typing.Optional[ClassA]` should become
5//! `typing.Optional[sub_mod.ClassA]` when ClassA is from a different module.
6
7use crate::stub_type::{ImportKind, TypeIdentifierRef};
8use std::collections::HashMap;
9
10/// Token types in Python type expressions
11#[derive(Debug, Clone, PartialEq)]
12pub(crate) enum Token {
13    /// Bare identifier (e.g., "ClassA", "int")
14    Identifier(String),
15    /// Dotted path (e.g., "typing.Optional", "collections.abc.Callable")
16    DottedPath(Vec<String>),
17    /// Opening bracket: [ or (
18    OpenBracket(char),
19    /// Closing bracket: ] or )
20    CloseBracket(char),
21    /// Comma separator
22    Comma,
23    /// Pipe operator for unions (PEP 604)
24    Pipe,
25    /// Ellipsis (...)
26    Ellipsis,
27    /// String literal for forward references
28    StringLiteral(String),
29    /// Whitespace (preserved for formatting)
30    Whitespace(String),
31    /// Numeric literal (e.g., "42", "3.14", "-1")
32    NumericLiteral(String),
33}
34
35/// Tokenizes a Python type expression into tokens.
36///
37/// Handles:
38/// - Identifiers: `ClassA`, `int`, `str`
39/// - Dotted paths: `typing.Optional`, `collections.abc.Callable`
40/// - Brackets: `[`, `]`, `(`, `)`
41/// - Special characters: `,`, `|`, `...`
42/// - String literals: `"ForwardRef"`
43/// - Whitespace preservation
44pub(crate) fn tokenize(expr: &str) -> Vec<Token> {
45    let mut tokens = Vec::new();
46    let mut chars = expr.chars().peekable();
47
48    while let Some(&ch) = chars.peek() {
49        match ch {
50            // Whitespace
51            ' ' | '\t' | '\n' | '\r' => {
52                let mut ws = String::new();
53                while let Some(&c) = chars.peek() {
54                    if c.is_whitespace() {
55                        ws.push(c);
56                        chars.next();
57                    } else {
58                        break;
59                    }
60                }
61                tokens.push(Token::Whitespace(ws));
62            }
63
64            // Brackets
65            '[' | '(' => {
66                tokens.push(Token::OpenBracket(ch));
67                chars.next();
68            }
69            ']' | ')' => {
70                tokens.push(Token::CloseBracket(ch));
71                chars.next();
72            }
73
74            // Comma
75            ',' => {
76                tokens.push(Token::Comma);
77                chars.next();
78            }
79
80            // Pipe (union operator)
81            '|' => {
82                tokens.push(Token::Pipe);
83                chars.next();
84            }
85
86            // String literals (forward references)
87            '"' | '\'' => {
88                let quote_char = ch;
89                chars.next(); // consume opening quote
90                let mut content = String::new();
91
92                while let Some(&c) = chars.peek() {
93                    chars.next();
94                    if c == quote_char {
95                        break;
96                    }
97                    // Handle escape sequences
98                    if c == '\\' {
99                        if let Some(&next) = chars.peek() {
100                            content.push(c);
101                            content.push(next);
102                            chars.next();
103                        }
104                    } else {
105                        content.push(c);
106                    }
107                }
108
109                tokens.push(Token::StringLiteral(content));
110            }
111
112            // Dot - could be start of ellipsis or part of dotted path
113            '.' => {
114                // Look ahead for ellipsis
115                let mut peek_chars = chars.clone();
116                peek_chars.next(); // skip first dot
117                if matches!(peek_chars.peek(), Some(&'.')) {
118                    peek_chars.next();
119                    if matches!(peek_chars.peek(), Some(&'.')) {
120                        // It's an ellipsis
121                        chars.next();
122                        chars.next();
123                        chars.next();
124                        tokens.push(Token::Ellipsis);
125                        continue;
126                    }
127                }
128
129                // Otherwise, it's part of a dotted path - this shouldn't happen
130                // as dots should be consumed as part of identifiers
131                chars.next();
132            }
133
134            // Identifier or dotted path
135            _ if ch.is_alphabetic() || ch == '_' => {
136                let mut ident = String::new();
137                let mut parts = Vec::new();
138
139                // Read first identifier
140                while let Some(&c) = chars.peek() {
141                    if c.is_alphanumeric() || c == '_' {
142                        ident.push(c);
143                        chars.next();
144                    } else {
145                        break;
146                    }
147                }
148
149                parts.push(ident.clone());
150
151                // Check for dotted path
152                while let Some(&'.') = chars.peek() {
153                    // Look ahead to see if there's an identifier after the dot
154                    let mut peek = chars.clone();
155                    peek.next(); // skip dot
156
157                    if let Some(&c) = peek.peek() {
158                        if c.is_alphabetic() || c == '_' {
159                            // It's a dotted path
160                            chars.next(); // consume dot
161                            ident.clear();
162
163                            while let Some(&c) = chars.peek() {
164                                if c.is_alphanumeric() || c == '_' {
165                                    ident.push(c);
166                                    chars.next();
167                                } else {
168                                    break;
169                                }
170                            }
171
172                            parts.push(ident.clone());
173                        } else {
174                            break;
175                        }
176                    } else {
177                        break;
178                    }
179                }
180
181                // Create token based on whether it's a dotted path
182                if parts.len() > 1 {
183                    tokens.push(Token::DottedPath(parts));
184                } else {
185                    tokens.push(Token::Identifier(parts[0].clone()));
186                }
187            }
188
189            // Numeric literals (e.g., 42, 3.14, -1)
190            _ if ch.is_ascii_digit()
191                || (ch == '-' && chars.clone().nth(1).is_some_and(|c| c.is_ascii_digit())) =>
192            {
193                let mut num = String::new();
194                // Handle negative sign
195                if ch == '-' {
196                    num.push(ch);
197                    chars.next();
198                }
199                // Read digits, dots, and scientific notation
200                while let Some(&c) = chars.peek() {
201                    if c.is_ascii_digit()
202                        || c == '.'
203                        || c == 'e'
204                        || c == 'E'
205                        || c == '+'
206                        || c == '-'
207                    {
208                        // Special handling: dot must be followed by digit for float
209                        if c == '.' {
210                            let mut peek = chars.clone();
211                            peek.next();
212                            if !peek.peek().is_some_and(|&d| d.is_ascii_digit()) {
213                                break;
214                            }
215                        }
216                        // For +/- in scientific notation, must be after e/E
217                        if (c == '+' || c == '-') && !num.ends_with('e') && !num.ends_with('E') {
218                            break;
219                        }
220                        num.push(c);
221                        chars.next();
222                    } else {
223                        break;
224                    }
225                }
226                tokens.push(Token::NumericLiteral(num));
227            }
228
229            // Skip other characters (shouldn't happen in valid type expressions)
230            _ => {
231                chars.next();
232            }
233        }
234    }
235
236    tokens
237}
238
239/// Type expression qualifier that rewrites identifiers based on module context.
240pub(crate) struct TypeExpressionQualifier;
241
242impl TypeExpressionQualifier {
243    /// Qualify a type expression based on the type references
244    ///
245    /// This rewrites bare identifiers in the expression to add module qualifiers
246    /// when necessary, based on the import context.
247    ///
248    /// # Parameters
249    /// - `expr`: The type expression to qualify
250    /// - `type_refs`: Map of type names to their module references
251    /// - `target_module`: The module where this type expression will be used
252    pub(crate) fn qualify_expression(
253        expr: &str,
254        type_refs: &HashMap<String, TypeIdentifierRef>,
255        target_module: &str,
256    ) -> String {
257        let tokens = tokenize(expr);
258        let mut result = String::new();
259
260        for token in tokens {
261            match token {
262                Token::Identifier(ref name) => {
263                    // Check if this identifier needs qualification
264                    if let Some(type_ref) = type_refs.get(name) {
265                        match type_ref.import_kind {
266                            ImportKind::ByName | ImportKind::SameModule => {
267                                // Can use unqualified
268                                result.push_str(name);
269                            }
270                            ImportKind::Module => {
271                                // Need to qualify with module component
272                                if let Some(module_name) = type_ref.module.get() {
273                                    // Check if type is from same module as target
274                                    if module_name == target_module {
275                                        // Same module - use unqualified name
276                                        result.push_str(name);
277                                    } else {
278                                        // Different module - qualify with last component
279                                        let module_component =
280                                            module_name.rsplit('.').next().unwrap_or(module_name);
281                                        result.push_str(module_component);
282                                        result.push('.');
283                                        result.push_str(name);
284                                    }
285                                } else {
286                                    // No module info, use as-is
287                                    result.push_str(name);
288                                }
289                            }
290                        }
291                    } else if Self::is_python_builtin(name) {
292                        // Known Python builtin or typing construct - use as-is
293                        result.push_str(name);
294                    } else {
295                        // Unknown identifier - preserve as-is
296                        result.push_str(name);
297                    }
298                }
299                Token::DottedPath(parts) => {
300                    // Handle dotted paths like "module.Type" or "module.Class.Member"
301                    // We need to distinguish between:
302                    // 1. Local over-qualified paths: "_core.C.C1" in "pkg._core" → "C.C1"
303                    // 2. External paths: "collections.abc.Callable" → preserve as-is
304                    //
305                    // Strategy:
306                    // - If any part is a known Python builtin → external, preserve
307                    // - If the type is in type_refs → use that module info
308                    // - Otherwise, use suffix matching for local modules
309                    if parts.len() >= 2 {
310                        // Check if any part is a Python builtin (indicates external path)
311                        let contains_builtin = parts.iter().any(|p| Self::is_python_builtin(p));
312
313                        if contains_builtin {
314                            // External type path - preserve as-is
315                            result.push_str(&parts.join("."));
316                        } else {
317                            // Check if the type (parts[1] or last meaningful part) is in type_refs
318                            let type_name = &parts[1];
319                            let module_prefix = &parts[0];
320
321                            if let Some(type_ref) = type_refs.get(type_name) {
322                                // We have type info - use it to decide
323                                if let Some(module_name) = type_ref.module.get() {
324                                    // Check if the module matches target_module
325                                    if module_name == target_module {
326                                        // Same module - strip prefix
327                                        result.push_str(&parts[1..].join("."));
328                                    } else {
329                                        // Different module - preserve full path
330                                        result.push_str(&parts.join("."));
331                                    }
332                                } else {
333                                    // No module info in type_ref - preserve as-is
334                                    result.push_str(&parts.join("."));
335                                }
336                            } else {
337                                // Type not in type_refs - use suffix matching as fallback
338                                // This handles cases like "_core.C.C1" when C is not explicitly tracked
339                                let is_local_module = target_module == module_prefix
340                                    || target_module.ends_with(&format!(".{}", module_prefix));
341
342                                if is_local_module {
343                                    // Local module - strip prefix
344                                    result.push_str(&parts[1..].join("."));
345                                } else {
346                                    // Different module - preserve full path
347                                    result.push_str(&parts.join("."));
348                                }
349                            }
350                        }
351                    } else {
352                        // Single-part path - preserve as-is (shouldn't happen for DottedPath)
353                        result.push_str(&parts.join("."));
354                    }
355                }
356                Token::OpenBracket(ch) => result.push(ch),
357                Token::CloseBracket(ch) => result.push(ch),
358                Token::Comma => result.push(','),
359                Token::Pipe => result.push_str(" | "),
360                Token::Ellipsis => result.push_str("..."),
361                Token::StringLiteral(s) => {
362                    // String literals (forward references) - wrap in quotes
363                    result.push('"');
364                    result.push_str(&s);
365                    result.push('"');
366                }
367                Token::Whitespace(ws) => result.push_str(&ws),
368                Token::NumericLiteral(num) => result.push_str(&num),
369            }
370        }
371
372        result
373    }
374
375    /// Check if an identifier is a known Python builtin or typing construct
376    fn is_python_builtin(identifier: &str) -> bool {
377        matches!(
378            identifier,
379            // typing module types
380            "Any" | "Optional" | "Union" | "List" | "Dict" | "Tuple" | "Set" |
381            "Callable" | "Sequence" | "Mapping" | "Iterable" | "Iterator" |
382            "Literal" | "TypeVar" | "Generic" | "Protocol" | "TypeAlias" |
383            "Final" | "ClassVar" | "Annotated" | "TypeGuard" | "Never" |
384            // builtins
385            "int" | "str" | "float" | "bool" | "bytes" | "bytearray" |
386            "list" | "dict" | "tuple" | "set" | "frozenset" |
387            "object" | "type" | "None" | "Ellipsis" |
388            "complex" | "slice" | "range" | "memoryview" |
389            // Special
390            "typing" | "collections" | "abc" | "builtins"
391        )
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::stub_type::ModuleRef;
399
400    #[test]
401    fn test_tokenize_simple() {
402        let tokens = tokenize("ClassA");
403        assert_eq!(tokens, vec![Token::Identifier("ClassA".to_string())]);
404    }
405
406    #[test]
407    fn test_tokenize_optional() {
408        let tokens = tokenize("typing.Optional[ClassA]");
409        assert_eq!(
410            tokens,
411            vec![
412                Token::DottedPath(vec!["typing".to_string(), "Optional".to_string()]),
413                Token::OpenBracket('['),
414                Token::Identifier("ClassA".to_string()),
415                Token::CloseBracket(']'),
416            ]
417        );
418    }
419
420    #[test]
421    fn test_tokenize_callable() {
422        let tokens = tokenize("Callable[[ClassA, str], int]");
423        assert_eq!(
424            tokens,
425            vec![
426                Token::Identifier("Callable".to_string()),
427                Token::OpenBracket('['),
428                Token::OpenBracket('['),
429                Token::Identifier("ClassA".to_string()),
430                Token::Comma,
431                Token::Whitespace(" ".to_string()),
432                Token::Identifier("str".to_string()),
433                Token::CloseBracket(']'),
434                Token::Comma,
435                Token::Whitespace(" ".to_string()),
436                Token::Identifier("int".to_string()),
437                Token::CloseBracket(']'),
438            ]
439        );
440    }
441
442    #[test]
443    fn test_tokenize_union() {
444        let tokens = tokenize("ClassA | ClassB");
445        assert_eq!(
446            tokens,
447            vec![
448                Token::Identifier("ClassA".to_string()),
449                Token::Whitespace(" ".to_string()),
450                Token::Pipe,
451                Token::Whitespace(" ".to_string()),
452                Token::Identifier("ClassB".to_string()),
453            ]
454        );
455    }
456
457    #[test]
458    fn test_qualify_simple() {
459        let mut type_refs = HashMap::new();
460        type_refs.insert(
461            "ClassA".to_string(),
462            TypeIdentifierRef {
463                module: ModuleRef::Named("test_package.sub_mod".into()),
464                import_kind: ImportKind::Module,
465            },
466        );
467
468        let result =
469            TypeExpressionQualifier::qualify_expression("ClassA", &type_refs, "test_package");
470        assert_eq!(result, "sub_mod.ClassA");
471    }
472
473    #[test]
474    fn test_qualify_optional() {
475        let mut type_refs = HashMap::new();
476        type_refs.insert(
477            "ClassA".to_string(),
478            TypeIdentifierRef {
479                module: ModuleRef::Named("test_package.sub_mod".into()),
480                import_kind: ImportKind::Module,
481            },
482        );
483
484        let result = TypeExpressionQualifier::qualify_expression(
485            "typing.Optional[ClassA]",
486            &type_refs,
487            "test_package",
488        );
489        assert_eq!(result, "typing.Optional[sub_mod.ClassA]");
490    }
491
492    #[test]
493    fn test_qualify_same_module() {
494        let mut type_refs = HashMap::new();
495        type_refs.insert(
496            "ClassA".to_string(),
497            TypeIdentifierRef {
498                module: ModuleRef::Named("test_package.sub_mod".into()),
499                import_kind: ImportKind::SameModule,
500            },
501        );
502
503        let result = TypeExpressionQualifier::qualify_expression(
504            "typing.Optional[ClassA]",
505            &type_refs,
506            "test_package.sub_mod",
507        );
508        assert_eq!(result, "typing.Optional[ClassA]");
509    }
510
511    #[test]
512    fn test_qualify_callable() {
513        let mut type_refs = HashMap::new();
514        type_refs.insert(
515            "ClassA".to_string(),
516            TypeIdentifierRef {
517                module: ModuleRef::Named("test_package.sub_mod".into()),
518                import_kind: ImportKind::Module,
519            },
520        );
521        type_refs.insert(
522            "ClassB".to_string(),
523            TypeIdentifierRef {
524                module: ModuleRef::Named("test_package.other_mod".into()),
525                import_kind: ImportKind::Module,
526            },
527        );
528
529        let result = TypeExpressionQualifier::qualify_expression(
530            "collections.abc.Callable[[ClassA, str], ClassB]",
531            &type_refs,
532            "test_package",
533        );
534        assert_eq!(
535            result,
536            "collections.abc.Callable[[sub_mod.ClassA, str], other_mod.ClassB]"
537        );
538    }
539
540    #[test]
541    fn test_qualify_dotted_path_three_parts_same_module() {
542        // Test: _core.C.C1 in module "pkg._core" should become C.C1
543        let result =
544            TypeExpressionQualifier::qualify_expression("_core.C.C1", &HashMap::new(), "pkg._core");
545        assert_eq!(result, "C.C1");
546    }
547
548    #[test]
549    fn test_qualify_dotted_path_three_parts_different_module() {
550        // Test: _core.C.C1 in module "pkg.other" should stay _core.C.C1
551        let result =
552            TypeExpressionQualifier::qualify_expression("_core.C.C1", &HashMap::new(), "pkg.other");
553        assert_eq!(result, "_core.C.C1");
554    }
555
556    #[test]
557    fn test_qualify_dotted_path_two_parts_same_module() {
558        // Test: _core.C in module "pkg._core" should become C
559        let result =
560            TypeExpressionQualifier::qualify_expression("_core.C", &HashMap::new(), "pkg._core");
561        assert_eq!(result, "C");
562    }
563
564    #[test]
565    fn test_tokenize_numeric_literals() {
566        // Test integer
567        let tokens = tokenize("42");
568        assert_eq!(tokens, vec![Token::NumericLiteral("42".to_string())]);
569
570        // Test float
571        let tokens = tokenize("3.14");
572        assert_eq!(tokens, vec![Token::NumericLiteral("3.14".to_string())]);
573
574        // Test negative integer
575        let tokens = tokenize("-1");
576        assert_eq!(tokens, vec![Token::NumericLiteral("-1".to_string())]);
577    }
578
579    #[test]
580    fn test_qualify_numeric_literal_preserved() {
581        // Test: numeric literals should be preserved as-is
582        let result = TypeExpressionQualifier::qualify_expression("2", &HashMap::new(), "pkg._core");
583        assert_eq!(result, "2");
584
585        let result =
586            TypeExpressionQualifier::qualify_expression("1.0", &HashMap::new(), "pkg._core");
587        assert_eq!(result, "1.0");
588    }
589
590    #[test]
591    fn test_external_dotted_path_preserved() {
592        // Test: external paths like collections.abc.Callable should NEVER be stripped
593        // even if the target module happens to end with a matching component
594
595        // collections.abc.Callable should be preserved even in pkg.collections
596        let result = TypeExpressionQualifier::qualify_expression(
597            "collections.abc.Callable",
598            &HashMap::new(),
599            "pkg.collections",
600        );
601        assert_eq!(result, "collections.abc.Callable");
602
603        // typing.Optional should be preserved as-is
604        let result = TypeExpressionQualifier::qualify_expression(
605            "typing.Optional",
606            &HashMap::new(),
607            "pkg.typing",
608        );
609        assert_eq!(result, "typing.Optional");
610
611        // builtins.int should be preserved
612        let result = TypeExpressionQualifier::qualify_expression(
613            "builtins.int",
614            &HashMap::new(),
615            "pkg.builtins",
616        );
617        assert_eq!(result, "builtins.int");
618    }
619}