Enums and refactoring
Enumerations are types that take a set of pre-defined options, called members which are also assigned values. Usually, enumerations members are, as the name implies, simply mapped to a integer values, but any arbitrary value might work. Here’s an example of an enumeration in Python for colors:
from enum import Enum
class Color(Enum):
RED = 0
GREEN = 1
BLUE = 2
Creating this Enum allows us to use Color.RED
in code as a symbolic value, instead of using a string or integer. This post goes over some of the intricacies of using enumerations in python, and why you would want to use them over raw values. The Python docs have a really nice tutorial about Enum. This blog post will have some overlap, but will mostly focus on why, when, and where you would like to use them.
The TLDR version of this blog: enumerations are a great way to cure your codebase of what is called “stringly typing”. That’s it!
If you are still interested, let’s use a simple example from a Machine Learning workflow. In Machine Learning, we usually train models on some set of data, which we call the train set, and fine-tune our training procedure (e.g., the model hyperparameters) on another set, which is called the validation set. After training, we produce accuracy scores for held out data, commonly called the test set.
Let’s think of a simple function that gives us the correct split for a given dataset. We’re pretending to use mongoDB.
def get_data(dataset_id: str, dataset_type: str) -> list[dict]:
db_client = get_client()
records = db_client.find({"dataset_id": dataset_id, "dataset_split": dataset_type})
return list(records)
So far so good, looks nice! However, when maintaining this code, we will come to an ugly realization: it is unclear which values dataset_type
can take. Are we using TRAIN
or train
? validation
or val
, or even dev
? Is there another option, like unassigned
or not_used
? This becomes all the more problematic if you consider that often the dataset_type
is passed in via a command line argument, or via an API, so you can’t actually go through the code to find all possible values for dataset_type
. Often, the most reliable way to find out is to do a Database call, which is 👎👎👎.
Enter enumerations! Let’s define an Enum
for our dataset_type
.
from enum import Enum
class DatasetType(Enum):
TRAIN = "train"
VALIDATION = "validation"
TEST = "test"
(Note that we could also have used auto
here, in combination with a StrEnum
, but only on Python >= 3.11.)
Now, our function above reduces to:
def get_data(dataset_id: str, dataset_type: DatasetType) -> list[dict]:
db_client = get_client()
records = db_client.find({"dataset_id": dataset_id, "dataset_split": dataset_type.value})
return list(records)
Note the call to .value
in the query itself. This is because we don’t want to pass the Enum member itself, but the string value of that member. The database has no knowledge of the enumeration, so this query might fail if you don’t use .value
.
As you can see, a programmer that has no knowledge of the code base will have a much easier time to discover the meaning of the argument called dataset_type
, and can also easily determine the values this type can take. This makes refactors and changes in the code much less error-prone, and much easier.
name
and value
Enum
members have two properties, a name
and value
. The name
of a member is always unique, and always has the type str
, while the value
can have any type, and and also needs to be unique. Although uniqueness can be enforced using the unique
decorator, any duplicate values get removed on Enum creation anyway, which is weird to me. The following showcases what I mean:
from enum import Enum
class Color(Enum):
RED = 0
BLUE = 0
GREEN = 1
len(Color)
# 2!
print(list(Color))
# [<Color.RED: 0>, <Color.GREEN: 1>]
I don’t think I have ever used name
for anything, and always use the members directly in functions. I tend to only coerce members to their value
at the last possible moment, i.e., when interfacing with another library, a function that has no knowledge of the enumeration, or when returning API responses.
Also useful to know: you can get enumeration members by their value by calling the enumeration directly, and by their name by indexing them as a string. Using the Color enumeration defined above:
Color(0)
# <Color.RED: 0>
Color["RED"]
# <Color.RED: 0>
Advantages of using enums
Enumerations allow you to statically type check you code really nicely using mypy
. It thus gives quite strong, or a lot stronger, correctness guarantees compared to using string or integer values directly.
An additional advantage is that it is very difficult to assign the wrong value to the wrong argument when calling functions. For example, given a function that takes two string values, it is possible to mix them up:
get_data(dataset_type: str,
dataset_split: str)
These, and other points, have also been made by the excellent “rusty python” blog post: by assigning tiny types to your functions, it becomes much more difficult to use them incorrectly, and much easier for programmers to reason about them, as follows:
get_data(dataset_type: DatasetType,
dataset_split: DatasetSplit)
Another advantage is that using enumerations (or NewType
s, for that matter), force you to not create circular dependencies in your code. By using strings directly, it is possible, and very easy, in fact, to pass these strings around, up and down your program, and then pass it back out again. If you ever want to replace this string by an enumeration, you will be burned, and will have to probably refactor extensively.
A final advantage, and counterpoint to the above, is that enumerations don’t force you to use them everywhere. Once a circular dependency is uncovered, you can just start using whatever .value
you have, and start passing that around. It’s not pretty, very ugly, in fact, but it greatly eases the burned of refactors.
So, that was it for the first part of enumerations, next time we’ll discuss more complicated enumerations.