Skip to content

Commit e4b4959

Browse files
authored
Fixed stubgen parsing generics from C extensions (#8939)
pybind11 is capable of producing type signatures that use generics (for example https://github.com/pybind/pybind11/blob/4e3d9fea74ed50a042d98f68fa35a3133482289b/include/pybind11/stl.h#L140). A user may also opt to write a signature in the docstring that uses generics. Currently when stubgen parses one of these generics, it attempts to import a part of it. For example if a docstring had my_func(str, int) -> List[mypackage.module_being_parsed.MyClass], the resulting stub file tries to import List[mypackage.module_being_parsed. This change fixes this behaviour by breaking the found type down into the multiple types around [], characters, adding any imports from those types that are needed, and then stripping out the name of the module being parsed.
1 parent 77b8574 commit e4b4959

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

mypy/stubgenc.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,16 @@ def strip_or_import(typ: str, module: ModuleType, imports: List[str]) -> str:
214214
imports: list of import statements (may be modified during the call)
215215
"""
216216
stripped_type = typ
217-
if module and typ.startswith(module.__name__ + '.'):
217+
if any(c in typ for c in '[,'):
218+
for subtyp in re.split(r'[\[,\]]', typ):
219+
strip_or_import(subtyp.strip(), module, imports)
220+
if module:
221+
stripped_type = re.sub(
222+
r'(^|[\[, ]+)' + re.escape(module.__name__ + '.'),
223+
r'\1',
224+
typ,
225+
)
226+
elif module and typ.startswith(module.__name__ + '.'):
218227
stripped_type = typ[len(module.__name__) + 1:]
219228
elif '.' in typ:
220229
arg_module = typ[:typ.rindex('.')]

mypy/test/teststubgen.py

+75
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,81 @@ def get_attribute(self) -> None:
794794
generate_c_property_stub('attribute', TestClass.attribute, output, readonly=True)
795795
assert_equal(output, ['@property', 'def attribute(self) -> str: ...'])
796796

797+
def test_generate_c_type_with_single_arg_generic(self) -> None:
798+
class TestClass:
799+
def test(self, arg0: str) -> None:
800+
"""
801+
test(self: TestClass, arg0: List[int])
802+
"""
803+
pass
804+
output = [] # type: List[str]
805+
imports = [] # type: List[str]
806+
mod = ModuleType(TestClass.__module__, '')
807+
generate_c_function_stub(mod, 'test', TestClass.test, output, imports,
808+
self_var='self', class_name='TestClass')
809+
assert_equal(output, ['def test(self, arg0: List[int]) -> Any: ...'])
810+
assert_equal(imports, [])
811+
812+
def test_generate_c_type_with_double_arg_generic(self) -> None:
813+
class TestClass:
814+
def test(self, arg0: str) -> None:
815+
"""
816+
test(self: TestClass, arg0: Dict[str, int])
817+
"""
818+
pass
819+
output = [] # type: List[str]
820+
imports = [] # type: List[str]
821+
mod = ModuleType(TestClass.__module__, '')
822+
generate_c_function_stub(mod, 'test', TestClass.test, output, imports,
823+
self_var='self', class_name='TestClass')
824+
assert_equal(output, ['def test(self, arg0: Dict[str,int]) -> Any: ...'])
825+
assert_equal(imports, [])
826+
827+
def test_generate_c_type_with_nested_generic(self) -> None:
828+
class TestClass:
829+
def test(self, arg0: str) -> None:
830+
"""
831+
test(self: TestClass, arg0: Dict[str, List[int]])
832+
"""
833+
pass
834+
output = [] # type: List[str]
835+
imports = [] # type: List[str]
836+
mod = ModuleType(TestClass.__module__, '')
837+
generate_c_function_stub(mod, 'test', TestClass.test, output, imports,
838+
self_var='self', class_name='TestClass')
839+
assert_equal(output, ['def test(self, arg0: Dict[str,List[int]]) -> Any: ...'])
840+
assert_equal(imports, [])
841+
842+
def test_generate_c_type_with_generic_using_other_module_first(self) -> None:
843+
class TestClass:
844+
def test(self, arg0: str) -> None:
845+
"""
846+
test(self: TestClass, arg0: Dict[argparse.Action, int])
847+
"""
848+
pass
849+
output = [] # type: List[str]
850+
imports = [] # type: List[str]
851+
mod = ModuleType(TestClass.__module__, '')
852+
generate_c_function_stub(mod, 'test', TestClass.test, output, imports,
853+
self_var='self', class_name='TestClass')
854+
assert_equal(output, ['def test(self, arg0: Dict[argparse.Action,int]) -> Any: ...'])
855+
assert_equal(imports, ['import argparse'])
856+
857+
def test_generate_c_type_with_generic_using_other_module_last(self) -> None:
858+
class TestClass:
859+
def test(self, arg0: str) -> None:
860+
"""
861+
test(self: TestClass, arg0: Dict[str, argparse.Action])
862+
"""
863+
pass
864+
output = [] # type: List[str]
865+
imports = [] # type: List[str]
866+
mod = ModuleType(TestClass.__module__, '')
867+
generate_c_function_stub(mod, 'test', TestClass.test, output, imports,
868+
self_var='self', class_name='TestClass')
869+
assert_equal(output, ['def test(self, arg0: Dict[str,argparse.Action]) -> Any: ...'])
870+
assert_equal(imports, ['import argparse'])
871+
797872
def test_generate_c_type_with_overload_pybind11(self) -> None:
798873
class TestClass:
799874
def __init__(self, arg0: str) -> None:

0 commit comments

Comments
 (0)