With this code
import os
from typing import Literal, get_args
Markets = Literal[
"BE", "DE", "DK", "EE", "ES", "FI", "FR", "GB", "IT", "LT", "LV", "NL", "NO", "PL", "PT", "SE"
]
MARKETS: list[Markets] = list(get_args(Markets))
def foo(x: Markets) -> None:
print(x)
market = os.environ.get("market")
if market not in MARKETS:
raise ValueError
foo(market)
I get the following error.
Argument 1 to "foo" has incompatible type "str"; expected "Literal['BE', 'DE', 'DK', 'EE', 'ES', 'FI', 'FR', 'GB', 'IT', 'LT', 'LV', 'NL', 'NO', 'PL', 'PT', 'SE']" [arg-type]mypy(error)
How do I need to check the market
variable so that mypy
knows that is of correct type?
CodePudding user response:
Alternative: Custom TypeGuard
Since Python 3.10
you can define your own type guards, which can make this slightly more elegant:
import os
from typing import Literal, TypeGuard, get_args
MarketT = Literal[
"BE", "DE", "DK", "EE", "ES", "FI", "FR", "GB", "IT", "LT", "LV", "NL", "NO", "PL", "PT", "SE"
]
MARKETS: list[MarketT] = list(get_args(MarketT))
def is_valid_market(val: str) -> TypeGuard[MarketT]:
return val in MARKETS
def foo(x: MarketT) -> None:
print(x)
market = os.environ.get("market", "")
reveal_type(market)
assert is_valid_market(market)
reveal_type(market)
foo(market)
Running mypy
over this will show you that before the assert
the type is inferred as str
, whereas after the assert it is narrowed to that union of string literals you defined earlier. This basically combines both the runtime check (that you already had) and the static narrowing into one.
Note: I still need to provide a str
instance as the default for os.environ.get
because otherwise market
might still turn out to be None
. We could instead annotate the val
parameter in is_valid_market
with Optional[str]
to avoid another type checker error. This is just a matter of preference.
Original post
Yes, cast
is the easiest way IMO:
import os
from typing import Literal, cast, get_args
Market = Literal[
"BE", "DE", "DK", "EE", "ES", "FI", "FR", "GB", "IT", "LT", "LV", "NL", "NO", "PL", "PT", "SE"
]
MARKETS: list[Market] = list(get_args(Market))
def foo(x: Market) -> None:
print(x)
market = cast(Market, os.environ.get("market"))
# reveal_type(market)
if market not in MARKETS:
raise ValueError
foo(market)
Uncommenting the reveal_type
statement and running mypy
will give you the following:
note: Revealed type is "Union[Literal['BE'], Literal['DE'], Literal['DK'], Literal['EE'], Literal['ES'], Literal['FI'], Literal['FR'], Literal['GB'], Literal['IT'], Literal['LT'], Literal['LV'], Literal['NL'], Literal['NO'], Literal['PL'], Literal['PT'], Literal['SE']]"
So the type is correctly inferred as a union of those string literals.
As a side note, semantically, I would say the name of your literal union should be Market
, not Markets
(maybe even MarketType
or MarketT
). It refers to the type of the variable that will represent a single market after all, not multiple. The list
name on the other hand is fitting, since it refers to a collection of all the possible markets.
CodePudding user response:
No need to use cast(..)
. Just type your variable:
def foo(x: Market) -> None:
print(x)
market: Market = os.environ.get("market")
foo(market)