Table of Contents
Don’t be afraid of the names on the title. Although they can seem scary or strange probably you already have been in touch with tools that work with this kind of stuff. For example, pytest and numba.
Intro: our previous problem
In the previous post,
I talked about python frames and inspection module. I started showing how we can use the inspect.signature
to construct a decorator that validates arguments:
@math_validator()
def simple_method(x: "\in R", y: "\in R_+", z: float = 2) -> float:
...
simple_method(1, 0)
simple_method((1, 2)) -> 1.5
---> 19 simple_method(1, 0)
...
<locals>.decorate.<locals>.decorated(*_args)
11 continue
13 if not MATH_SPACES[annotation]["validator"](_args[i]):
---> 14 raise ValueError(f"{k} doesn't belong to the {MATH_SPACES[annotation]['name']}")
15 result = func(*_args)
16 print(f"{func.__name__}({_args}) -> {result}")
ValueError: y doesn't belong to the space of real numbers greater than zero
And after that, I combined the inspect.singature
+sys.trace
to construct a decorator that exposes the local variables of a decorated function. All this stuff allows us to do cool things like creating a generic report decorator that has access to the local variables of the decorated method
@report('{arg.n_bananas} Monkey {gluttonous_monkey} ate too much bananas. Num monkeys {num_monkeys}')
def feed_monkeys(n_bananas):
num_monkeys = 3
monkeys = {
f"monkey_{i}": {"bananas": 0}
for i in range(num_monkeys)
}
while n_bananas > 0:
if np.random.uniform() < 0.4:
continue
monkey = monkeys[np.random.choice(list(monkeys.keys()))]
if n_bananas > 0:
monkey["bananas"] += 1
n_bananas -= 1
gluttonous_monkey = max(monkeys, key=lambda k: monkeys[k]["bananas"])
These two examples can be found in real application scenarios. But at the end of my previous post I told you some issues regarding the use of sys.trace
. I’ll put the code here of the previous solution:
Click here to see the solution
import sys
import inspect
from types import SimpleNamespace
def call_and_extract_frame(func, *args, **kwargs):
frame_var = None
trace = sys.gettrace()
def update_frame_var(stack_frame, event_name, arg_frame):
"""
Args:
stack_frame: (frame)
The current stack frame.
event_name: (str)
The name of the event that triggered the call.
Can be 'call', 'line', 'return' and 'exception'.
arg_frame:
Depends on the event. Can be a None type
"""
nonlocal frame_var # nonlocal is a keyword which allows us to modify the outisde scope variable
if event_name != 'call':
return trace
frame_var = stack_frame
sys.settrace(trace)
return trace
sys.settrace(update_frame_var)
try:
func_result = func(*args, **kwargs)
finally:
sys.settrace(trace)
return frame_var, func_result
def report(formater):
def decorate(func):
def decorated(*_args):
sig = inspect.signature(func)
named_args = {}
num_args = len(_args)
for i, (k, v) in enumerate(sig.parameters.items()):
if i < num_args:
named_args[k] = repr(_args[i])
else:
named_args[k] = repr(v.default)
frame_func, _result = call_and_extract_frame(func, *_args)
name = func.__name__
result = repr(_result)
args_dict = {
"args": SimpleNamespace(**named_args),
"args_repr": repr(SimpleNamespace(**named_args)),
**locals(),
**frame_func.f_locals,
}
print(formater.format(**args_dict))
# do other stuff here
return _result
return decorated
return decorate
What are the problems with this solution?
- A tracing always creates a cost. Thus, it is expected that we will reduce the performance of our system. If you use this just for debugging purposes, it’s ok.
- It can create conflicts with other tools and libs that are also trying to use the trace tool
- Seems dirty!
Ok, maybe you’re asking yourself “This guy is overthinking. Why didn’t he just do this?"
@report('stuff goes here')
def func(x, y):
random_var = np.random.uniform()
... #more local vars
result = (x+y)**random_var
return result, locals
The reason is:
- The main point of using this decorator is to avoid any change in other parts of the codebase. For example,
if in any part of the codebase
func
has been called you will have to change to
result = func(x, y) # to
result = func(x, y)[0]
If after you choose to remove the decorator from a function, you will need to roll back all the above changes.
- You will increase the cognitive load in all members of the team who don’t care about what your decorator needs to do.
- If you propose this a solution you’d better just create another function and face the consequences of this increase in complexity in the original codebase.
Ok, maybe you’re now thinking: “Right, this makes sense, but you’re avoiding these issues creating other issues in performance and debugging. It doesn’t sound good except for some special cases”. And I need to agree with you, it’s not a good solution for most of the cases!
The problem we’re facing is that python doesn’t have context managers that can deal with namespaces, although there is an active discussion about this https://mail.python.org/archives/list/python-ideas@python.org/. But don’t worry about this big name. The important point here is that:
In python we are fine with this because it’s a language that turns to be easy to manipulate the Abstract Syntax Tree and recompile a function with the manipulated tree. Doing that way we’re in the realm of metaprogramming. Writing code which writes code. If t’s not clear I’ll try to be more clear now.
ASTs: What are they?
A programming language obviously is at least a language. OK, but what is a language? Do all the human languages share the same building blocks? How can we compare different sentences? These questions seem more proper to be answered by philosophers. Well, maybe this is true, but these questions can also be answered by mathematicians and computer scientists. However, mathematicians and CS people usually prefer to talk using mathematical formalism rather than long debates about the meaning of the stuff. In essence, an AST is a mathematical formalism that allows us to represent a sentence using a well-defined set of rules and structures represented by a tree.
How do you know that a sentence is grammatically correct?
Intuitively, probably you remember a set of rules that you learned during your life about how to organize and compose verbs, nouns, adjectives, adverbs, etc. This set of rules and guidelines is the Syntax of a language. A Syntax Tree is a structure that helps us to understand a sentence.
Take for example the sentence: “I drive a car to my college”, the syntax tree is the following:
What is the advantage of using ASTs? Notice that we don’t need to talk about how many spaces you’re using, we didn’t talk about your calligraphy and besides that, we have a hierarchy structure that allows us to analyze the validity of the sentence per level! If we want to change any element of the sentence we can directly manipulate the node which represents that element for a safe guarantee that the manipulated sentence is still grammatically correct!
It’s not a surprise that ASTs are also a common tool used in computer science to analyze the correctness of a piece of code and as a common part of the process of compiling/interpreting a code. Here we will extend the behavior of a Python decorator manipulating the AST. But before that, I would like to ask you a question:
Python: interpreted or compiled?
Usually, when I meet a Python hater (or even an enthusiast) they say statements like that
- “Python is slow because it’s an interpreted language!"
- “Python sucks because it doesn’t have a compiler!"
Well, these assertions are not true. The important point is that: when people refer to Python commonly they are actually talking about the language Python and the CPython virtual machine. Let’s talk more about these misconceptions.
First, the distinction between interpreted and compiled languages is very blurry today. Second, let’s see something
hello_world = "print('Hello, world!')"
hello_world_obj = compile(hello_world, '<string>', 'single')
Yeah, if you’re trying to defend that Python is interpreted the things start to get harder for you. Why is there a compile available?
exec(hello_world_obj)
Hello, world!
I’m executing a thing that has been compiled??? What is this hello_world_obj?
print(f"Bad news for you:\n\tContent: {hello_world_obj.co_code}\n\tType: {type(hello_world_obj.co_code)}")
Bad news for you:
Content: b'e\x00d\x00\x83\x01F\x00d\x01S\x00'
Type: <class 'bytes'>
But what is this stuff?
It is important to understand what happens behind the scenes.
After you write a code and call the python command, Python starts a compiling phase creating the ASTs; generating the bytecotes that will be attached to code objects, and then, these code objects will be interpreted by the CPython virtual machine. The diagram below is a simple representation of this process with some details hidden
The compilation phase are the firts steps of the above diagram
But don’t worry about most of the big names above. The only concepts that will matter to us are the AST, bytecodes, and Code object. Bytecodes are just a compact way to tell the interpreter what we want to do. The code object is just a way to encapsulate the bytecodes extracted from the AST.
But how does this help us?
Extracting ASTs and interpreting them
Let’s see a simple example of a function and the extracted AST.
import inspect
import ast
import astor # install this for pretty printing
def example(a: float, b:float = 2) -> float:
s = a+b
return s
tree = ast.parse(inspect.getsource(example))
print(astor.dump(tree))
astor.to_source(tree)
Module(
body=[
FunctionDef(name='example',
args=arguments(posonlyargs=[],
args=[arg(arg='a', annotation=Name(id='float'), type_comment=None),
arg(arg='b', annotation=Name(id='float'), type_comment=None)],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[Constant(value=2, kind=None)]),
body=[
Assign(targets=[Name(id='s')],
value=BinOp(left=Name(id='a'), op=Add, right=Name(id='b')),
type_comment=None),
Return(value=Name(id='s'))],
decorator_list=[],
returns=Name(id='float'),
type_comment=None)],
type_ignores=[])
The above output is our AST and the below image show its graph representation. Take some time looking into it to see how all our code stuff is organized.
Each element in the above output with an upper case letter is a node (Name, BinOp, FunctionDef, etc) from the base class ast.Node
. One of the most important node types is the ast.Name
.
For example,
value=BinOp(left=Name(id='a'), op=Add, right=Name(id='b')),
the ast.Name
is used to refer to variable by the name, id
.
Now let’s come back to our problem. Remember that one bad solution was rewriting every function
def func(x, y):
random_var = np.random.uniform()
... #more local vars
result = (x+y)**random_var
return result
as
def func_transformed(x, y):
random_var = np.random.uniform()
... #more local vars
result = (x+y)**random_var
return result, locals
The big stuff that we will do is to write a function that codes new functions for us! This is metaprogramming! And at same time we will write a decorator that will avoid any change in our codebase!
How can I be efficient in metaprogramming?
We must create a function that generates a new one similar to func_transformed
. How to get an idea of what we need to do?
The 6 simple steps
- Create an example function
- Code the transformed function from the example function
- Code a simple test to check if the transformed function is correct
- Extract the AST from the example and the transformed function
- Compare the ASTs. What is the difference? Annotate this difference somewhere
- You can use the
difflib
module that comes with Python to diff strings
- You can use the
- Create a new and more complex example function and repeat the process until you get a good idea of what you need to do.
After you have a good idea of what you need to do, you can start writing your metaprogramming function.
Creating our metaprogramming function
First six-steps interaction
Let’s start our first interaction writing one function, the expected transformed function and the test to check if it is correct.
def example_1(x, y):
internal_var = 222
result = (x+y)**internal_var
return result
def example_1_expected(x, y):
internal_var = 222
result = (x+y)**internal_var
return result, locals()
def test_meta_example_1(meta_func, x, y):
expected_result, expected_locals = example_1_expected(x, y)
result, locals_dict = meta_func(x, y)
assert result == expected_result
assert expected_locals == locals_dict
Everything looks fine. Now we will use the difflib
to see the differences between the two ASTs.
import difflib
from pprint import pprint
example_1_ast_str = astor.dump_tree(ast.parse(inspect.getsource(example_1)))
example_1_expected_str = astor.dump_tree(ast.parse(inspect.getsource(example_1_expected)))
pprint(
list(
difflib.unified_diff(example_1_ast_str.splitlines(), example_1_expected_str.splitlines(), n=0)
)
)
['--- \n',
'+++ \n',
'@@ -3 +3 @@\n',
"- FunctionDef(name='example_1',",
"+ FunctionDef(name='example_1_expected',",
'@@ -19 +19 @@\n',
"- Return(value=Name(id='result'))],",
"+ Return(value=Tuple(elts=[Name(id='result'), "
"Call(func=Name(id='locals'), args=[], keywords=[])]))],"]
Now we know that we must change this Node in the AST
Return(value=Name(id='result'))],
To this
Return(value=Tuple(elts=[Name(id='result'), Call(func=Name(id='locals'), args=[], keywords=[])]))],
How we can do this? With the help of NodeTransformer
class
The NodeTransformer class
The ast.NodeTransformer
allows us to create objects with a walker-like interface. The walker will visit each node in the AST and during each visit, the walker can remove, replace, modify or add nodes, and after that, he can continue to walk to the children of the node or stop there.
How can we use this?
First, we start by creating a new class derived from ast.NodeTransformer
class ASTTransformer(ast.NodeTransformer):
def visit_Return(self, node):
If you want to interact/change/delete a node of type Something
you must override the visit_Something
method. Thus, because we need to change the Return
node we override the visit_Return
. If we do just the following, our walker will not change our AST,
class ASTTransformer(ast.NodeTransformer):
...
Let’s start the modifications. We need to create a new node responsible for calling the locals
class ASTTransformer(ast.NodeTransformer):
def visit_Return(self, node):
node_locals = ast.Call(
func=ast.Name(id='locals', ctx=ast.Load()),
args=[], keywords=[]
)
self.generic_visit(node)
return node
We used a Name
node to identify the locals
function. Now, according to the diff result our Return
node must be transformed into a Return
of a Tuple node
class ASTTransformer(ast.NodeTransformer):
def visit_Return(self, node):
node_locals = ast.Call(
func=ast.Name(id='locals', ctx=ast.Load()),
args=[], keywords=[]
)
new_node.value = ast.Tuple(
elts=[
node.value,
node_locals
],
ctx=ast.Load()
)
self.generic_visit(new_node)
return new_node
A new thing appeared. The elts
argument. But don’t worry, this is just an argument which tells us what the list of other nodes Tuple
has. Whenever you have some doubt about AST stuff, you can check the ast
documentation
here. The documentation is simple to understand because python is simple!
Everything is almost done. The last thing is to fix our AST. Because when we change the Node we need to fill missing information like the line_number and column_offset. Thanks to python we just need to call fix_missing_locations
to fill this for us.
class ASTTransformer(ast.NodeTransformer):
def visit_Return(self, node):
new_node = node
node_locals = ast.Call(
func=ast.Name(id='locals', ctx=ast.Load()),
args=[], keywords=[]
)
new_node.value = ast.Tuple(
elts=[
node.value,
node_locals
],
ctx=ast.Load()
)
ast.copy_location(new_node, node)
ast.fix_missing_locations(new_node)
self.generic_visit(new_node)
return new_node
Ok, let’s see if it is working. We must instantiate our transformer and call the visit
method that tells the walker to walk in the AST and do all the modification we’re asking
tree_meta = ast.parse(inspect.getsource(example_1))
transformer = ASTTransformer()
transformer.visit(tree_meta)
example_1_meta_ast_str = astor.dump_tree(tree_meta)
example_1_expected_str = astor.dump_tree(ast.parse(inspect.getsource(example_1_expected)))
pprint(
list(
difflib.unified_diff(example_1_meta_ast_str.splitlines(), example_1_expected_str.splitlines(), n=0)
)
)
['--- \n',
'+++ \n',
'@@ -3 +3 @@\n',
"- FunctionDef(name='example_1',",
"+ FunctionDef(name='example_1_expected',"]
Our first iteration was successful! Let’s try a more complex example.
The second six-steps interaction
We’ll just add more complexity without any particular meaning, we can be creative!
def example_2(x, y):
internal_var = 222
def sub(x, y):
ommit_this_var = 1
return x - y
result = sub(x,y)**internal_var
return (result, False)
def example_2_expected(x, y):
internal_var = 222
def sub(x, y):
ommit_this_var = 1
return x - y
result = sub(x,y)**internal_var
return ((result, False), locals())
def test_meta_example_2(meta_func, x, y):
expected_result, expected_locals = example_2_expected(x, y)
result, locals_dict = meta_func(x, y)
del locals_dict["sub"]
del expected_locals["sub"]
assert result == expected_result
assert expected_locals == locals_dict
example_2_ast_str = astor.dump_tree(ast.parse(inspect.getsource(example_2)))
example_2_expected_str = astor.dump_tree(ast.parse(inspect.getsource(example_2_expected)))
pprint(
list(
difflib.unified_diff(example_2_ast_str.splitlines(), example_2_expected_str.splitlines(), n=0)
)
)
['--- \n',
'+++ \n',
'@@ -3 +3 @@\n',
"- FunctionDef(name='example_2',",
"+ FunctionDef(name='example_2_expected',",
'@@ -37 +37,4 @@\n',
"- Return(value=Tuple(elts=[Name(id='result'), "
'Constant(value=False, kind=None)]))],',
'+ Return(',
'+ value=Tuple(',
"+ elts=[Tuple(elts=[Name(id='result'), "
'Constant(value=False, kind=None)]),',
"+ Call(func=Name(id='locals'), args=[], "
'keywords=[])]))],']
Now, it’s time to cross the fingers and see if we need to work more
tree_meta = ast.parse(inspect.getsource(example_2))
transformer = ASTTransformer()
transformer.visit(tree_meta)
example_2_meta_ast_str = astor.dump_tree(tree_meta)
example_2_expected_str = astor.dump_tree(ast.parse(inspect.getsource(example_2_expected)))
pprint(
list(
difflib.unified_diff(example_2_meta_ast_str.splitlines(), example_2_expected_str.splitlines(), n=0)
)
)
['--- \n',
'+++ \n',
'@@ -3 +3 @@\n',
"- FunctionDef(name='example_2',",
"+ FunctionDef(name='example_2_expected',",
'@@ -27,4 +27 @@\n',
'- Return(',
'- value=Tuple(',
"- elts=[BinOp(left=Name(id='x'), op=Sub, "
"right=Name(id='y')),",
"- Call(func=Name(id='locals'), args=[], "
'keywords=[])]))],',
"+ Return(value=BinOp(left=Name(id='x'), op=Sub, "
"right=Name(id='y')))],"]
Unfortunately, our ASTTransformer
was not able to deal with this crazy guy. What is the problem? If you check carefully you will notice that the inner function def sub
is the problem. We don’t want to change any “sub” function, so we need to tell our walker to avoid changing this kind of stuff. To do so, we will create a flag to tell if the walker is in a sub-function, and we will just override the visit_FunctionDef
method to check this flag
class ASTTransformer(ast.NodeTransformer):
def visit_FunctionDef(self, node):
if self._sub:
return node
self._sub = True
self.generic_visit(node)
return node
def visit_Module(self, node):
self._sub = 0
self.generic_visit(node)
def visit_Return(self, node):
new_node = node
node_locals = ast.Call(
func=ast.Name(id='locals', ctx=ast.Load()),
args=[], keywords=[]
)
new_node.value = ast.Tuple(
elts=[
node.value,
node_locals
],
ctx=ast.Load()
)
ast.copy_location(new_node, node)
ast.fix_missing_locations(new_node)
self.generic_visit(new_node)
return new_node
tree_meta = ast.parse(inspect.getsource(example_2))
transformer = ASTTransformer()
transformer.visit(tree_meta)
example_2_meta_ast_str = astor.dump_tree(tree_meta)
example_2_expected_str = astor.dump_tree(ast.parse(inspect.getsource(example_2_expected)))
pprint(
list(
difflib.unified_diff(example_2_meta_ast_str.splitlines(), example_2_expected_str.splitlines(), n=0)
)
)
['--- \n',
'+++ \n',
'@@ -3 +3 @@\n',
"- FunctionDef(name='example_2',",
"+ FunctionDef(name='example_2_expected',"]
Our new ASTTransformer
was able to deal with our new complicated example!
Creating a new function at runtime
We have an ASTTransformer
, now we must compile the transformed AST
into a new function. In python, we can create a new function using the FunctionType
, see below
from types import FunctionType, CodeType
def transform_and_compile(func: FunctionType)->FunctionType:
source = inspect.getsource(func)
# we put this to remove the line from source code with the decorator
source = "\n".join([l for l in source.splitlines() if not l.startswith("@")])
tree = ast.parse(source)
transformer = ASTTransformer()
transformer.visit(tree)
code_obj = compile(tree, func.__code__.co_filename, 'exec')
function_code = [c for c in code_obj.co_consts if isinstance(c, CodeType)][0]
# we must to pass the globals context to the function
transformed_func = FunctionType(function_code, func.__globals__)
return transformed_func
test_meta_example_1(transform_and_compile(example_1), 4, 2)
test_meta_example_2(transform_and_compile(example_2), 1, 2)
The transform_and_compile
was able to create new functions that passed all the tests! We can now move further to the final and easy step which is just to integrate this function with our decorator!
Integrating the AST manipulation with a decorator
We will call the transform_and_compile
right after the def decorate
to avoid unnecessary compilations every time that the decorated function is called.
def report(fmt):
def decorate(func):
meta_func = transform_and_compile(func)
....
Inside def decorated
we call the meta_func
and return just the result because we don’t want to change our codebase.
def report(fmt):
def decorate(func):
meta_func = transform_and_compile(func)
...
def decorated(*_args):
_result, internal_locals = meta_func(*_args)
....
return _result
With all the stuff we learned in the previous post our report
decorator with the above changes will be
def report(fmt):
def decorate(func):
meta_func = transform_and_compile(func)
sig = inspect.signature(func)
def decorated(*_args):
_result, internal_locals = meta_func(*_args)
named_args = {}
num_args = len(_args)
for i, (k, v) in enumerate(sig.parameters.items()):
if i < num_args:
named_args[k] = repr(_args[i])
else:
named_args[k] = repr(v.default)
name = func.__name__
result = repr(_result)
args_dict = {
**internal_locals,
**locals(),
**named_args
}
print(fmt.format(**args_dict))
# store the information in some place
return result
return decorated
return decorate
Let’s see the result with a dummy function
@report(fmt='{name}(a={a}, b={b}, c={c}); sum_ab {sum_ab}, diff_ab {dif_ab}; r={result}')
def dummy_example(a, b, c=2):
sum_ab = a + b
dif_ab = a - b
r = sum_ab**c + dif_ab**c
return r
r = dummy_example(2, 3, 1)
print("r:", r)
dummy_example(a=2, b=3, c=1); sum_ab 5, diff_ab -1; r=4
r: 4
I know this post is quite hard to read, but I think it’s worth sharing it. I hope you enjoyed it!