pyo3_stub_gen/generate/
qualifier.rs

1//! Context-aware type name qualification for Python stub files.
2//!
3//! This module provides utilities to qualify type identifiers within compound type expressions
4//! based on the target module context. For example, `typing.Optional[ClassA]` should become
5//! `typing.Optional[sub_mod.ClassA]` when ClassA is from a different module.
6
7use crate::stub_type::{ImportKind, TypeIdentifierRef};
8use std::collections::HashMap;
9
10/// Token types in Python type expressions
11#[derive(Debug, Clone, PartialEq)]
12enum Token {
13    /// Bare identifier (e.g., "ClassA", "int")
14    Identifier(String),
15    /// Dotted path (e.g., "typing.Optional", "collections.abc.Callable")
16    DottedPath(Vec<String>),
17    /// Opening bracket: [ or (
18    OpenBracket(char),
19    /// Closing bracket: ] or )
20    CloseBracket(char),
21    /// Comma separator
22    Comma,
23    /// Pipe operator for unions (PEP 604)
24    Pipe,
25    /// Ellipsis (...)
26    Ellipsis,
27    /// String literal for forward references
28    StringLiteral(String),
29    /// Whitespace (preserved for formatting)
30    Whitespace(String),
31}
32
33/// Tokenizes a Python type expression into tokens.
34///
35/// Handles:
36/// - Identifiers: `ClassA`, `int`, `str`
37/// - Dotted paths: `typing.Optional`, `collections.abc.Callable`
38/// - Brackets: `[`, `]`, `(`, `)`
39/// - Special characters: `,`, `|`, `...`
40/// - String literals: `"ForwardRef"`
41/// - Whitespace preservation
42fn tokenize(expr: &str) -> Vec<Token> {
43    let mut tokens = Vec::new();
44    let mut chars = expr.chars().peekable();
45
46    while let Some(&ch) = chars.peek() {
47        match ch {
48            // Whitespace
49            ' ' | '\t' | '\n' | '\r' => {
50                let mut ws = String::new();
51                while let Some(&c) = chars.peek() {
52                    if c.is_whitespace() {
53                        ws.push(c);
54                        chars.next();
55                    } else {
56                        break;
57                    }
58                }
59                tokens.push(Token::Whitespace(ws));
60            }
61
62            // Brackets
63            '[' | '(' => {
64                tokens.push(Token::OpenBracket(ch));
65                chars.next();
66            }
67            ']' | ')' => {
68                tokens.push(Token::CloseBracket(ch));
69                chars.next();
70            }
71
72            // Comma
73            ',' => {
74                tokens.push(Token::Comma);
75                chars.next();
76            }
77
78            // Pipe (union operator)
79            '|' => {
80                tokens.push(Token::Pipe);
81                chars.next();
82            }
83
84            // String literals (forward references)
85            '"' | '\'' => {
86                let quote_char = ch;
87                chars.next(); // consume opening quote
88                let mut content = String::new();
89
90                while let Some(&c) = chars.peek() {
91                    chars.next();
92                    if c == quote_char {
93                        break;
94                    }
95                    // Handle escape sequences
96                    if c == '\\' {
97                        if let Some(&next) = chars.peek() {
98                            content.push(c);
99                            content.push(next);
100                            chars.next();
101                        }
102                    } else {
103                        content.push(c);
104                    }
105                }
106
107                tokens.push(Token::StringLiteral(content));
108            }
109
110            // Dot - could be start of ellipsis or part of dotted path
111            '.' => {
112                // Look ahead for ellipsis
113                let mut peek_chars = chars.clone();
114                peek_chars.next(); // skip first dot
115                if matches!(peek_chars.peek(), Some(&'.')) {
116                    peek_chars.next();
117                    if matches!(peek_chars.peek(), Some(&'.')) {
118                        // It's an ellipsis
119                        chars.next();
120                        chars.next();
121                        chars.next();
122                        tokens.push(Token::Ellipsis);
123                        continue;
124                    }
125                }
126
127                // Otherwise, it's part of a dotted path - this shouldn't happen
128                // as dots should be consumed as part of identifiers
129                chars.next();
130            }
131
132            // Identifier or dotted path
133            _ if ch.is_alphabetic() || ch == '_' => {
134                let mut ident = String::new();
135                let mut parts = Vec::new();
136
137                // Read first identifier
138                while let Some(&c) = chars.peek() {
139                    if c.is_alphanumeric() || c == '_' {
140                        ident.push(c);
141                        chars.next();
142                    } else {
143                        break;
144                    }
145                }
146
147                parts.push(ident.clone());
148
149                // Check for dotted path
150                while let Some(&'.') = chars.peek() {
151                    // Look ahead to see if there's an identifier after the dot
152                    let mut peek = chars.clone();
153                    peek.next(); // skip dot
154
155                    if let Some(&c) = peek.peek() {
156                        if c.is_alphabetic() || c == '_' {
157                            // It's a dotted path
158                            chars.next(); // consume dot
159                            ident.clear();
160
161                            while let Some(&c) = chars.peek() {
162                                if c.is_alphanumeric() || c == '_' {
163                                    ident.push(c);
164                                    chars.next();
165                                } else {
166                                    break;
167                                }
168                            }
169
170                            parts.push(ident.clone());
171                        } else {
172                            break;
173                        }
174                    } else {
175                        break;
176                    }
177                }
178
179                // Create token based on whether it's a dotted path
180                if parts.len() > 1 {
181                    tokens.push(Token::DottedPath(parts));
182                } else {
183                    tokens.push(Token::Identifier(parts[0].clone()));
184                }
185            }
186
187            // Skip other characters (shouldn't happen in valid type expressions)
188            _ => {
189                chars.next();
190            }
191        }
192    }
193
194    tokens
195}
196
197/// Type expression qualifier that rewrites identifiers based on module context.
198pub(crate) struct TypeExpressionQualifier;
199
200impl TypeExpressionQualifier {
201    /// Qualify a type expression based on the type references
202    ///
203    /// This rewrites bare identifiers in the expression to add module qualifiers
204    /// when necessary, based on the import context.
205    ///
206    /// # Parameters
207    /// - `expr`: The type expression to qualify
208    /// - `type_refs`: Map of type names to their module references
209    /// - `target_module`: The module where this type expression will be used
210    pub(crate) fn qualify_expression(
211        expr: &str,
212        type_refs: &HashMap<String, TypeIdentifierRef>,
213        target_module: &str,
214    ) -> String {
215        let tokens = tokenize(expr);
216        let mut result = String::new();
217
218        for token in tokens {
219            match token {
220                Token::Identifier(ref name) => {
221                    // Check if this identifier needs qualification
222                    if let Some(type_ref) = type_refs.get(name) {
223                        match type_ref.import_kind {
224                            ImportKind::ByName | ImportKind::SameModule => {
225                                // Can use unqualified
226                                result.push_str(name);
227                            }
228                            ImportKind::Module => {
229                                // Need to qualify with module component
230                                if let Some(module_name) = type_ref.module.get() {
231                                    // Check if type is from same module as target
232                                    if module_name == target_module {
233                                        // Same module - use unqualified name
234                                        result.push_str(name);
235                                    } else {
236                                        // Different module - qualify with last component
237                                        let module_component =
238                                            module_name.rsplit('.').next().unwrap_or(module_name);
239                                        result.push_str(module_component);
240                                        result.push('.');
241                                        result.push_str(name);
242                                    }
243                                } else {
244                                    // No module info, use as-is
245                                    result.push_str(name);
246                                }
247                            }
248                        }
249                    } else if Self::is_python_builtin(name) {
250                        // Known Python builtin or typing construct - use as-is
251                        result.push_str(name);
252                    } else {
253                        // Unknown identifier - preserve as-is
254                        result.push_str(name);
255                    }
256                }
257                Token::DottedPath(parts) => {
258                    // Check if this is an over-qualified path (e.g., "my_module.Type" when we're already in "my_module")
259                    // If the dotted path is module.Type and module matches target_module, simplify to just Type
260                    if parts.len() == 2 {
261                        let module_path = &parts[0];
262                        let type_name = &parts[1];
263
264                        // Check if target_module matches or ends with the module_path
265                        // E.g., target="pkg.sub_mod" matches module_path="sub_mod"
266                        let is_same_module = module_path == target_module
267                            || target_module.ends_with(&format!(".{}", module_path));
268
269                        if is_same_module {
270                            // Over-qualified - just use the type name
271                            result.push_str(type_name);
272                        } else {
273                            // Different module - keep the qualification
274                            result.push_str(&parts.join("."));
275                        }
276                    } else {
277                        // More complex path - preserve as-is
278                        result.push_str(&parts.join("."));
279                    }
280                }
281                Token::OpenBracket(ch) => result.push(ch),
282                Token::CloseBracket(ch) => result.push(ch),
283                Token::Comma => result.push(','),
284                Token::Pipe => result.push_str(" | "),
285                Token::Ellipsis => result.push_str("..."),
286                Token::StringLiteral(s) => {
287                    // String literals (forward references) - wrap in quotes
288                    result.push('"');
289                    result.push_str(&s);
290                    result.push('"');
291                }
292                Token::Whitespace(ws) => result.push_str(&ws),
293            }
294        }
295
296        result
297    }
298
299    /// Check if an identifier is a known Python builtin or typing construct
300    fn is_python_builtin(identifier: &str) -> bool {
301        matches!(
302            identifier,
303            // typing module types
304            "Any" | "Optional" | "Union" | "List" | "Dict" | "Tuple" | "Set" |
305            "Callable" | "Sequence" | "Mapping" | "Iterable" | "Iterator" |
306            "Literal" | "TypeVar" | "Generic" | "Protocol" | "TypeAlias" |
307            "Final" | "ClassVar" | "Annotated" | "TypeGuard" | "Never" |
308            // builtins
309            "int" | "str" | "float" | "bool" | "bytes" | "bytearray" |
310            "list" | "dict" | "tuple" | "set" | "frozenset" |
311            "object" | "type" | "None" | "Ellipsis" |
312            "complex" | "slice" | "range" | "memoryview" |
313            // Special
314            "typing" | "collections" | "abc" | "builtins"
315        )
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::stub_type::ModuleRef;
323
324    #[test]
325    fn test_tokenize_simple() {
326        let tokens = tokenize("ClassA");
327        assert_eq!(tokens, vec![Token::Identifier("ClassA".to_string())]);
328    }
329
330    #[test]
331    fn test_tokenize_optional() {
332        let tokens = tokenize("typing.Optional[ClassA]");
333        assert_eq!(
334            tokens,
335            vec![
336                Token::DottedPath(vec!["typing".to_string(), "Optional".to_string()]),
337                Token::OpenBracket('['),
338                Token::Identifier("ClassA".to_string()),
339                Token::CloseBracket(']'),
340            ]
341        );
342    }
343
344    #[test]
345    fn test_tokenize_callable() {
346        let tokens = tokenize("Callable[[ClassA, str], int]");
347        assert_eq!(
348            tokens,
349            vec![
350                Token::Identifier("Callable".to_string()),
351                Token::OpenBracket('['),
352                Token::OpenBracket('['),
353                Token::Identifier("ClassA".to_string()),
354                Token::Comma,
355                Token::Whitespace(" ".to_string()),
356                Token::Identifier("str".to_string()),
357                Token::CloseBracket(']'),
358                Token::Comma,
359                Token::Whitespace(" ".to_string()),
360                Token::Identifier("int".to_string()),
361                Token::CloseBracket(']'),
362            ]
363        );
364    }
365
366    #[test]
367    fn test_tokenize_union() {
368        let tokens = tokenize("ClassA | ClassB");
369        assert_eq!(
370            tokens,
371            vec![
372                Token::Identifier("ClassA".to_string()),
373                Token::Whitespace(" ".to_string()),
374                Token::Pipe,
375                Token::Whitespace(" ".to_string()),
376                Token::Identifier("ClassB".to_string()),
377            ]
378        );
379    }
380
381    #[test]
382    fn test_qualify_simple() {
383        let mut type_refs = HashMap::new();
384        type_refs.insert(
385            "ClassA".to_string(),
386            TypeIdentifierRef {
387                module: ModuleRef::Named("test_package.sub_mod".into()),
388                import_kind: ImportKind::Module,
389            },
390        );
391
392        let result =
393            TypeExpressionQualifier::qualify_expression("ClassA", &type_refs, "test_package");
394        assert_eq!(result, "sub_mod.ClassA");
395    }
396
397    #[test]
398    fn test_qualify_optional() {
399        let mut type_refs = HashMap::new();
400        type_refs.insert(
401            "ClassA".to_string(),
402            TypeIdentifierRef {
403                module: ModuleRef::Named("test_package.sub_mod".into()),
404                import_kind: ImportKind::Module,
405            },
406        );
407
408        let result = TypeExpressionQualifier::qualify_expression(
409            "typing.Optional[ClassA]",
410            &type_refs,
411            "test_package",
412        );
413        assert_eq!(result, "typing.Optional[sub_mod.ClassA]");
414    }
415
416    #[test]
417    fn test_qualify_same_module() {
418        let mut type_refs = HashMap::new();
419        type_refs.insert(
420            "ClassA".to_string(),
421            TypeIdentifierRef {
422                module: ModuleRef::Named("test_package.sub_mod".into()),
423                import_kind: ImportKind::SameModule,
424            },
425        );
426
427        let result = TypeExpressionQualifier::qualify_expression(
428            "typing.Optional[ClassA]",
429            &type_refs,
430            "test_package.sub_mod",
431        );
432        assert_eq!(result, "typing.Optional[ClassA]");
433    }
434
435    #[test]
436    fn test_qualify_callable() {
437        let mut type_refs = HashMap::new();
438        type_refs.insert(
439            "ClassA".to_string(),
440            TypeIdentifierRef {
441                module: ModuleRef::Named("test_package.sub_mod".into()),
442                import_kind: ImportKind::Module,
443            },
444        );
445        type_refs.insert(
446            "ClassB".to_string(),
447            TypeIdentifierRef {
448                module: ModuleRef::Named("test_package.other_mod".into()),
449                import_kind: ImportKind::Module,
450            },
451        );
452
453        let result = TypeExpressionQualifier::qualify_expression(
454            "collections.abc.Callable[[ClassA, str], ClassB]",
455            &type_refs,
456            "test_package",
457        );
458        assert_eq!(
459            result,
460            "collections.abc.Callable[[sub_mod.ClassA, str], other_mod.ClassB]"
461        );
462    }
463}