Test Driven Development with Tensorflow
(30 min) This post demonstrates how Test Driven Development (TDD) applies to developing Machine Learning frameworks such as Tensorflow
1. Introduction
During my past ML research internship I frequently see colleagues or even tech leads writing codes that are not tested at all. Near the end of a product cycle, the ‘prototyped’ code/ tensorflow models are dumped to software engineers, where they are almost always re-written or accepted as black-box programs. Needless to say, the pain comes when the model is deployed, code/models naturally break and you typically add another week of toil on debugging who has messed up. Data science is becoming more and more of an integral part of the team, so developing bug-free, reproducible code should naturally be part of a data scientist’s job.
Test Driven Development or TDD is a development technique where you must first write a test that fails before you write new functional code. It is a key part of agile software development process. There are ample tutorials online on how to do TDD, but I cannot find many resources on how to do TDD on machine learning frameworks such as Tensorflow. Personally, I think a lot of data scientists come from academia, where rapid prototyping is always prefered over code quality. This post therefore serves to demonstrate how model development process can actually be improved by using TDD as part of the workflow.
2. Rapid Introduction of Conventional TDD Steps
There is actually a whole book on TDD written by Kent Beck. Here I reproduce his summmary on the TDD workflow:
- Quickly add a test before writing you actual code
- Run all tests and see how it fails, make sure it fails at where you expect it to fail
- Make a little change that makes the test passes without worrying about code quality
- Make sure tests pass then refactor your code
- Repeat from step 1
There are several nice properties of the TDD process: obviously, you no longer need to worry about catching up on writing unit tests. Secondly, you can actually worry less when you write actual code, as the tests are their to help you debug. Thirdly, you can refactor your code without worrying about breaking anything. Lastly and perhaps most crucially, you become faster at coding because you know exactly what you need to code because of the ‘contract’ specified by your tests.
When applying the TDD workflow to ML Research, it is worth noting that by no means should prototypes be tested, i.e. one should be free to open a ipython console/ Jupyter notebook and experiment with ideas. But the moment when a prototype goes near production code, the TDD framework should be applied.
3. TDD By Example: Scaled Dot Product Attention
Let’s say you have decided to write your own scaled dot product attention layer
in tensorflow. The equation to be implemented is:
$$\text{attention}(Q, K, V) = \sigma\left(\dfrac{QK^T}{\sqrt{\text{dim}(K)}}\right)V$$
Where $Q, K, V$
refer to the queries, keys and values matrices respectively.
Before we even think about implementing the line above, we create a test script
called test_attention.py
. First, we imagine that a script was already written
for us by importing it (the test will obviously fail):
Now we run pytest and observe that dreaded import error. We move onto step 3 -
make a little change that makes the test passes by creating the script
attention.py
under models
directory, and adding appropriate paths to
PYTHONPATH.
Now we run the test, everything passes (there is no test so….) and we move
onto the next task: How do we test for this attention layer? Sure, we can do
some hand done complicated maths with some random $Q, K, V$
. Or, we can begin
by testing a very simple case: when query is fully aligned with one of the keys, and is orthogonal to the rest.
Then the output should obviously be the value assigned to the key!
In the test above, our query is aligned with the second key, whose assiciated
values is $[2, 6]^T$
which is the expected output. Now we run the test and
realise that we dont even have the function scaled_dot_product_attention
in
our script! We then add the following to attention.py
Test fails: now at the right location, our model output is None
. Now we
actually implement the equation:
The test passes! What about masked inputs? Attention layers are typically designed to ignore nan inputs. Again, we start by writing a test:
So we masked our last key, which happens to also perfectly align with our
query. But we want to make sure that its values are not picked up by the
output. Now we simply re-run the test and see test fails since we are giving
one extra argument (mask
) to the model. We then add:
Again, test fails as expected, our existing implementation (which doesn’t do
anything with the mask) has managed to pick up the extreme values $[1000, 1000]^T$
. Finally we correct our script:
Vola, all tests pass and now we can refactor our code if needed.