Table of Contents

Don’t be afraid of the names on the title. Although they can seem scary or strange probably you already have been in touch with tools that work with this kind of stuff. For example, pytest and numba.

Intro: our previous problem

In the previous post, I talked about python frames and inspection module. I started showing how we can use the inspect.signature to construct a decorator that validates arguments:

@math_validator()
def simple_method(x: "\in R", y: "\in R_+", z: float = 2) -> float:
    ...
simple_method(1, 0)
simple_method((1, 2)) -> 1.5
---> 19 simple_method(1, 0)
...
<locals>.decorate.<locals>.decorated(*_args)
     11         continue
     13     if not MATH_SPACES[annotation]["validator"](_args[i]):
---> 14         raise ValueError(f"{k} doesn't belong to the {MATH_SPACES[annotation]['name']}")
     15 result = func(*_args)
     16 print(f"{func.__name__}({_args}) -> {result}")

ValueError: y doesn't belong to the space of real numbers greater than zero

And after that, I combined the inspect.singature+sys.trace to construct a decorator that exposes the local variables of a decorated function. All this stuff allows us to do cool things like creating a generic report decorator that has access to the local variables of the decorated method

@report('{arg.n_bananas} Monkey {gluttonous_monkey} ate too much bananas.  Num monkeys {num_monkeys}')
def feed_monkeys(n_bananas): 
    num_monkeys = 3
    monkeys = {
        f"monkey_{i}": {"bananas": 0}
        for i in range(num_monkeys)
    }
    while n_bananas > 0:
        if np.random.uniform() < 0.4:
            continue
        monkey = monkeys[np.random.choice(list(monkeys.keys()))]
        if n_bananas > 0:
            monkey["bananas"] += 1
            n_bananas -= 1
    gluttonous_monkey = max(monkeys, key=lambda k: monkeys[k]["bananas"]) 

These two examples can be found in real application scenarios. But at the end of my previous post I told you some issues regarding the use of sys.trace. I’ll put the code here of the previous solution:

import sys
import inspect
from types import SimpleNamespace


def call_and_extract_frame(func, *args, **kwargs):
    frame_var = None
    trace = sys.gettrace()
    def update_frame_var(stack_frame, event_name, arg_frame):
        """
        Args:
            stack_frame: (frame)
                The current stack frame.
            event_name: (str)
                The name of the event that triggered the call. 
                Can be 'call', 'line', 'return' and 'exception'.
            arg_frame: 
                Depends on the event. Can be a None type
        """
        nonlocal frame_var # nonlocal is a keyword which allows us to modify the outisde scope variable
        if event_name != 'call':
            return trace
        frame_var = stack_frame
        sys.settrace(trace)
        return trace
    sys.settrace(update_frame_var)
    try:
        func_result = func(*args, **kwargs)
    finally:
        sys.settrace(trace)
    return frame_var, func_result
def report(formater):
    def decorate(func):
        def decorated(*_args):
            sig = inspect.signature(func)
            named_args = {}
            num_args = len(_args)
            for i, (k, v) in enumerate(sig.parameters.items()):
                if i < num_args:
                    named_args[k] = repr(_args[i])
                else:
                    named_args[k] = repr(v.default)
            frame_func, _result = call_and_extract_frame(func, *_args)
            name = func.__name__
            result = repr(_result)
            
            args_dict = {
                "args": SimpleNamespace(**named_args), 
                "args_repr": repr(SimpleNamespace(**named_args)),
                **locals(),
                **frame_func.f_locals,
            }
            print(formater.format(**args_dict))
            # do other stuff here
            return _result 
        return decorated
    return decorate

What are the problems with this solution?

  • A tracing always creates a cost. Thus, it is expected that we will reduce the performance of our system. If you use this just for debugging purposes, it’s ok.
  • It can create conflicts with other tools and libs that are also trying to use the trace tool
  • Seems dirty!

Ok, maybe you’re asking yourself “This guy is overthinking. Why didn’t he just do this?"

@report('stuff goes here')
def func(x, y):
    random_var = np.random.uniform()
    ... #more local vars
    result = (x+y)**random_var
    return result, locals 

The reason is:

  • The main point of using this decorator is to avoid any change in other parts of the codebase. For example, if in any part of the codebase func has been called you will have to change to
result = func(x, y) # to 
result = func(x, y)[0]

If after you choose to remove the decorator from a function, you will need to roll back all the above changes.

  • You will increase the cognitive load in all members of the team who don’t care about what your decorator needs to do.
  • If you propose this a solution you’d better just create another function and face the consequences of this increase in complexity in the original codebase.

Ok, maybe you’re now thinking: “Right, this makes sense, but you’re avoiding these issues creating other issues in performance and debugging. It doesn’t sound good except for some special cases”. And I need to agree with you, it’s not a good solution for most of the cases!

The problem we’re facing is that python doesn’t have context managers that can deal with namespaces, although there is an active discussion about this https://mail.python.org/archives/list/python-ideas@python.org/. But don’t worry about this big name. The important point here is that:

If a language doesn’t have a feature that I need, what can I do?

In python we are fine with this because it’s a language that turns to be easy to manipulate the Abstract Syntax Tree and recompile a function with the manipulated tree. Doing that way we’re in the realm of metaprogramming. Writing code which writes code. If t’s not clear I’ll try to be more clear now.

ASTs: What are they?

A programming language obviously is at least a language. OK, but what is a language? Do all the human languages share the same building blocks? How can we compare different sentences? These questions seem more proper to be answered by philosophers. Well, maybe this is true, but these questions can also be answered by mathematicians and computer scientists. However, mathematicians and CS people usually prefer to talk using mathematical formalism rather than long debates about the meaning of the stuff. In essence, an AST is a mathematical formalism that allows us to represent a sentence using a well-defined set of rules and structures represented by a tree.

How do you know that a sentence is grammatically correct?

Intuitively, probably you remember a set of rules that you learned during your life about how to organize and compose verbs, nouns, adjectives, adverbs, etc. This set of rules and guidelines is the Syntax of a language. A Syntax Tree is a structure that helps us to understand a sentence.

After constructing the syntax tree we can look in the guidelines book of our language and check if this tree has a valid structure.

Take for example the sentence: “I drive a car to my college”, the syntax tree is the following:

A Syntax Tree for the sentence: I drive a car to my college. Source: Geeks for Geeks:Syntax Tree – Natural Language Processing.

What is the advantage of using ASTs? Notice that we don’t need to talk about how many spaces you’re using, we didn’t talk about your calligraphy and besides that, we have a hierarchy structure that allows us to analyze the validity of the sentence per level! If we want to change any element of the sentence we can directly manipulate the node which represents that element for a safe guarantee that the manipulated sentence is still grammatically correct!

It’s not a surprise that ASTs are also a common tool used in computer science to analyze the correctness of a piece of code and as a common part of the process of compiling/interpreting a code. Here we will extend the behavior of a Python decorator manipulating the AST. But before that, I would like to ask you a question:

Python: interpreted or compiled?

Usually, when I meet a Python hater (or even an enthusiast) they say statements like that

  • “Python is slow because it’s an interpreted language!"
  • “Python sucks because it doesn’t have a compiler!"

Well, these assertions are not true. The important point is that: when people refer to Python commonly they are actually talking about the language Python and the CPython virtual machine. Let’s talk more about these misconceptions.

First, the distinction between interpreted and compiled languages is very blurry today. Second, let’s see something

hello_world = "print('Hello, world!')"
hello_world_obj = compile(hello_world, '<string>', 'single')

Yeah, if you’re trying to defend that Python is interpreted the things start to get harder for you. Why is there a compile available?

exec(hello_world_obj)
Hello, world!

I’m executing a thing that has been compiled??? What is this hello_world_obj?

print(f"Bad news for you:\n\tContent: {hello_world_obj.co_code}\n\tType: {type(hello_world_obj.co_code)}")
Bad news for you:
	Content: b'e\x00d\x00\x83\x01F\x00d\x01S\x00'
	Type: <class 'bytes'>

But what is this stuff?

It is important to understand what happens behind the scenes.

After you write a code and call the python command, Python starts a compiling phase creating the ASTs; generating the bytecotes that will be attached to code objects, and then, these code objects will be interpreted by the CPython virtual machine. The diagram below is a simple representation of this process with some details hidden

graph LR; A[Source Code]-->|parsing|B[Parse Tree]; B-->C[AST]; C-->E[Bytecode]; E-->F[Code Object]; F-->|execution by|G[CPython Virtual Machine];

The compilation phase are the firts steps of the above diagram

graph LR; A[Source Code]-->|parsing|B[Parse Tree]; B-->C[AST]; C-->E[Bytecode]; E-->F[Code Object];

But don’t worry about most of the big names above. The only concepts that will matter to us are the AST, bytecodes, and Code object. Bytecodes are just a compact way to tell the interpreter what we want to do. The code object is just a way to encapsulate the bytecodes extracted from the AST.

But how does this help us?

Our solution will involve the manipulation of the AST and after that generating a new code object with the related manipulated AST!

Extracting ASTs and interpreting them

Let’s see a simple example of a function and the extracted AST.

import inspect
import ast
import astor # install this for pretty printing
def example(a: float, b:float = 2) -> float:
    s = a+b
    return s

tree = ast.parse(inspect.getsource(example))
print(astor.dump(tree))
astor.to_source(tree)
Module(
    body=[
        FunctionDef(name='example',
            args=arguments(posonlyargs=[],
                args=[arg(arg='a', annotation=Name(id='float'), type_comment=None),
                    arg(arg='b', annotation=Name(id='float'), type_comment=None)],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=[],
                kwarg=None,
                defaults=[Constant(value=2, kind=None)]),
            body=[
                Assign(targets=[Name(id='s')],
                    value=BinOp(left=Name(id='a'), op=Add, right=Name(id='b')),
                    type_comment=None),
                Return(value=Name(id='s'))],
            decorator_list=[],
            returns=Name(id='float'),
            type_comment=None)],
    type_ignores=[])

The above output is our AST and the below image show its graph representation. Take some time looking into it to see how all our code stuff is organized.

Each element in the above output with an upper case letter is a node (Name, BinOp, FunctionDef, etc) from the base class ast.Node. One of the most important node types is the ast.Name. For example,

value=BinOp(left=Name(id='a'), op=Add, right=Name(id='b')),

the ast.Name is used to refer to variable by the name, id.

Now let’s come back to our problem. Remember that one bad solution was rewriting every function

def func(x, y):
    random_var = np.random.uniform()
    ... #more local vars
    result = (x+y)**random_var
    return result

as

def func_transformed(x, y):
    random_var = np.random.uniform()
    ... #more local vars
    result = (x+y)**random_var
    return result, locals 

The big stuff that we will do is to write a function that codes new functions for us! This is metaprogramming! And at same time we will write a decorator that will avoid any change in our codebase!

How can I be efficient in metaprogramming?

We must create a function that generates a new one similar to func_transformed. How to get an idea of what we need to do?

The 6 simple steps

  1. Create an example function
  2. Code the transformed function from the example function
  3. Code a simple test to check if the transformed function is correct
  4. Extract the AST from the example and the transformed function
  5. Compare the ASTs. What is the difference? Annotate this difference somewhere
    • You can use the difflib module that comes with Python to diff strings
  6. Create a new and more complex example function and repeat the process until you get a good idea of what you need to do.

After you have a good idea of what you need to do, you can start writing your metaprogramming function.

Creating our metaprogramming function

First six-steps interaction

Let’s start our first interaction writing one function, the expected transformed function and the test to check if it is correct.

def example_1(x, y):
    internal_var  =  222
    result = (x+y)**internal_var
    return result
def example_1_expected(x, y):
    internal_var = 222
    result = (x+y)**internal_var
    return result, locals()

def test_meta_example_1(meta_func, x, y):
    expected_result, expected_locals = example_1_expected(x, y)
    result, locals_dict = meta_func(x, y)
    assert result == expected_result
    assert expected_locals == locals_dict

Everything looks fine. Now we will use the difflib to see the differences between the two ASTs.

import difflib
from pprint import pprint

example_1_ast_str = astor.dump_tree(ast.parse(inspect.getsource(example_1)))
example_1_expected_str = astor.dump_tree(ast.parse(inspect.getsource(example_1_expected)))


pprint(
    list(
        difflib.unified_diff(example_1_ast_str.splitlines(), example_1_expected_str.splitlines(), n=0)
    )
)
['--- \n',
 '+++ \n',
 '@@ -3 +3 @@\n',
 "-        FunctionDef(name='example_1',",
 "+        FunctionDef(name='example_1_expected',",
 '@@ -19 +19 @@\n',
 "-                Return(value=Name(id='result'))],",
 "+                Return(value=Tuple(elts=[Name(id='result'), "
 "Call(func=Name(id='locals'), args=[], keywords=[])]))],"]

Now we know that we must change this Node in the AST

Return(value=Name(id='result'))],

To this

Return(value=Tuple(elts=[Name(id='result'), Call(func=Name(id='locals'), args=[], keywords=[])]))],

How we can do this? With the help of NodeTransformer class

The NodeTransformer class

The ast.NodeTransformer allows us to create objects with a walker-like interface. The walker will visit each node in the AST and during each visit, the walker can remove, replace, modify or add nodes, and after that, he can continue to walk to the children of the node or stop there.

How can we use this? First, we start by creating a new class derived from ast.NodeTransformer

class ASTTransformer(ast.NodeTransformer):
    def visit_Return(self, node):

If you want to interact/change/delete a node of type Something you must override the visit_Something method. Thus, because we need to change the Return node we override the visit_Return. If we do just the following, our walker will not change our AST,

class ASTTransformer(ast.NodeTransformer):
    ...

Let’s start the modifications. We need to create a new node responsible for calling the locals

class ASTTransformer(ast.NodeTransformer):
    def visit_Return(self, node):
        node_locals = ast.Call(
            func=ast.Name(id='locals', ctx=ast.Load()),
            args=[], keywords=[]
        )
        self.generic_visit(node)
        return node

We used a Name node to identify the locals function. Now, according to the diff result our Return node must be transformed into a Return of a Tuple node

class ASTTransformer(ast.NodeTransformer):
    def visit_Return(self, node):
        node_locals = ast.Call(
            func=ast.Name(id='locals', ctx=ast.Load()),
            args=[], keywords=[]
        )
        new_node.value = ast.Tuple(
            elts=[
                node.value,
                node_locals
            ],
            ctx=ast.Load()
        )
        self.generic_visit(new_node)
        return new_node

A new thing appeared. The elts argument. But don’t worry, this is just an argument which tells us what the list of other nodes Tuple has. Whenever you have some doubt about AST stuff, you can check the ast documentation here. The documentation is simple to understand because python is simple!

Everything is almost done. The last thing is to fix our AST. Because when we change the Node we need to fill missing information like the line_number and column_offset. Thanks to python we just need to call fix_missing_locations to fill this for us.


class ASTTransformer(ast.NodeTransformer):
    def visit_Return(self, node):
        new_node = node
        node_locals = ast.Call(
            func=ast.Name(id='locals', ctx=ast.Load()),
            args=[], keywords=[]
        )
        new_node.value = ast.Tuple(
            elts=[
                node.value,
                node_locals
            ],
            ctx=ast.Load()
        )
        ast.copy_location(new_node, node)
        ast.fix_missing_locations(new_node)
        self.generic_visit(new_node)
        return new_node

Ok, let’s see if it is working. We must instantiate our transformer and call the visit method that tells the walker to walk in the AST and do all the modification we’re asking

tree_meta = ast.parse(inspect.getsource(example_1))
transformer = ASTTransformer()
transformer.visit(tree_meta)
example_1_meta_ast_str = astor.dump_tree(tree_meta)
example_1_expected_str = astor.dump_tree(ast.parse(inspect.getsource(example_1_expected)))


pprint(
    list(
        difflib.unified_diff(example_1_meta_ast_str.splitlines(), example_1_expected_str.splitlines(), n=0)
    )
)
['--- \n',
 '+++ \n',
 '@@ -3 +3 @@\n',
 "-        FunctionDef(name='example_1',",
 "+        FunctionDef(name='example_1_expected',"]

Our first iteration was successful! Let’s try a more complex example.

The second six-steps interaction

We’ll just add more complexity without any particular meaning, we can be creative!

def example_2(x, y):
    internal_var  =  222
    def sub(x, y):
        ommit_this_var = 1
        return x - y
    result = sub(x,y)**internal_var
    return (result, False)
def example_2_expected(x, y):
    internal_var  =  222
    def sub(x, y):
        ommit_this_var = 1
        return x - y
    result = sub(x,y)**internal_var
    return ((result, False), locals())
def test_meta_example_2(meta_func, x, y):
    expected_result, expected_locals = example_2_expected(x, y)
    result, locals_dict = meta_func(x, y)
    del locals_dict["sub"]
    del expected_locals["sub"]
    assert result == expected_result
    assert expected_locals == locals_dict
example_2_ast_str = astor.dump_tree(ast.parse(inspect.getsource(example_2)))
example_2_expected_str = astor.dump_tree(ast.parse(inspect.getsource(example_2_expected)))


pprint(
    list(
        difflib.unified_diff(example_2_ast_str.splitlines(), example_2_expected_str.splitlines(), n=0)
    )
)
['--- \n',
 '+++ \n',
 '@@ -3 +3 @@\n',
 "-        FunctionDef(name='example_2',",
 "+        FunctionDef(name='example_2_expected',",
 '@@ -37 +37,4 @@\n',
 "-                Return(value=Tuple(elts=[Name(id='result'), "
 'Constant(value=False, kind=None)]))],',
 '+                Return(',
 '+                    value=Tuple(',
 "+                        elts=[Tuple(elts=[Name(id='result'), "
 'Constant(value=False, kind=None)]),',
 "+                            Call(func=Name(id='locals'), args=[], "
 'keywords=[])]))],']

Now, it’s time to cross the fingers and see if we need to work more

tree_meta = ast.parse(inspect.getsource(example_2))
transformer = ASTTransformer()
transformer.visit(tree_meta)
example_2_meta_ast_str = astor.dump_tree(tree_meta)
example_2_expected_str = astor.dump_tree(ast.parse(inspect.getsource(example_2_expected)))


pprint(
    list(
        difflib.unified_diff(example_2_meta_ast_str.splitlines(), example_2_expected_str.splitlines(), n=0)
    )
)
['--- \n',
 '+++ \n',
 '@@ -3 +3 @@\n',
 "-        FunctionDef(name='example_2',",
 "+        FunctionDef(name='example_2_expected',",
 '@@ -27,4 +27 @@\n',
 '-                        Return(',
 '-                            value=Tuple(',
 "-                                elts=[BinOp(left=Name(id='x'), op=Sub, "
 "right=Name(id='y')),",
 "-                                    Call(func=Name(id='locals'), args=[], "
 'keywords=[])]))],',
 "+                        Return(value=BinOp(left=Name(id='x'), op=Sub, "
 "right=Name(id='y')))],"]

Unfortunately, our ASTTransformer was not able to deal with this crazy guy. What is the problem? If you check carefully you will notice that the inner function def sub is the problem. We don’t want to change any “sub” function, so we need to tell our walker to avoid changing this kind of stuff. To do so, we will create a flag to tell if the walker is in a sub-function, and we will just override the visit_FunctionDef method to check this flag

class ASTTransformer(ast.NodeTransformer):
    def visit_FunctionDef(self, node):
        if self._sub:
            return node
        self._sub = True
        self.generic_visit(node)
        return node

    def visit_Module(self, node):
        self._sub = 0
        self.generic_visit(node)

    def visit_Return(self, node):
        new_node = node
        node_locals = ast.Call(
            func=ast.Name(id='locals', ctx=ast.Load()),
            args=[], keywords=[]
        )
        new_node.value = ast.Tuple(
            elts=[
                node.value,
                node_locals
            ],
            ctx=ast.Load()
        )
        ast.copy_location(new_node, node)
        ast.fix_missing_locations(new_node)
        self.generic_visit(new_node)
        return new_node 
tree_meta = ast.parse(inspect.getsource(example_2))
transformer = ASTTransformer()
transformer.visit(tree_meta)
example_2_meta_ast_str = astor.dump_tree(tree_meta)
example_2_expected_str = astor.dump_tree(ast.parse(inspect.getsource(example_2_expected)))


pprint(
    list(
        difflib.unified_diff(example_2_meta_ast_str.splitlines(), example_2_expected_str.splitlines(), n=0)
    )
)
['--- \n',
 '+++ \n',
 '@@ -3 +3 @@\n',
 "-        FunctionDef(name='example_2',",
 "+        FunctionDef(name='example_2_expected',"]

Our new ASTTransformer was able to deal with our new complicated example!

Creating a new function at runtime

We have an ASTTransformer , now we must compile the transformed AST into a new function. In python, we can create a new function using the FunctionType, see below

from types import FunctionType, CodeType

def transform_and_compile(func: FunctionType)->FunctionType:
    source = inspect.getsource(func)
    # we put this to remove the line from source code with the decorator
    source = "\n".join([l for l in source.splitlines() if not l.startswith("@")])
    tree = ast.parse(source)
    transformer = ASTTransformer()
    transformer.visit(tree)
    code_obj = compile(tree, func.__code__.co_filename, 'exec')
    function_code = [c for c in code_obj.co_consts if isinstance(c, CodeType)][0]
    # we must to pass the globals context to the function
    transformed_func = FunctionType(function_code, func.__globals__)
    return transformed_func
test_meta_example_1(transform_and_compile(example_1), 4, 2)
test_meta_example_2(transform_and_compile(example_2), 1, 2)

The transform_and_compile was able to create new functions that passed all the tests! We can now move further to the final and easy step which is just to integrate this function with our decorator!

Integrating the AST manipulation with a decorator

We will call the transform_and_compile right after the def decorate to avoid unnecessary compilations every time that the decorated function is called.

def report(fmt):
    def decorate(func):
        meta_func = transform_and_compile(func)
        ....

Inside def decorated we call the meta_func and return just the result because we don’t want to change our codebase.

def report(fmt):
    def decorate(func):
        meta_func = transform_and_compile(func)
        ...
        def decorated(*_args):
            _result, internal_locals = meta_func(*_args)
            ....
            return _result

With all the stuff we learned in the previous post our report decorator with the above changes will be


def report(fmt):
    def decorate(func):
        meta_func = transform_and_compile(func)
        sig = inspect.signature(func)
        def decorated(*_args):
            _result, internal_locals = meta_func(*_args)
            named_args = {}
            num_args = len(_args)
            for i, (k, v) in enumerate(sig.parameters.items()):
                if i < num_args:
                    named_args[k] = repr(_args[i])
                else:
                    named_args[k] = repr(v.default)
            
            name = func.__name__
            result = repr(_result)
            args_dict = {
                **internal_locals,
                **locals(),
                **named_args
            }
            print(fmt.format(**args_dict))
            # store the information in some place
            return result
        return decorated 
    return decorate

Let’s see the result with a dummy function

@report(fmt='{name}(a={a}, b={b}, c={c}); sum_ab {sum_ab}, diff_ab {dif_ab}; r={result}')
def dummy_example(a, b, c=2):
    sum_ab = a + b
    dif_ab = a - b
    r = sum_ab**c + dif_ab**c
    return r

r = dummy_example(2, 3, 1)
print("r:", r)
dummy_example(a=2, b=3, c=1); sum_ab 5, diff_ab -1; r=4
r: 4

I know this post is quite hard to read, but I think it’s worth sharing it. I hope you enjoyed it!

Bruno Messias
Bruno Messias
Ph.D Candidate/Software Developer

Free-software enthusiast, researcher, and software developer. Currently, working in the field of Graphs, Complex Systems and Machine Learning.

comments powered by Disqus

Related