Modifying Python 3 code using abstract syntax trees - python

I'm currently playing around with abstract syntax trees, using the ast and astor modules. The documentation taught me how to retrieve and pretty-print source code for various functions, and various examples on the web show how to modify parts of the code by replacing the contents of one line with another or changing all occurrences of + to *.
However, I would like to insert additional code in various places, specifically when a function calls another function. For instance, the following hypothetical function:
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
return all_other_cases(param)
would yield (once we've used astor.to_source(modified_ast)):
def some_function(param):
if param == 0:
print ("Hey, we're calling case_0")
return case_0(param)
elif param < 0:
print ("Hey, we're calling negative_case")
return negative_case(param)
print ("Seems we're in the general case, calling all_other_cases")
return all_other_cases(param)
Is this possible with abstract syntax trees? (note: I'm aware that decorating functions that are called would produce the same results when running the code, but this is not what I'm after; I need to actually output the modified code, and insert more complicated things than print statements).

It's not clear from your question if you're asking about how to insert nodes into an AST tree at a low level, or more specifically about how to do node insertions with a higher level tool to walk the AST tree (e.g. a subclass of ast.NodeVisitor or astor.TreeWalk).
Inserting nodes at a low level is exceedingly easy. You just use list.insert on an appropriate list in the tree. For instance, here's some code that adds the last of the three print calls you want (the other two would be almost as easy, they'd just require more indexing). Most of the code is building the new AST node for the print call. The actual insertion is very short:
source = """
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
return all_other_cases(param)
"""
tree = ast.parse(source) # parse an ast tree from the source code
# build a new tree of AST nodes to insert into the main tree
message = ast.Str("Seems we're in the general case, calling all_other_cases")
print_func = ast.Name("print", ast.Load())
print_call = ast.Call(print_func, [message], []) # add two None args in Python<=3.4
print_statement = ast.Expr(print_call)
tree.body[0].body.insert(1, print_statement) # doing the actual insert here!
# now, do whatever you want with the modified ast tree.
print(astor.to_source(tree))
The output will be:
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
print("Seems we're in the general case, calling all_other_cases")
return all_other_cases(param)
(Note that the arguments for ast.Call changed between Python 3.4 and 3.5+. If you're using an older version of Python, you may need to add two additional None arguments: ast.Call(print_func, [message], [], None, None))
If you're using a higher level approach, things are a little bit trickier, since the code needs to figure out where to insert the new nodes, rather than using your own knowledge of the input to hard code things.
Here's a quick and dirty implementation of a TreeWalk subclass that adds a print call as the statement before any statement that has a Call node under it. Note that Call nodes include calls to classes (to create instances), not only function calls. This code only handles the outermost of a nested set of calls, so if the code had foo(bar()) the inserted print will only mention foo:
class PrintBeforeCall(astor.TreeWalk):
def pre_body_name(self):
body = self.cur_node
print_func = ast.Name("print", ast.Load())
for i, child in enumerate(body[:]):
self.__name = None
self.walk(child)
if self.__name is not None:
message = ast.Str("Calling {}".format(self.__name))
print_statement = ast.Expr(ast.Call(print_func, [message], []))
body.insert(i, print_statement)
self.__name = None
return True
def pre_Call(self):
self.__name = self.cur_node.func.id
return True
You'd call it like this:
source = """
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
return all_other_cases(param)
"""
tree = ast.parse(source)
walker = PrintBeforeCall() # create an instance of the TreeWalk subclass
walker.walk(tree) # modify the tree in place
print(astor.to_source(tree)
The output this time is:
def some_function(param):
if param == 0:
print('Calling case_0')
return case_0(param)
elif param < 0:
print('Calling negative_case')
return negative_case(param)
print('Calling all_other_cases')
return all_other_cases(param)
That's not quite the exact messages you wanted, but it's close. The walker can't describe the cases being handled in detail since it only looks at the names functions being called, not the conditions that got it there. If you have a very well defined set of things to look for, you could perhaps change it to look at the ast.If nodes, but I suspect that would be a lot more challenging.

Related

Determine if a python function has changed

Context
I am trying to cache executions in a data processing framework (kedro). For this, I want to develop a unique hash for a python function to determine if anything in the function body (or the functions and modules this function calls) has changed. I looked into __code__.co_code. While that nicely ignores comments, spacing etc, it also doesn't change when two functions are obviously different. E.g.
def a():
a = 1
return a
def b():
b = 2
return b
assert a.__code__.co_code != b.__code__.co_code
fails. So the byte code for these two functions is equal.
The ultimate goal: Determine if either a function's code or any of its data inputs have changed. If not and the result already exists, skip execution to save runtime.
Question: How can one get a fingerprint of a functions code in python?
Another idea brought forward by a colleague was this:
import dis
def compare_instructions(func1, func2):
"""compatre instructions of two functions"""
func1_instructions = list(dis.get_instructions(func1))
func2_instructions = list(dis.get_instructions(func2))
# compare every attribute of instructions except for starts_line
for line1, line2 in zip(func1_instructions, func2_instructions):
assert line1.opname == line2.opname
assert line1.opcode == line2.opcode
assert line1.arg == line2.arg
assert line1.argval == line2.argval
assert line1.argrepr == line2.argrepr
assert line1.offset == line2.offset
return True
This seems rather like a hack. Other tools like pytest-testmon try to solve this as well but they appear to be using a number of heuristics.
__code__.co_code returns the byte_code which doesn't reference the constants. Ignore the constants in your functions and they are the same.
__code__.co_consts contains information about the constants so would need to be accounted for in your comparison.
assert a.__code__.co_code != b.__code__.co_code \
or a.__code__.co_consts != b.__code__.co_consts
Looking at inspect highlights a few other considerations for 'sameness'. For example, to ensure the functions below are considered different, default arguments must be accounted for.
def a(a1, a2=1):
return a1 * a2
def b(b1, b2=2):
return b1 * b2
One way to finger print is to use the built-in hash function. Assume the same function defintions as in the OP's example:
def finger_print(func):
return hash(func.__code__.co_consts) + hash(func.__code__.co_code)
assert finger_print(a) != finger_print(b)

Override Standard Assert Messaging in Pytest Assert

I'm using Pytest to test some SQL queries my team runs programmatically over time.
My SQL queries are lists of JSONs - one JSON corresponds to one row of data.
I've got a function that diffs the JSON key:value pairs so that we can point to exactly which values are different for a given row. Ideally, I'd output a list of these diffs instead of the standard output of an assert statement, which ends up looking clunky and not-very-useful for the end user.
You can use Python built-in capability to show custom exception message:
assert response.status_code == 200, "My custom message: actual status code {}".format(response.status_code)
Check it out: https://wiki.python.org/moin/UsingAssertionsEffectively
Pytest give us the hook pytest_assertrepr_compare to add a custom explanation about why an assertion failed.
You can create a class to wrap the JSON string and implement your comparator algorithm overloading the equal operator.
class JSONComparator:
def __init__(self, lst):
self.value = value
def __eq__(self, other):
# Here your algorithm to compare two JSON strings
...
# If they are different, save that information
# We will need it later
self.diff = "..."
return True
# Put the hook in conftest.py or import it in order to make pytest aware of the hook.
def pytest_assertrepr_compare(config, op, left, right):
if isinstance(left, JSONComparator) and op == "==":
# Return the diff inside an array.
return [left.diff]
# Create a reference as an alias if you want
compare = JSONComparator
Usage
def test_somethig():
original = '{"cartoon": "bugs"}'
expected = '{"cartoon": "bugs"}'
assert compare(original) == expected

Use AST module to mutate and delete assignment/function calls

For example if I wanted to change greater than to less than or equal to I have successfully executed:
def visit_Gt(self, node):
new_node = ast.GtE()
return ast.copy_location(new_node, node)
How would I visit/detect an assignment operation (=) and a function call () and simply delete them? I'm reading through the AST documentation and I can't find a way to visit the assignment or function call classes and then return nothing.
An example of what I'm seeking for assignment operations:
print("Start")
x = 5
print("End")
Becomes:
print("Start")
print("End")
And an example of what I'm seeking for deleting function calls:
print("Start")
my_function_call(Args)
print("End")
Becomes
print("Start")
print("End")
You can use a ast.NodeTransformer() subclass to mutate an existing AST tree:
import ast
class RemoveAssignments(ast.NodeTransformer):
def visit_Assign(self, node):
return None
def visit_AugAssign(self, node):
return None
new_tree = RemoveAssignments().visit(old_tree)
The above class removes None to completely remove the node from the input tree. The Assign and AugAssign nodes contain the whole assignment statement, so the expression producing the result, and the target list (1 or more names to assign the result to).
This means that the above will turn
print('Start!')
foo = 'bar'
foo += 'eggs'
print('Done!')
into
print('Start!')
print('Done!')
If you need to make more fine-grained decisions, look at the child nodes of the assignment, either directly, or by passing the child nodes to self.visit() to have the transformer further call visit_* hooks for them if they exist:
class RemoveFunctionCallAssignments(NodeTransformer):
"""Remove assignments of the form "target = name()", so a single name being called
The target list size plays no role.
"""
def visit_Assign(self, node):
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name):
return None
return node
Here, we only return None if the value side of the assignment (the expression on the right-hand side) is a Call node that is applied to a straight-forward Name node. Returning the original node object passed in means that it'll not be replaced.
To replace top-level function calls (so those without an assignment or further expressions), look at Expr nodes; these are expression statements, not just expressions that are part of some other construct. If you have a Expr node with a Call, you can remove it:
def visit_Expr(self, node):
# stand-alone call to a single name is to be removed
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name):
return None
return node
Also see the excellent Green Tree Snakes documentation, which covers working on the AST tree with further examples.

How to write pop(item) method for unsorted list

I'm implementing some basic data structures in preparation for an exam and have come across the following issue. I want to implement an unsorted linked list, and have already implemented a pop() method, however I don't know, either syntactically or conceptually, how to make a function sometimes take an argument, sometimes not take an argument. I hope that makes sense.
def pop(self):
current = self.head
found = False
endOfList = None
while current != None and not found:
if current.getNext() == None:
found = True
endOfList = current.getData()
self.remove(endOfList)
self.count = self.count - 1
else:
current = current.getNext()
return endOfList
I want to know how to make the statement unsortedList.pop(3) valid, 3 being just an example and unsortedList being a new instance of the class.
The basic syntax (and a common use case) for using a parameter with a default value looks like this:
def pop(self, index=None):
if index is not None:
#Do whatever your default behaviour should be
You then just have to identify how you want your behaviour to change based on the argument. I am just guessing that the argument should specify the index of the element that should be pop'ed from the list.
If that is the case you can directly use a valid default value instead of None e.g. 0
def pop(self, index=0):
First, add a parameter with a default value to the function:
def pop(self, item=None):
Now, in the code, if item is None:, you can do the "no param" thing; otherwise, use item. Whether you want to switch at the top, or lower down in the logic, depends on your logic. In this case, item is None probably means "match the first item", so you probably want a single loop that checks item is None or current.data == item:.
Sometimes you'll want to do this for a parameter that can legitimately be None, in which case you need to pick a different sentinel. There are a few questions around here (and blog posts elsewhere) on the pros and cons of different choices. But here's one way:
class LinkedList(object):
_sentinel = object()
def pop(self, item=_sentinel):
Unless it's valid for someone to use the private _sentinel class member of LinkedList as a list item, this works. (If that is valid—e.g., because you're building a debugger out of these things—you have to get even trickier.)
The terminology on this is a bit tricky. Quoting the docs:
When one or more top-level parameters have the form parameter = expression, the function is said to have “default parameter values.”
To understand this: "Parameters" (or "formal parameters") are the things the function is defined to take; "arguments" are things passed to the function in a call expression; "parameter values" (or "actual parameters", but this just makes things more confusing) are the values the function body receives. So, it's technically incorrect to refer to either "default parameters" or "parameters with default arguments", but both are quite common, because even experts find this stuff confusing. (If you're curious, or just not confused yet, see function definitions and calls in the reference documentation for full details.)
Is your exam using Python specifically? If not, you may want to look into function overloading. Python doesn't support this feature, but many other languages do, and is a very common approach to solving this kind of problem.
In Python, you can get a lot of mileage out of using parameters with default values (as Michael Mauderer's example points out).
def pop(self, index=None):
prev = None
current = self.head
if current is None:
raise IndexError("can't pop from empty list")
if index is None:
index = 0 # the first item by default (counting from head)
if index < 0:
index += self.count
if not (0 <= index < self.count):
raise IndexError("index out of range")
i = 0
while i != index:
i += 1
prev = current
current = current.getNext()
assert current is not None # never happens if list is self-consistent
assert i == index
value = current.getData()
self.remove(current, prev)
##self.count -= 1 # this should be in self.remove()
return value

Disjoint-Set forests in Python alternate implementation

I'm implementing a disjoint set system in Python, but I've hit a wall. I'm using a tree implementation for the system and am implementing Find(), Merge() and Create() functions for the system.
I am implementing a rank system and path compression for efficiency.
The catch is that these functions must take the set of disjoint sets as a parameter, making traversing hard.
class Node(object):
def __init__(self, value):
self.parent = self
self.value = value
self.rank = 0
def Create(values):
l = [Node(value) for value in values]
return l
The Create function takes in a list of values and returns a list of singular Nodes containing the appropriate data.
I'm thinking the Merge function would look similar to this,
def Merge(set, value1, value2):
value1Root = Find(set, value1)
value2Root = Find(set, value2)
if value1Root == value2Root:
return
if value1Root.rank < value2Root.rank:
value1Root.parent = value2Root
elif value1Root.rank > value2Root.rank:
value2Root.parent = value1Root
else:
value2Root.parent = value1Root
value1Root.rank += 1
but I'm not sure how to implement the Find() function since it is required to take the list of Nodes and a value (not just a node) as the parameters. Find(set, value) would be the prototype.
I understand how to implement path compression when a Node is taken as a parameter for Find(x), but this method is throwing me off.
Any help would be greatly appreciated. Thank you.
Edited for clarification.
The implementation of this data structure becomes simpler when you realize that the operations union and find can also be implemented as methods of a disjoint set forest class, rather than on the individual disjoint sets.
If you can read C++, then have a look at my take on the data structure; it hides the actual sets from the outside world, representing them only as numeric indices in the API. In Python, it would be something like
class DisjSets(object):
def __init__(self, n):
self._parent = range(n)
self._rank = [0] * n
def find(self, i):
if self._parent[i] == i:
return i
else:
self._parent[i] = self.find(self._parent[i])
return self._parent[i]
def union(self, i, j):
root_i = self.find(i)
root_j = self.find(j)
if root_i != root_j:
if self._rank[root_i] < self._rank[root_j]:
self._parent[root_i] = root_j
elif self._rank[root_i] > self._rank[root_j]:
self._parent[root_j] = root_i
else:
self._parent[root_i] = root_j
self._rank[root_j] += 1
(Not tested.)
If you choose not to follow this path, the client of your code will indeed have to have knowledge of Nodes and Find must take a Node argument.
Clearly merge function should be applied to pair of nodes.
So find function should take single node parameter and look like this:
def find(node):
if node.parent != node:
node.parent = find(node.parent)
return node.parent
Also wikipedia has pseudocode that is easily translatable to python.
Find is always done on an item. Find(item) is defined as returning the set to which the item belongs. Merger as such must not take nodes, merge always takes two items/sets. Merge or union (item1, item2) must first find(item1) and find(item2) which will return the sets to which each of these belong. After that the smaller set represented by an up-tree must be added to the taller. When a find is issued, always retrace the path and compress it.
A tested implementation with path compression is here.

Categories

Resources