Home > OS >  How to check that a string is a string literal for mypy?
How to check that a string is a string literal for mypy?

Time:01-08

With this code

import os
from typing import Literal, get_args

Markets = Literal[
    "BE", "DE", "DK", "EE", "ES", "FI", "FR", "GB", "IT", "LT", "LV", "NL", "NO", "PL", "PT", "SE"
]
MARKETS: list[Markets] = list(get_args(Markets))


def foo(x: Markets) -> None:
    print(x)


market = os.environ.get("market")


if market not in MARKETS:
    raise ValueError


foo(market)

I get the following error.

Argument 1 to "foo" has incompatible type "str"; expected "Literal['BE', 'DE', 'DK', 'EE', 'ES', 'FI', 'FR', 'GB', 'IT', 'LT', 'LV', 'NL', 'NO', 'PL', 'PT', 'SE']"  [arg-type]mypy(error)

How do I need to check the market variable so that mypy knows that is of correct type?

CodePudding user response:

Alternative: Custom TypeGuard

Since Python 3.10 you can define your own type guards, which can make this slightly more elegant:

import os
from typing import Literal, TypeGuard, get_args


MarketT = Literal[
    "BE", "DE", "DK", "EE", "ES", "FI", "FR", "GB", "IT", "LT", "LV", "NL", "NO", "PL", "PT", "SE"
]
MARKETS: list[MarketT] = list(get_args(MarketT))


def is_valid_market(val: str) -> TypeGuard[MarketT]:
    return val in MARKETS


def foo(x: MarketT) -> None:
    print(x)


market = os.environ.get("market", "")
reveal_type(market)
assert is_valid_market(market)
reveal_type(market)

foo(market)

Running mypy over this will show you that before the assert the type is inferred as str, whereas after the assert it is narrowed to that union of string literals you defined earlier. This basically combines both the runtime check (that you already had) and the static narrowing into one.

Note: I still need to provide a str instance as the default for os.environ.get because otherwise market might still turn out to be None. We could instead annotate the val parameter in is_valid_market with Optional[str] to avoid another type checker error. This is just a matter of preference.


Original post

Yes, cast is the easiest way IMO:

import os
from typing import Literal, cast, get_args


Market = Literal[
    "BE", "DE", "DK", "EE", "ES", "FI", "FR", "GB", "IT", "LT", "LV", "NL", "NO", "PL", "PT", "SE"
]
MARKETS: list[Market] = list(get_args(Market))


def foo(x: Market) -> None:
    print(x)


market = cast(Market, os.environ.get("market"))
# reveal_type(market)

if market not in MARKETS:
    raise ValueError

foo(market)

Uncommenting the reveal_type statement and running mypy will give you the following:

note: Revealed type is "Union[Literal['BE'], Literal['DE'], Literal['DK'], Literal['EE'], Literal['ES'], Literal['FI'], Literal['FR'], Literal['GB'], Literal['IT'], Literal['LT'], Literal['LV'], Literal['NL'], Literal['NO'], Literal['PL'], Literal['PT'], Literal['SE']]"

So the type is correctly inferred as a union of those string literals.

As a side note, semantically, I would say the name of your literal union should be Market, not Markets (maybe even MarketType or MarketT). It refers to the type of the variable that will represent a single market after all, not multiple. The list name on the other hand is fitting, since it refers to a collection of all the possible markets.

CodePudding user response:

No need to use cast(..). Just type your variable:

def foo(x: Market) -> None:
    print(x)

market: Market = os.environ.get("market")

foo(market)
  • Related