Stéphan Tulkens

NLP Person

Using overload to handle tagged union return types

Here’s a function with an idiom I’ve seen a lot (probably copied from sentence-transformers):

def score_sentences(sentences: str | list[str]) -> float | list[float]:
    was_single = False
    if isinstance(sentences, str):
        sentences = [sentences]
        was_single = True
    # Do something
    if was_single:
        return scores[0]
    return scores

reveal_type(score_sentences("dog"))
# Revealed type is "Union[builtins.float, builtins.list[builtins.float]]"
reveal_type(score_sentences(["dog", "cat"]))
# Revealed type is "Union[builtins.float, builtins.list[builtins.float]]"

That is, depending on whether a single sentence was passed in, the output type differs. This is a super interesting idiom, but one which is difficult to deal with using static type analysis: even if you specify that the input is of type str, the static analyzer can’t determine that this is the cause of a single float being returned instead of a list of floats. So the return type of this function will always be: float | list[float], a tagged union. You can see this in the revealed types of the example above: both function calls return float | list[float], and both are wrong. So in both cases, you’d have to cast or assert to make sure the following types worked.

So here’s my first recommendation: don’t use tagged unions as return types if you can help it. In the example above, consider the following alternative. It’s a trade-off, but it’s certainly easier to reason about.

def score_sentence(sentence: str) -> float:
    return score_sentences([sentence])[0]

def score_sentences(sentences: list[str]) -> list[float]:
    ...

But ok, let’s say you are stuck with this implementation, and you need a union return type. What to do? Here’s something you can consider: the typing.overload decorator. Normal Other programming languages have overloaded functions, which is where a single function has multiple implementations of the same behavior, and the correct implementation is selected based on the input passed to that function.

With typing.overload, you still have a single implementation of your function, this should be the function above. The types used in this function should be the most general, i.e., return unions, and so on. You can then define signatures for the same function, and use the @overload decorator on them, to specify that, specifically for these input types, the output type changes.

See here for an example:

from typing import overload

@overload
def score_sentences(sentences: str) -> float:
    ...

@overload
def score_sentences(sentences: list[str]) -> list[float]:
    ...

# Original implementation
def score_sentences(sentences: str | list[str]) -> float | list[float]:
    ...

reveal_type(score_sentences("dog"))
# note: Revealed type is "builtins.float"
reveal_type(score_sentences(["dog", "cat"]))
# note: Revealed type is "builtins.list[builtins.float]"

As you can see, using an overload solved the issue! Pretty nice. Some remarks:

  • If an overload doesn’t exist, the type checker will default to the original implementation.
  • The overloads don’t “exist”, they don’t influence processing (although they can be looked up at run-time, so they do exist.)
  • Conflicting overloads will be flagged by the type checker.
  • All parameters need to be redefined by every overload.

Now, let’s deal with a more complicated case:

def score_sentences(sentences: str | list[str], return_integers: bool = False) -> int | float | list[int] | list[float]:
    ...

This function has the following behavior: it returns a single float, or integer if return_integers is True, if a single string is passed, and a list of floats or integers if a list of strings is passed. Note that is is a BAD function. It uses a flag argument. As above, this function should also be turned into two different functions.

So this is where it gets super-gnarly, because as it turns out, you can’t overload functions on values of arguments, only on their types. So there’s no direct way to say: “if this is True, return X, otherwise return Y”. As it turns out though, there is a way to fix this, which is by turning specific values into types using Literal. This allows us to turn values of the boolean True and False into their own types. The type checker is smart enough to surmise that True | False covers the bool type.

from typing import overload, Literal

TrueType = Literal[True]
FalseType = Literal[False]


@overload
def score_sentences(sentences: str, return_integers: TrueType) -> int: ...


@overload
def score_sentences(sentences: list[str], return_integers: TrueType) -> list[int]: ...


@overload
def score_sentences(sentences: str, return_integers: FalseType) -> float: ...


@overload
def score_sentences(
    sentences: list[str], return_integers: FalseType
) -> list[float]: ...


def score_sentences(
    sentences: str | list[str], return_integers: bool
) -> int | float | list[int] | list[float]: ...


reveal_type(score_sentences("dog", True))
# Revealed type is "builtins.int"
reveal_type(score_sentences("dog", False))
# Revealed type is "builtins.float"
reveal_type(score_sentences(["dog", "cat"], True))
# Revealed type is "builtins.list[builtins.int]"
reveal_type(score_sentences(["dog", "cat"], False))
# Revealed type is "builtins.list[builtins.float]"

This works. This is not something I would advocate doing, but it’s a nice idiom to keep in mind. Legacy code that is weakly typed or not typed at all is bound to have things like this, and this is a nice way to fix it. For new code, I would err on the side of caution and avoid union types if possible.

Newer >>