Skip to content

Numbers and core type fixes #966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 13 additions & 22 deletions babel/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import pickle
from collections.abc import Iterable, Mapping
from typing import TYPE_CHECKING, Any, overload
from typing import TYPE_CHECKING, Any

from babel import localedata
from babel.plural import PluralRule
Expand Down Expand Up @@ -260,21 +260,13 @@ def negotiate(
if identifier:
return Locale.parse(identifier, sep=sep)

@overload
@classmethod
def parse(cls, identifier: None, sep: str = ..., resolve_likely_subtags: bool = ...) -> None: ...

@overload
@classmethod
def parse(cls, identifier: str | Locale, sep: str = ..., resolve_likely_subtags: bool = ...) -> Locale: ...

@classmethod
def parse(
cls,
identifier: str | Locale | None,
sep: str = '_',
resolve_likely_subtags: bool = True,
) -> Locale | None:
) -> Locale:
"""Create a `Locale` instance for the given locale identifier.

>>> l = Locale.parse('de-DE', sep='-')
Expand Down Expand Up @@ -317,10 +309,9 @@ def parse(
identifier
:raise `UnknownLocaleError`: if no locale data is available for the
requested locale
:raise `TypeError`: if the identifier is not a string or a `Locale`
"""
if identifier is None:
return None
elif isinstance(identifier, Locale):
if isinstance(identifier, Locale):
return identifier
elif not isinstance(identifier, str):
raise TypeError(f"Unexpected value for identifier: {identifier!r}")
Expand Down Expand Up @@ -364,9 +355,9 @@ def _try_load_reducing(parts):
language, territory, script, variant = parts
modifier = None
language = get_global('language_aliases').get(language, language)
territory = get_global('territory_aliases').get(territory, (territory,))[0]
script = get_global('script_aliases').get(script, script)
variant = get_global('variant_aliases').get(variant, variant)
territory = get_global('territory_aliases').get(territory or '', (territory,))[0]
script = get_global('script_aliases').get(script or '', script)
variant = get_global('variant_aliases').get(variant or '', variant)

if territory == 'ZZ':
territory = None
Expand All @@ -389,9 +380,9 @@ def _try_load_reducing(parts):
if likely_subtag is not None:
parts2 = parse_locale(likely_subtag)
if len(parts2) == 5:
language2, _, script2, variant2, modifier2 = parse_locale(likely_subtag)
language2, _, script2, variant2, modifier2 = parts2
else:
language2, _, script2, variant2 = parse_locale(likely_subtag)
language2, _, script2, variant2 = parts2
modifier2 = None
locale = _try_load_reducing((language2, territory, script2, variant2, modifier2))
if locale is not None:
Expand Down Expand Up @@ -512,7 +503,7 @@ def get_territory_name(self, locale: Locale | str | None = None) -> str | None:
if locale is None:
locale = self
locale = Locale.parse(locale)
return locale.territories.get(self.territory)
return locale.territories.get(self.territory or '')

territory_name = property(get_territory_name, doc="""\
The localized territory name of the locale if available.
Expand All @@ -526,7 +517,7 @@ def get_script_name(self, locale: Locale | str | None = None) -> str | None:
if locale is None:
locale = self
locale = Locale.parse(locale)
return locale.scripts.get(self.script)
return locale.scripts.get(self.script or '')

script_name = property(get_script_name, doc="""\
The localized script name of the locale if available.
Expand Down Expand Up @@ -1147,7 +1138,7 @@ def negotiate_locale(preferred: Iterable[str], available: Iterable[str], sep: st
def parse_locale(
identifier: str,
sep: str = '_'
) -> tuple[str, str | None, str | None, str | None, str | None]:
) -> tuple[str, str | None, str | None, str | None] | tuple[str, str | None, str | None, str | None, str | None]:
"""Parse a locale identifier into a tuple of the form ``(language,
territory, script, variant, modifier)``.

Expand Down Expand Up @@ -1261,7 +1252,7 @@ def get_locale_identifier(
:param tup: the tuple as returned by :func:`parse_locale`.
:param sep: the separator for the identifier.
"""
tup = tuple(tup[:5])
tup = tuple(tup[:5]) # type: ignore # length should be no more than 5
lang, territory, script, variant, modifier = tup + (None,) * (5 - len(tup))
ret = sep.join(filter(None, (lang, script, territory, variant)))
return f'{ret}@{modifier}' if modifier else ret
45 changes: 23 additions & 22 deletions babel/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import decimal
import re
import warnings
from typing import TYPE_CHECKING, Any, overload
from typing import TYPE_CHECKING, Any, cast, overload

from babel.core import Locale, default_locale, get_global
from babel.localedata import LocaleDataDict
Expand Down Expand Up @@ -428,7 +428,7 @@ def get_decimal_quantum(precision: int | decimal.Decimal) -> decimal.Decimal:

def format_decimal(
number: float | decimal.Decimal | str,
format: str | None = None,
format: str | NumberPattern | None = None,
locale: Locale | str | None = LC_NUMERIC,
decimal_quantization: bool = True,
group_separator: bool = True,
Expand Down Expand Up @@ -474,8 +474,8 @@ def format_decimal(
number format.
"""
locale = Locale.parse(locale)
if not format:
format = locale.decimal_formats.get(format)
if format is None:
format = locale.decimal_formats[format]
pattern = parse_pattern(format)
return pattern.apply(
number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator)
Expand Down Expand Up @@ -513,15 +513,15 @@ def format_compact_decimal(
number, format = _get_compact_format(number, compact_format, locale, fraction_digits)
# Did not find a format, fall back.
if format is None:
format = locale.decimal_formats.get(None)
format = locale.decimal_formats[None]
pattern = parse_pattern(format)
return pattern.apply(number, locale, decimal_quantization=False)


def _get_compact_format(
number: float | decimal.Decimal | str,
compact_format: LocaleDataDict,
locale: Locale | str | None,
locale: Locale,
fraction_digits: int,
) -> tuple[decimal.Decimal, NumberPattern | None]:
"""Returns the number after dividing by the unit and the format pattern to use.
Expand All @@ -543,7 +543,7 @@ def _get_compact_format(
break
# otherwise, we need to divide the number by the magnitude but remove zeros
# equal to the number of 0's in the pattern minus 1
number = number / (magnitude // (10 ** (pattern.count("0") - 1)))
number = cast(decimal.Decimal, number / (magnitude // (10 ** (pattern.count("0") - 1))))
# round to the number of fraction digits requested
rounded = round(number, fraction_digits)
# if the remaining number is singular, use the singular format
Expand All @@ -565,7 +565,7 @@ class UnknownCurrencyFormatError(KeyError):
def format_currency(
number: float | decimal.Decimal | str,
currency: str,
format: str | None = None,
format: str | NumberPattern | None = None,
locale: Locale | str | None = LC_NUMERIC,
currency_digits: bool = True,
format_type: Literal["name", "standard", "accounting"] = "standard",
Expand Down Expand Up @@ -680,7 +680,7 @@ def format_currency(
def _format_currency_long_name(
number: float | decimal.Decimal | str,
currency: str,
format: str | None = None,
format: str | NumberPattern | None = None,
locale: Locale | str | None = LC_NUMERIC,
currency_digits: bool = True,
format_type: Literal["name", "standard", "accounting"] = "standard",
Expand All @@ -706,7 +706,7 @@ def _format_currency_long_name(

# Step 5.
if not format:
format = locale.decimal_formats.get(format)
format = locale.decimal_formats[format]

pattern = parse_pattern(format)

Expand Down Expand Up @@ -758,13 +758,15 @@ def format_compact_currency(
# compress adjacent spaces into one
format = re.sub(r'(\s)\s+', r'\1', format).strip()
break
if format is None:
raise ValueError('No compact currency format found for the given number and locale.')
pattern = parse_pattern(format)
return pattern.apply(number, locale, currency=currency, currency_digits=False, decimal_quantization=False)


def format_percent(
number: float | decimal.Decimal | str,
format: str | None = None,
format: str | NumberPattern | None = None,
locale: Locale | str | None = LC_NUMERIC,
decimal_quantization: bool = True,
group_separator: bool = True,
Expand Down Expand Up @@ -808,15 +810,15 @@ def format_percent(
"""
locale = Locale.parse(locale)
if not format:
format = locale.percent_formats.get(format)
format = locale.percent_formats[format]
pattern = parse_pattern(format)
return pattern.apply(
number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator)


def format_scientific(
number: float | decimal.Decimal | str,
format: str | None = None,
format: str | NumberPattern | None = None,
locale: Locale | str | None = LC_NUMERIC,
decimal_quantization: bool = True,
) -> str:
Expand Down Expand Up @@ -847,7 +849,7 @@ def format_scientific(
"""
locale = Locale.parse(locale)
if not format:
format = locale.scientific_formats.get(format)
format = locale.scientific_formats[format]
pattern = parse_pattern(format)
return pattern.apply(
number, locale, decimal_quantization=decimal_quantization)
Expand All @@ -856,7 +858,7 @@ def format_scientific(
class NumberFormatError(ValueError):
"""Exception raised when a string cannot be parsed into a number."""

def __init__(self, message: str, suggestions: str | None = None) -> None:
def __init__(self, message: str, suggestions: list[str] | None = None) -> None:
super().__init__(message)
#: a list of properly formatted numbers derived from the invalid input
self.suggestions = suggestions
Expand Down Expand Up @@ -1140,7 +1142,7 @@ def scientific_notation_elements(self, value: decimal.Decimal, locale: Locale |

def apply(
self,
value: float | decimal.Decimal,
value: float | decimal.Decimal | str,
locale: Locale | str | None,
currency: str | None = None,
currency_digits: bool = True,
Expand Down Expand Up @@ -1211,9 +1213,9 @@ def apply(
number = ''.join([
self._quantize_value(value, locale, frac_prec, group_separator),
get_exponential_symbol(locale),
exp_sign,
self._format_int(
str(exp), self.exp_prec[0], self.exp_prec[1], locale)])
exp_sign, # type: ignore # exp_sign is always defined here
self._format_int(str(exp), self.exp_prec[0], self.exp_prec[1], locale) # type: ignore # exp is always defined here
])

# Is it a significant digits pattern?
elif '@' in self.pattern:
Expand All @@ -1234,9 +1236,8 @@ def apply(
number if self.number_pattern != '' else '',
self.suffix[is_negative]])

if '¤' in retval:
retval = retval.replace('¤¤¤',
get_currency_name(currency, value, locale))
if '¤' in retval and currency is not None:
retval = retval.replace('¤¤¤', get_currency_name(currency, value, locale))
retval = retval.replace('¤¤', currency.upper())
retval = retval.replace('¤', get_currency_symbol(currency, locale))

Expand Down