From 2e364847fed741141cdbc0523d3d9f02af2287ba Mon Sep 17 00:00:00 2001 From: Olivier Breuleux Date: Thu, 27 Apr 2023 15:31:59 -0400 Subject: [PATCH 1/4] Rewrite the docstrings are found/parsed --- simple_parsing/docstring.py | 343 +++++++---------------- simple_parsing/wrappers/field_wrapper.py | 74 ++++- 2 files changed, 155 insertions(+), 262 deletions(-) diff --git a/simple_parsing/docstring.py b/simple_parsing/docstring.py index c7bc16e5..65765796 100644 --- a/simple_parsing/docstring.py +++ b/simple_parsing/docstring.py @@ -4,9 +4,13 @@ from __future__ import annotations import functools +import ast import inspect +import tokenize from dataclasses import dataclass +from functools import partial from logging import getLogger +from textwrap import dedent import docstring_parser as dp from docstring_parser.common import Docstring @@ -101,17 +105,6 @@ def _get_attribute_docstring(dataclass: type, field_name: str) -> AttributeDocSt """Gets the AttributeDocString of the given field in the given dataclass. Doesn't inspect base classes. """ - try: - source = inspect.getsource(dataclass) - except (TypeError, OSError) as e: - logger.debug( - UserWarning( - f"Couldn't retrieve the source code of class {dataclass} " - f"(in order to retrieve the docstring of field {field_name}): {e}" - ) - ) - return None - # Parse docstring to use as help strings desc_from_cls_docstring = "" cls_docstring = inspect.getdoc(dataclass) @@ -121,257 +114,113 @@ def _get_attribute_docstring(dataclass: type, field_name: str) -> AttributeDocSt if param.arg_name == field_name: desc_from_cls_docstring = param.description or "" - # NOTE: We want to skip the docstring lines. - # NOTE: Currently, we just remove the __doc__ from the source. It's perhaps a bit crude, - # but it works. - if dataclass.__doc__ and dataclass.__doc__ in source: - source = source.replace(dataclass.__doc__, "\n", 1) - # note: does this remove the whitespace though? - - code_lines: list[str] = source.splitlines() - # the first line is the class definition (OR the decorator!), we skip it. - start_line_index = 1 - # starting at the second line, there might be the docstring for the class. - # We want to skip over that until we reach an attribute definition. - while start_line_index < len(code_lines): - if _contains_field_definition(code_lines[start_line_index]): - break - start_line_index += 1 - - lines_with_field_defs = [ - (index, line) for index, line in enumerate(code_lines) if _contains_field_definition(line) - ] - for i, line in lines_with_field_defs: - if _line_contains_definition_for(line, field_name): - # we found the line with the definition of this field. - comment_above = _get_comment_ending_at_line(code_lines, i - 1) - comment_inline = _get_inline_comment_at_line(code_lines, i) - docstring_below = _get_docstring_starting_at_line(code_lines, i + 1) - return AttributeDocString( - comment_above, - comment_inline, - docstring_below, - desc_from_cls_docstring=desc_from_cls_docstring, - ) - return None - - -def _contains_field_definition(line: str) -> bool: - """Returns whether or not a line contains a an dataclass field definition. - - Arguments: - line_str {str} -- the line content - - Returns: - bool -- True if there is an attribute definition in the line. - - >>> _contains_field_definition("a: int = 0") - True - >>> _contains_field_definition("a: int") - True - >>> _contains_field_definition("a: int # comment") - True - >>> _contains_field_definition("a: int = 0 # comment") - True - >>> _contains_field_definition("class FooBaz(Foo, Baz):") - False - >>> _contains_field_definition("a = 4") - False - >>> _contains_field_definition("fooooooooobar.append(123)") - False - >>> _contains_field_definition("{a: int}") - False - >>> _contains_field_definition(" foobaz: int = 123 #: The foobaz property") - True - >>> _contains_field_definition("a #:= 3") - False - """ - # Get rid of any comments first. - line, _, _ = line.partition("#") - - if ":" not in line: - return False - - if "=" in line: - attribute_and_type, _, _ = line.partition("=") + results = get_attribute_docstrings(dataclass).get(field_name, None) + if results: + results.desc_from_cls_docstring = desc_from_cls_docstring + return results else: - attribute_and_type = line - - field_name, _, type = attribute_and_type.partition(":") - field_name = field_name.strip() - if ":" in type: - # weird annotation or dictionary? - return False - if not field_name: - # Empty attribute name? - return False - return field_name.isidentifier() - - -def _line_contains_definition_for(line: str, field_name: str) -> bool: - line = line.strip() - if not _contains_field_definition(line): - return False - attribute, _, type_and_value_assignment = line.partition(":") - attribute = attribute.strip() # remove any whitespace after the attribute name. - return attribute.isidentifier() and attribute == field_name - - -def _is_empty(line_str: str) -> bool: - return line_str.strip() == "" - - -def _is_comment(line_str: str) -> bool: - return line_str.strip().startswith("#") - - -def _get_comment_at_line(code_lines: list[str], line: int) -> str: - """Gets the comment at line `line` in `code_lines`. - - Arguments: - line {int} -- the index of the line in code_lines - - Returns: - str -- the comment at the given line. empty string if not present. - """ - line_str = code_lines[line] - assert not _contains_field_definition(line_str) - if "#" not in line_str: - return "" - parts = line_str.split("#", maxsplit=1) - comment = parts[1].strip() - return comment + return None -def _get_inline_comment_at_line(code_lines: list[str], line: int) -> str: - """Gets the inline comment at line `line`. +def scrape_comments(src): + lines = bytes(src, encoding="utf8").splitlines(keepends=True) + return [ + (*tok.start, "COMMENT", tok.string[1:].strip()) + for tok in tokenize.tokenize(partial(next, iter(lines))) + if tok.type == tokenize.COMMENT + ] - Arguments: - line {int} -- the index of the line in code_lines - Returns: - str -- the inline comment at the given line, else an empty string. - """ - assert 0 <= line < len(code_lines) - assert _contains_field_definition(code_lines[line]) - line_str = code_lines[line] - parts = line_str.split("#", maxsplit=1) - if len(parts) != 2: - return "" - comment = parts[1].strip() - return comment +class AttributeVisitor(ast.NodeVisitor): + def __init__(self): + self.data = [] + self.prefix = None + def add_data(self, node, kind, content): + self.data.append((node.lineno, node.col_offset, kind, content)) -def _get_comment_ending_at_line(code_lines: list[str], line: int) -> str: - start_line = line - end_line = line - # move up the code, one line at a time, while we don't hit the start, - # an attribute definition, or the end of a docstring. - while start_line > 0: - line_str = code_lines[start_line] - if _contains_field_definition(line_str): - break # previous line is an assignment - if '"""' in line_str or "'''" in line_str: - break # previous line has a docstring - start_line -= 1 - start_line += 1 + def visit_body(self, name, stmts): + old_prefix = self.prefix + if self.prefix is None: + self.prefix = "" + else: + self.prefix += f"{name}." + for stmt in stmts: + if ( + isinstance(stmt, ast.Expr) + and isinstance(stmt.value, ast.Constant) + and isinstance(stmt.value.value, str) + ): + self.add_data(stmt, "DOC", stmt.value.value) + else: + self.visit(stmt) + self.prefix = old_prefix - lines = [] - for i in range(start_line, end_line + 1): - # print(f"line {i}: {code_lines[i]}") - if _is_empty(code_lines[i]): - continue - assert not _contains_field_definition(code_lines[i]) - comment = _get_comment_at_line(code_lines, i) - lines.append(comment) - return "\n".join(lines).strip() + def visit_ClassDef(self, node): + if self.prefix is not None: + self.add_data(node, "VARIABLE", f"{self.prefix}{node.name}") + self.visit_body(node.name, node.body) + def visit_FunctionDef(self, node): + if self.prefix is not None: + self.add_data(node, "VARIABLE", f"{self.prefix}{node.name}") + self.visit_body(node.name, node.body) -def _get_docstring_starting_at_line(code_lines: list[str], line: int) -> str: - i = line - token: str | None = None - triple_single = "'''" - triple_double = '"""' - # print("finding docstring starting from line", line) + def visit_Assign(self, node): + self.generic_visit(node, may_assign=True) - # if we are looking further down than the end of the code, there is no - # docstring. - if line >= len(code_lines): - return "" - # the list of lines making up the docstring. - docstring_contents: list[str] = [] + def visit_AnnAssign(self, node): + self.generic_visit(node, may_assign=True) - while i < len(code_lines): - line_str = code_lines[i] - # print(f"(docstring) line {line}: {line_str}") + def visit_Name(self, node): + if isinstance(node.ctx, ast.Store): + self.add_data(node, "VARIABLE", f"{self.prefix}{node.id}") - # we haven't identified the starting line yet. - if token is None: - if _is_empty(line_str): - i += 1 - continue + def generic_visit(self, node, may_assign=False): + if isinstance(node, ast.stmt) and not may_assign: + self.add_data(node, "OTHER", None) + super().generic_visit(node) - elif _contains_field_definition(line_str) or _is_comment(line_str): - # we haven't reached the start of a docstring yet (since token - # is None), and we reached a line with an attribute definition, - # or a comment, hence the docstring is empty. - return "" - elif triple_single in line_str and triple_double in line_str: - # This handles something stupid like: - # @dataclass - # class Bob: - # a: int - # """ hello ''' - # bob - # ''' bye - # """ - triple_single_index = line_str.index(triple_single) - triple_double_index = line_str.index(triple_double) - if triple_single_index < triple_double_index: - token = triple_single - else: - token = triple_double - elif triple_double in line_str: - token = triple_double - elif triple_single in line_str: - token = triple_single - else: - # for i, line in enumerate(code_lines): - # print(f"line {i}: <{line}>") - # print(f"token: <{token}>") - # print(line_str) - logger.debug(f"Warning: Unable to parse attribute docstring: {line_str}") - return "" +def scrape_docstrings(src): + visitor = AttributeVisitor() + visitor.visit(ast.parse(src)) + return visitor.data - # get the string portion of the line (after a token or possibly - # between two tokens). - parts = line_str.split(token, maxsplit=2) - if len(parts) == 3: - # This takes care of cases like: - # @dataclass - # class Bob: - # a: int - # """ hello """ - between_tokens = parts[1].strip() - # print("Between tokens:", between_tokens) - docstring_contents.append(between_tokens) - break - elif len(parts) == 2: - after_token = parts[1].strip() - # print("After token:", after_token) - docstring_contents.append(after_token) - else: - # print(f"token is <{token}>") - if token in line_str: - # print(f"Line {line} End of a docstring:", line_str) - before = line_str.split(token, maxsplit=1)[0] - docstring_contents.append(before.strip()) - break +def get_attribute_docstrings(cls): + docs = {} + current = None + current_line = None + comments_above = [] + try: + indented_src = inspect.getsource(cls) + except (TypeError, OSError) as e: + logger.debug( + UserWarning( + f"Couldn't retrieve the source code of class {cls} " + f"(in order to retrieve the docstrings of its fields): {e}" + ) + ) + return {} + src = dedent(indented_src) + data = scrape_comments(src) + scrape_docstrings(src) + for line, _, kind, content in sorted(data): + if kind == "COMMENT": + if current is not None and current_line == line: + docs[current].comment_inline = content else: - # intermediate line without the token. - docstring_contents.append(line_str.strip()) - i += 1 - # print("Docstring contents:", docstring_contents) - return "\n".join(docstring_contents) + comments_above.append(content) + elif kind == "DOC" and current: + docs[current].docstring_below = content + elif kind == "VARIABLE": + docs[content] = AttributeDocString( + comment_above="\n".join(comments_above) + ) + comments_above = [] + current = content + current_line = line + elif kind == "OTHER": + current = current_line = None + comments_above = [] + return docs diff --git a/simple_parsing/wrappers/field_wrapper.py b/simple_parsing/wrappers/field_wrapper.py index 0becac2b..216158e2 100644 --- a/simple_parsing/wrappers/field_wrapper.py +++ b/simple_parsing/wrappers/field_wrapper.py @@ -10,6 +10,7 @@ from typing import Any, Callable, ClassVar, Hashable, Union, cast from typing_extensions import Literal +from typing import Any, ClassVar, Dict, List, Optional, Set, Tuple, Type, Union, cast, Callable from simple_parsing.help_formatter import TEMPORARY_TOKEN @@ -409,6 +410,12 @@ def enum_to_str(e): else: _arg_options["nargs"] = "*" + if isinstance(self.type_metadata, dict): + _arg_options.update(self.type_metadata) + elif isinstance(self.type_metadata, Callable): + self.type_metadata(self, _arg_options) + else: + raise TypeError("Wrong type for metadata") return _arg_options def duplicate_if_needed(self, parsed_values: Any) -> list[Any]: @@ -830,26 +837,62 @@ def required(self) -> bool: def required(self, value: bool): self._required = value + def _compute_type_and_metadata(self): + # TODO: Refactor this. Really ugly. + typ = self.field.type + if isinstance(typ, str): + # The type of the field might be a string when using `from __future__ import annotations`. + # NOTE: Here we'd like to convert the fields type to an actual type, in case the + # `from __future__ import annotations` feature is used. + # This should also resolve most forward references. + from simple_parsing.annotation_utils.get_field_annotations import ( + get_field_type_from_annotations, + ) + + field_type = get_field_type_from_annotations(self.parent.dataclass, self.field.name) + typ = field_type + elif isinstance(typ, dataclasses.InitVar): + typ = typ.type + + if hasattr(typ, "__metadata__"): + (metadata,) = typ.__metadata__ + (effective_type,) = typ.__args__ + else: + metadata = {} + effective_type = typ + + self._type = effective_type + self._type_metadata = metadata + + @property + def type_metadata(self) -> Type[Any]: + """Returns the wrapped field's type metadata.""" + if self._type_metadata is None: + self._compute_type_and_metadata() + return self._type_metadata + @property def type(self) -> type[Any]: """Returns the wrapped field's type annotation.""" - # TODO: Refactor this. Really ugly. if self._type is None: - self._type = self.field.type - if isinstance(self._type, str): - # The type of the field might be a string when using `from __future__ import annotations`. - # NOTE: Here we'd like to convert the fields type to an actual type, in case the - # `from __future__ import annotations` feature is used. - # This should also resolve most forward references. - from simple_parsing.annotation_utils.get_field_annotations import ( - get_field_type_from_annotations, - ) - - field_type = get_field_type_from_annotations(self.parent.dataclass, self.field.name) - self._type = field_type - elif isinstance(self._type, dataclasses.InitVar): - self._type = self._type.type + self._compute_type_and_metadata() return self._type + # if self._type is None: + # self._type = self.field.type + # if isinstance(self._type, str): + # # The type of the field might be a string when using `from __future__ import annotations`. + # # NOTE: Here we'd like to convert the fields type to an actual type, in case the + # # `from __future__ import annotations` feature is used. + # # This should also resolve most forward references. + # from simple_parsing.annotation_utils.get_field_annotations import ( + # get_field_type_from_annotations, + # ) + + # field_type = get_field_type_from_annotations(self.parent.dataclass, self.field.name) + # self._type = field_type + # elif isinstance(self._type, dataclasses.InitVar): + # self._type = self._type.type + # return self._type def __str__(self): return f"""""" @@ -1086,6 +1129,7 @@ def only_keep_action_args(options: dict[str, Any], action: str | Any) -> dict[st kept_options, deleted_options = utils.keep_keys(options, args_to_keep) if deleted_options: + breakpoint() logger.debug( f"Some auto-generated options were deleted, as they were " f"not required by the Action constructor: {deleted_options}." From c853d5cdb50823a610cd9dba3aa1e24cb177325e Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 27 Apr 2023 16:11:51 -0400 Subject: [PATCH 2/4] Fix tests, strip inline and docstring comments Signed-off-by: Fabrice Normandin --- simple_parsing/docstring.py | 39 +++++++++++++++--------- simple_parsing/wrappers/field_wrapper.py | 4 +-- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/simple_parsing/docstring.py b/simple_parsing/docstring.py index 65765796..e0354a04 100644 --- a/simple_parsing/docstring.py +++ b/simple_parsing/docstring.py @@ -3,8 +3,8 @@ """ from __future__ import annotations -import functools import ast +import functools import inspect import tokenize from dataclasses import dataclass @@ -14,6 +14,9 @@ import docstring_parser as dp from docstring_parser.common import Docstring +from typing_extensions import Literal + +from simple_parsing.utils import Dataclass logger = getLogger(__name__) @@ -122,7 +125,7 @@ def _get_attribute_docstring(dataclass: type, field_name: str) -> AttributeDocSt return None -def scrape_comments(src): +def scrape_comments(src: str) -> list[tuple[int, int, Literal["COMMENT"], str]]: lines = bytes(src, encoding="utf8").splitlines(keepends=True) return [ (*tok.start, "COMMENT", tok.string[1:].strip()) @@ -133,13 +136,13 @@ def scrape_comments(src): class AttributeVisitor(ast.NodeVisitor): def __init__(self): - self.data = [] + self.data: list[tuple[int, int, str, str]] = [] self.prefix = None - def add_data(self, node, kind, content): + def add_data(self, node: ast.Expr, kind: str, content: str): self.data.append((node.lineno, node.col_offset, kind, content)) - def visit_body(self, name, stmts): + def visit_body(self, name: str, stmts: list[ast.stmt]): old_prefix = self.prefix if self.prefix is None: self.prefix = "" @@ -182,16 +185,16 @@ def generic_visit(self, node, may_assign=False): super().generic_visit(node) -def scrape_docstrings(src): +def scrape_docstrings(src: str): visitor = AttributeVisitor() visitor.visit(ast.parse(src)) return visitor.data -def get_attribute_docstrings(cls): - docs = {} - current = None - current_line = None +def get_attribute_docstrings(cls: type[Dataclass]) -> dict[str, AttributeDocString]: + docs: dict[str, AttributeDocString] = {} + current: str | None = None + current_line: int | None = None comments_above = [] try: indented_src = inspect.getsource(cls) @@ -208,15 +211,21 @@ def get_attribute_docstrings(cls): for line, _, kind, content in sorted(data): if kind == "COMMENT": if current is not None and current_line == line: - docs[current].comment_inline = content + docs[current].comment_inline = content.strip() else: comments_above.append(content) elif kind == "DOC" and current: - docs[current].docstring_below = content + + content_lines = content.splitlines() + if len(content_lines) > 1: + docs[current].docstring_below = ( + dedent(content_lines[0]) + "\n" + dedent("\n".join(content_lines[1:])) + ) + else: + docs[current].docstring_below = dedent(content.strip()) + elif kind == "VARIABLE": - docs[content] = AttributeDocString( - comment_above="\n".join(comments_above) - ) + docs[content] = AttributeDocString(comment_above=dedent("\n".join(comments_above))) comments_above = [] current = content current_line = line diff --git a/simple_parsing/wrappers/field_wrapper.py b/simple_parsing/wrappers/field_wrapper.py index 216158e2..39826b2e 100644 --- a/simple_parsing/wrappers/field_wrapper.py +++ b/simple_parsing/wrappers/field_wrapper.py @@ -10,7 +10,6 @@ from typing import Any, Callable, ClassVar, Hashable, Union, cast from typing_extensions import Literal -from typing import Any, ClassVar, Dict, List, Optional, Set, Tuple, Type, Union, cast, Callable from simple_parsing.help_formatter import TEMPORARY_TOKEN @@ -865,7 +864,7 @@ def _compute_type_and_metadata(self): self._type_metadata = metadata @property - def type_metadata(self) -> Type[Any]: + def type_metadata(self) -> type[Any]: """Returns the wrapped field's type metadata.""" if self._type_metadata is None: self._compute_type_and_metadata() @@ -1129,7 +1128,6 @@ def only_keep_action_args(options: dict[str, Any], action: str | Any) -> dict[st kept_options, deleted_options = utils.keep_keys(options, args_to_keep) if deleted_options: - breakpoint() logger.debug( f"Some auto-generated options were deleted, as they were " f"not required by the Action constructor: {deleted_options}." From 21a9dd9fdb6b46b7c526b9415d272b6bcd76a404 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 27 Apr 2023 16:21:31 -0400 Subject: [PATCH 3/4] Small typing improvements in docstrings.py Signed-off-by: Fabrice Normandin --- simple_parsing/docstring.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/simple_parsing/docstring.py b/simple_parsing/docstring.py index e0354a04..7f69d862 100644 --- a/simple_parsing/docstring.py +++ b/simple_parsing/docstring.py @@ -136,10 +136,12 @@ def scrape_comments(src: str) -> list[tuple[int, int, Literal["COMMENT"], str]]: class AttributeVisitor(ast.NodeVisitor): def __init__(self): - self.data: list[tuple[int, int, str, str]] = [] + self.data: list[tuple[int, int, Literal["DOC", "VARIABLE", "OTHER"], str | None]] = [] self.prefix = None - def add_data(self, node: ast.Expr, kind: str, content: str): + def add_data( + self, node: ast.AST, kind: Literal["DOC", "VARIABLE", "OTHER"], content: str | None + ): self.data.append((node.lineno, node.col_offset, kind, content)) def visit_body(self, name: str, stmts: list[ast.stmt]): @@ -159,27 +161,27 @@ def visit_body(self, name: str, stmts: list[ast.stmt]): self.visit(stmt) self.prefix = old_prefix - def visit_ClassDef(self, node): + def visit_ClassDef(self, node: ast.ClassDef): if self.prefix is not None: self.add_data(node, "VARIABLE", f"{self.prefix}{node.name}") self.visit_body(node.name, node.body) - def visit_FunctionDef(self, node): + def visit_FunctionDef(self, node: ast.FunctionDef): if self.prefix is not None: self.add_data(node, "VARIABLE", f"{self.prefix}{node.name}") self.visit_body(node.name, node.body) - def visit_Assign(self, node): + def visit_Assign(self, node: ast.Assign): self.generic_visit(node, may_assign=True) - def visit_AnnAssign(self, node): + def visit_AnnAssign(self, node: ast.AnnAssign): self.generic_visit(node, may_assign=True) - def visit_Name(self, node): + def visit_Name(self, node: ast.Name): if isinstance(node.ctx, ast.Store): self.add_data(node, "VARIABLE", f"{self.prefix}{node.id}") - def generic_visit(self, node, may_assign=False): + def generic_visit(self, node: ast.AST, may_assign: bool = False): if isinstance(node, ast.stmt) and not may_assign: self.add_data(node, "OTHER", None) super().generic_visit(node) @@ -210,12 +212,13 @@ def get_attribute_docstrings(cls: type[Dataclass]) -> dict[str, AttributeDocStri data = scrape_comments(src) + scrape_docstrings(src) for line, _, kind, content in sorted(data): if kind == "COMMENT": + assert content is not None if current is not None and current_line == line: docs[current].comment_inline = content.strip() else: comments_above.append(content) elif kind == "DOC" and current: - + assert content is not None content_lines = content.splitlines() if len(content_lines) > 1: docs[current].docstring_below = ( @@ -225,6 +228,7 @@ def get_attribute_docstrings(cls: type[Dataclass]) -> dict[str, AttributeDocStri docs[current].docstring_below = dedent(content.strip()) elif kind == "VARIABLE": + assert content is not None docs[content] = AttributeDocString(comment_above=dedent("\n".join(comments_above))) comments_above = [] current = content From 5cd51af4d0cd4e1af2e0310c3494bded8b2d2b90 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 27 Apr 2023 16:30:26 -0400 Subject: [PATCH 4/4] Simplify docstrings.py slightly Signed-off-by: Fabrice Normandin --- simple_parsing/docstring.py | 46 +++++++++++++++---------------------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/simple_parsing/docstring.py b/simple_parsing/docstring.py index 7f69d862..ae054d4b 100644 --- a/simple_parsing/docstring.py +++ b/simple_parsing/docstring.py @@ -4,6 +4,7 @@ from __future__ import annotations import ast +import dataclasses import functools import inspect import tokenize @@ -13,7 +14,6 @@ from textwrap import dedent import docstring_parser as dp -from docstring_parser.common import Docstring from typing_extensions import Literal from simple_parsing.utils import Dataclass @@ -34,7 +34,7 @@ class AttributeDocString: @property def help_string(self) -> str: - """Returns the value that will be used for the "--help" string, using the contents of self.""" + """Returns the value that will be used for the "--help" string.""" return ( self.docstring_below or self.comment_above @@ -56,9 +56,9 @@ def get_attribute_docstring( Arguments: some_dataclass: a dataclass field_name: the name of the field. - accumulate_from_bases: Whether to accumulate the docstring components by looking through the - base classes. When set to `False`, whenever one of the classes has a definition for the - field, it is directly returned. Otherwise, we accumulate the parts of the dodc + accumulate_from_bases: Whether to accumulate the docstring components by looking through + the base classes. When set to `False`, whenever one of the classes has a definition for + the field, it is directly returned. Otherwise, we accumulate the parts of the dodc Returns: AttributeDocString -- an object holding the string descriptions of the field. """ @@ -69,7 +69,8 @@ def get_attribute_docstring( assert mro[-1] is object mro = mro[:-1] for base_class in mro: - attribute_docstring = _get_attribute_docstring(base_class, field_name) + attribute_docstring = get_attribute_docstrings(base_class).get(field_name, None) + if not attribute_docstring: continue if not created_docstring: @@ -103,28 +104,6 @@ def get_attribute_docstring( return created_docstring -@functools.lru_cache(2048) -def _get_attribute_docstring(dataclass: type, field_name: str) -> AttributeDocString | None: - """Gets the AttributeDocString of the given field in the given dataclass. - Doesn't inspect base classes. - """ - # Parse docstring to use as help strings - desc_from_cls_docstring = "" - cls_docstring = inspect.getdoc(dataclass) - if cls_docstring: - docstring: Docstring = dp.parse(cls_docstring) - for param in docstring.params: - if param.arg_name == field_name: - desc_from_cls_docstring = param.description or "" - - results = get_attribute_docstrings(dataclass).get(field_name, None) - if results: - results.desc_from_cls_docstring = desc_from_cls_docstring - return results - else: - return None - - def scrape_comments(src: str) -> list[tuple[int, int, Literal["COMMENT"], str]]: lines = bytes(src, encoding="utf8").splitlines(keepends=True) return [ @@ -193,6 +172,7 @@ def scrape_docstrings(src: str): return visitor.data +@functools.lru_cache(2048) def get_attribute_docstrings(cls: type[Dataclass]) -> dict[str, AttributeDocString]: docs: dict[str, AttributeDocString] = {} current: str | None = None @@ -236,4 +216,14 @@ def get_attribute_docstrings(cls: type[Dataclass]) -> dict[str, AttributeDocStri elif kind == "OTHER": current = current_line = None comments_above = [] + + # Parse docstring to use as help strings + cls_docstring = inspect.getdoc(cls) + if cls_docstring: + docstring: dp.Docstring = dp.parse(cls_docstring) + for param in docstring.params: + for field in dataclasses.fields(cls): + if param.arg_name == field.name: + docs[field.name].desc_from_cls_docstring = param.description or "" + return docs