asb: head /dev/brain > /dev/www

My home, musings, and wanderings on the world wide web.

Python's AST module: Bringing a gun to a knife fight!

So, I’ve been writing unit tests for some statistical code using py.test. One of the sweet things about py.test is that it gives you some cute context specific comparison assertions where you can check a data structure with another.

The problem that I ran into is when using this with floating point numbers. A minimal (convoluted) example:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
>>> import pandas as pd

>>> pd.np.random.seed(3141)
>>> xx = pd.np.random.random(17)

>>> print pd.np.percentile(xx, 25)
0.386739093187

>>> assert 0.386739093187 == pd.np.percentile(xx, 25)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-16-f6262184b4b6> in <module>()
----> 1 assert 0.386739093187 == pd.np.percentile(xx, 25)

AssertionError:

Now, this is not that annoying unless you wanna do this for complicated data structures such as dicts of lists of floats etc. and you want to use the assertion goodness that comes with py.test. And I had exactly this situation at hand. Disclaimer: Don’t try this at work! Even though I did.

I was thinking that it’d be easy as pie to do this in a lisp by traversing the structure with a lambda that will round all the floats. And then I figured that Python can do this for me using the AST module. There is a good introduction to the module here and a discussion on appropriate things to do with ASTs here.

So, here is what I ended up implementing (this is miles away from being safe to use) to solve the problem.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import ast
import codegen

class NumberRounder(ast.NodeTransformer):

    def __init__(self, digits):
        self.digits = digits

    def visit_Num(self, node):
        if isinstance(node.n, float):
            return ast.Num(round(node.n, self.digits))
        return node


def round_numbers_in(literal, digits=2):
    nr = NumberRounder(digits=digits)
    original_ast = ast.parse(str(literal))
    rewritten_ast = nr.visit(original_ast)
    rewritten_source = codegen.to_source(rewritten_ast)
    rewritten_literal = ast.literal_eval(rewritten_source)
    return rewritten_literal

And voila!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
>>> sample_complicated_literal = {
    'a': {
        'test': {
            'outliers': [0.074264470377902181, 0.83386867290874311],
            'quantiles': {
                'median': 0.090294684490804245,
                'upper_quartile': 0.46208167869977368,
                'lower_quartile': 0.082279577434353213,
                'minimum': 0.075867491789192387,
                'maximum': 0.75951127406694918
            }
        }, 'target': {
            'outliers': [0.90590397810859369, 0.074264470377902181],
            'quantiles': {
                'median': 0.51193055816399746,
                'upper_quartile': 0.83386867290874311,
                'lower_quartile': 0.38673909318708044,
                'minimum': 0.087088641668223832,
                'maximum': 0.90457554405068796
            }
        }
    }, 'c': {
        'test': {
            'outliers': [0.76596994956051723, 0.18885210718343348],
            'quantiles': {
                'median': 0.42025570188230343,
                'upper_quartile': 0.59311282572141033,
                'lower_quartile': 0.30455390453286846,
                'minimum': 0.2119924666533205,
                'maximum': 0.73139852479269585
            }
        }, 'target': {
            'outliers': [0.024942395243662818, 0.90001151823365477],
            'quantiles': {
                'median': 0.42025570188230343,
                'upper_quartile': 0.54953955294935286,
                'lower_quartile': 0.18885210718343348,
                'minimum': 0.057602330989601373,
                'maximum': 0.89242588191298255
            }
        }
    }
}

>>> import pprint

>>> pprint.pprint(round_numbers_in(sample_complicated_literal, digits=4))

{'a': {'target': {'outliers': [0.9059, 0.0743],
                'quantiles': {'lower_quartile': 0.3867,
                                'maximum': 0.9046,
                                'median': 0.5119,
                                'minimum': 0.0871,
                                'upper_quartile': 0.8339}},
    'test': {'outliers': [0.0743, 0.8339],
                'quantiles': {'lower_quartile': 0.0823,
                            'maximum': 0.7595,
                            'median': 0.0903,
                            'minimum': 0.0759,
                            'upper_quartile': 0.4621}}},
'c': {'target': {'outliers': [0.0249, 0.9],
                'quantiles': {'lower_quartile': 0.1889,
                                'maximum': 0.8924,
                                'median': 0.4203,
                                'minimum': 0.0576,
                                'upper_quartile': 0.5495}},
    'test': {'outliers': [0.766, 0.1889],
                'quantiles': {'lower_quartile': 0.3046,
                            'maximum': 0.7314,
                            'median': 0.4203,
                            'minimum': 0.212,
                            'upper_quartile': 0.5931}}}}