Stéphan Tulkens

NLP Person

Correctly typing cached functions

Caching, or memoization, is a useful way to speed up repeated calls to expensive, pure, functions. When calling a function, we save the output, using the parameters of the function as a key to the cache. Then, instead of re-calculating the result of a function on each call, we simply return the value that was stored in the cache.

from functools import cache

@cache
def mul(x: int, y: int) -> int:
    print(f"Input arg: {x}")
    return x * y

mul(10, 5) # Prints 
mul(10, 5) # Doesn't print
mul(5, 3) # Prints

One easy pit to fall in, however, is that, by caching a function, you remove any type hints on the arguments of that function, thus rendering mypy powerless to correct and detect any mistakes you might have made. This was surprising to me, and only kind of “registered” with me when I committed a trivial typing mistake, after which all my tests failed. In addition, using cache also removes any type hints you might get in your editor, which is also annoying.

In this post, I’ll show you how to remedy this issue, while also conceding that there is no nice solution that always works.

Typing a caching function

In terms of typing, the cache decorator can take anything as an input argument (In practice, your input arguments need to be hashable, but there is easy way to represent this in typing (no, not Hashable)). To illustrate what I mean, here is what a simple caching decorator could look like:

from typing import Callable, Any

def cache(function: Callable) -> Callable[..., Any]:
    # Initialize cache

    def cached_function(*args: Any, **kwargs: Any) -> Any
        # Caching happens here
        return function(*args, **kwargs)

    return cached_function

Recall that a decorator is actually nothing more than a function applied to a function, so in the example below square and square_2 are equivalent:


@cache
def square(x: int) -> int:
    return x ** 2

def _square(x: int) -> int:
    return x ** 2

square_2 = cache(_square)

So, we simply apply cache to a function, which returns cached_function. The inner caching function, as you can see, has to have *args and **kwargs as its only arguments. Thus, any type hints that we had on our original function, in effect, disappear. We can pass anything to my_func, and mypy will never complain. It doesn’t even know how many arguments square has.

square(3, 5, 12)  # Passes mypy
square("blog", "glob", "dog", "apollo", 12, 3.14, object())  # Passes mypy

Another issue is that, at least for VSCode, any hints you might have gotten during coding disappear. Note that this only holds for input arguments, not output arguments.

result = square(3)
reveal_type(square)
# note: Revealed type is "functools._lru_cache_wrapper[builtins.int]
reveal_type(result)
# note: Revealed type is "builtins.str"

First solutions

First, a fair warning: the solutions I’ll present below are all quite verbose, and should not be considered if you don’t like writing code just to conform to mypy. The context in which I used the solutions below was a piece of code which was quite difficult to test because of a really long build process, so I was willing to write some extra code to get a better sense of correctness before starting the build step. If you have really short feedback loops, you might get very little value out of this.

A simple first solution is just to type the resulting function. This only works if you don’t use cache as a decorator, but just apply it to a function.

from functools import cache
from typing import Callable

def _repeat_string(x: str, y: int) -> str:
    return x * y

repeat_string: Callable[[str, int], str] = cache(_repeat_string)

repeat_string("dog", 3)  # Correct
repeat_string(15, 5)  # Incorrect
# error: Argument 1 has incompatible type "int"; expected "str"  [arg-type]

So, we can see that our function correctly detects our arguments, which a big step up from what we had before: we can now safely use cached functions. One issue, however, is that the typing on our function does not have a notion of keyword arguments. This has two consequences:

  1. We don’t see any argument names while typing in our IDE (i.e., the box that pops over your function when you are typing the definition.)
  2. Mypy will complain if you use keyword arguments.

I can’t show the former, but I can show the latter, so:

repeat_string(x="dog", y=3)  
# error: Unexpected keyword argument "x"  [call-arg]
# error: Unexpected keyword argument "y"  [call-arg]

Again, note that this actually works, but just doesn’t pass mypy. We’re deep in the astral realm of typing, in which things work but aren’t astrally aligned.

The reason using Callable is insufficient is because callable arguments are anonymous, and can thus only represent positional arguments. An improvement would be to use what is known as a Protocol. A Protocol is exactly meant to do what a plain Callable can’t do: it can represent a function with keyword arguments, and even *args and **kwargs.

Here’s what the Protocol for the function above could look like:

from typing import Protocol

class RepeatStringProtocol(Protocol):

    def __call__(self, x: str, y: int) -> str:
        ...

repeat_string: RepeatStringProtocol = cache(_repeat_string)

repeat_string(x="dog", y=3)  # Passes mypy

This passes mypy and gives you nice hints in your editor.

Does a general solution exist?

As it turns out, there is a general solution for this issue: a ParamSpec, i.e., a Parameter Specification Variable. The best way to think of this is that a ParamSpec is like a TypeVar, but for function arguments. The nice thing is that ParamSpec even works for variadic (i.e., variable in length) functions. Unfortunately for us, using a ParamSpec is not available to replace the built-in cache, because a ParamSpec needs to be inserted into the function signature of the callable. So, unlike the solution above, you won’t be able to apply this to functools.cache directly without redefining something.

Nevertheless, here is how you could do it for your own decorators, using a dummy cache as an example:

from typing import Callable, TypeVar, ParamSpec

ReturnType = TypeVar("ReturnType")
Params = ParamSpec("Params")

def cache(function: Callable[Params, ReturnType]) -> Callable[Params, ReturnType]:
    # Initialize cache

    def cached_function(*args: Any, **kwargs: Any) -> Any:
        # Caching happens here
        return function(*args, **kwargs)

    return cached_function

# Using _repeat_string from above
repeat_string = cache(_repeat_string)

reveal_type(repeat_string)
# note: Revealed type is "def (x: builtins.str, y: builtins.int) -> builtins.int"
reveal_type("dog", 3)  # Fine
reveal_type("dog", 3, None)  # error: Too many arguments  [call-arg]

And this works, also in your editor. But it doesn’t help you to fix your built-in cache. In order to do that, you’ll need to use the solutions above.

Is it worth it?

As I said in the introduction of this post: I’m not sure whether all of this is worth it. It’s certainly not something I would do for every occurrence of a cached function. It all boils down to how much of a completionist, and stickler for mypy correctness, you are.

I would generally use it in the following cases:

  1. User-facing (“public”) code. Your users deserve to know what to put into your functions.
  2. Code that is difficult to test.
  3. Code in projects in which the cost of a re-compile (i.e., a CI/CD pipeline or dev -> master merge roundtrip) is high.
  4. Code I am really really proud of, and which is unlikely to change.

That’s all, I hope you learned something!