Use AST module to mutate and delete assignment/function calls - python

For example if I wanted to change greater than to less than or equal to I have successfully executed:
def visit_Gt(self, node):
new_node = ast.GtE()
return ast.copy_location(new_node, node)
How would I visit/detect an assignment operation (=) and a function call () and simply delete them? I'm reading through the AST documentation and I can't find a way to visit the assignment or function call classes and then return nothing.
An example of what I'm seeking for assignment operations:
print("Start")
x = 5
print("End")
Becomes:
print("Start")
print("End")
And an example of what I'm seeking for deleting function calls:
print("Start")
my_function_call(Args)
print("End")
Becomes
print("Start")
print("End")

You can use a ast.NodeTransformer() subclass to mutate an existing AST tree:
import ast
class RemoveAssignments(ast.NodeTransformer):
def visit_Assign(self, node):
return None
def visit_AugAssign(self, node):
return None
new_tree = RemoveAssignments().visit(old_tree)
The above class removes None to completely remove the node from the input tree. The Assign and AugAssign nodes contain the whole assignment statement, so the expression producing the result, and the target list (1 or more names to assign the result to).
This means that the above will turn
print('Start!')
foo = 'bar'
foo += 'eggs'
print('Done!')
into
print('Start!')
print('Done!')
If you need to make more fine-grained decisions, look at the child nodes of the assignment, either directly, or by passing the child nodes to self.visit() to have the transformer further call visit_* hooks for them if they exist:
class RemoveFunctionCallAssignments(NodeTransformer):
"""Remove assignments of the form "target = name()", so a single name being called
The target list size plays no role.
"""
def visit_Assign(self, node):
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name):
return None
return node
Here, we only return None if the value side of the assignment (the expression on the right-hand side) is a Call node that is applied to a straight-forward Name node. Returning the original node object passed in means that it'll not be replaced.
To replace top-level function calls (so those without an assignment or further expressions), look at Expr nodes; these are expression statements, not just expressions that are part of some other construct. If you have a Expr node with a Call, you can remove it:
def visit_Expr(self, node):
# stand-alone call to a single name is to be removed
if isinstance(node.value, ast.Call) and isinstance(node.value.func, ast.Name):
return None
return node
Also see the excellent Green Tree Snakes documentation, which covers working on the AST tree with further examples.

Related

How to remove a child node from Python AST parent node with NodeTransformer?

Given a part of an AST for some code, I need to remove particular default assignments from function definition. To be specific, I need to remove all the variables that are contained in a list vars_to_remove from function definitions where these variables are used as parameter.
For example, take vars_to_remove = ['sum1'] and a function def do_smth(sum = sum1):. Assume sum1 has been defined globally previously, and added to that list.
I fail to find an way to remove what I want without either removing the whole FunctionDef Node. Or being unable to uniquely identify the node I need to remove. See my attempts below.
I tried to solve the problem in two ways:
Parent way:
I override def visit_FunctionDef(self, node), I am able to access the parameter I need to remove. However, if I return None, this will remove the whole FunctionDef node from the tree, because this is the node that is passed in. While I need to only remove all the nodes corresponding to node.args.defaults[0]...node.args.defaults[n].
Child way:
I override visit_Name(self, node). When I do that, I am able to return None, removing the node. However, this deals with all Name nodes in the whole code (from which the AST is derived), not the ones defined exclusively within function definitions. I am removing id:'sum1', but I do not need to necessarily remove all of the other occurrences of Name nodes in the whole program where the id = sum1!
I think I am missing an easy solution.
You can do this my modifying the function's arguments in the call to visit_FunctionDef:
class Transformer(ast.NodeTransformer):
def visit_FunctionDef(self, node):
defaults = node.args.defaults
# Track args with and without defaults. We want to keep args
# without defaults. For args with defaults, we need to sync the
# defaults to be removed with their corresponding args.
args_without_defaults = node.args.args[: -len(defaults)]
args_with_defaults = node.args.args[-len(defaults) :]
# Filter out unwanted args and defaults.
retain = [
(a, d)
for (a, d) in zip(args_with_defaults, defaults)
if d.id not in vars_to_remove
]
node.args.args = args_without_defaults + [a for (a, d) in retain]
node.args.defaults = [d for (a, d) in retain]
return node
Given this source:
def func1(sum=sum1):
pass
def func2(a=spam, b=sum1):
pass
def func3(a=spam):
pass
def func4(a=spam, b=sum1, c=eggs):
pass
this result is output:
def func1():
pass
def func2(a=spam):
pass
def func3(a=spam):
pass
def func4(a=spam, c=eggs):
pass

Printing linked-list in python

In my task first I need to make single linked-list from array.
My code:
class Node:
def __init__(self,data):
self.data = data
self.next = next
class Lista:
def __init__(self, lista=None)
self.head = None
def ispis(self):
printval = self.head
while printval .next is not None:
print(printval.next.data)
printval = printval.next
if __name__ == '__main__'
L = Lista ([2, "python", 3, "bill", 4, "java"])
ispis(L)
With function ispis I need to print elements of linked-list. But it says name "ispis" is not defined. Cannot change ispis(L) !
EDIT: removed next from and ispis(self) is moved outside Lista class
while printvla.next is not None:
EDIT2:
It shows that L is empty so thats why it won't print anything. Should I add elements to class Node ?
ispis is a method in a class. But you are calling the function as if it is a normal function outside the class.
Atleast you have created the object correctly. Below would be the correct way of calling the method inside the class.
L.ispis()
This question sounds suspiciously like a homework assignment. If the instructor is trying to teach you how to create linked lists, you need to go back to what you need to do:
A node when set up for the first time only needs the data. Typically the next pointer/value would be set to None (meaning no next member).
Your __init__ method for your Lista class needs to do something with its argument.
I believe if you need to use your ispls function to operate on a class, then the function probably isn't supposed to be a member of Lista.
I think your ispls loop shouldn't be testing its .next member. This would fail if you had a None to begin with. You should be testing the current instance rather than its next. That way, when you move on to the next node, if it's None, it gets out of the loop.
Be careful with the keyword next. I would avoid using it as a class attribute. Also the literal next would just give you the built-in command.
At the bare minimum, you would want to iterate over the lista argument in __init__, creating a Node for each one, saving the previous node for the next operation.
if lista is None:
self.head = None
return
prev = None
for data in lista:
node = Node(data)
if prev is None:
self.head = node
else:
prev.next = node
prev = node
But again, I believe that is what the instructor wanted you to figure out. Hope this helped.
--B

Recursive loop not having expected behavior

So I create a tree in python. I am trying to change some value of every child of the root node. But, every node in my tree is not being hit.
class Node(object):
def __init__(self, value, priority):
self.parent = None
self.children = []
self.value = value
self.priority = priority
def add_child(self, obj):
self.children.insert(obj)
obj.parent = self
def getChildren(self):
return self.children.getAll()
tom = Node("DD",1)
tom.add_child(Node("a",0.3))
tom.add_child (Node("b", 0.6))
tom.getChildren()[0].add_child(Node("c",1))
tom.getChildren()[1].add_child(Node("d",1))
#print(tom.popHighestValue().value)
def getAll(currentNode):
print(currentNode.value)
if(currentNode.getChildren != []):
for sibling in currentNode.getChildren():
sibling.priority = 1
return getAll(sibling)
The tree should look like:
DD
/\
a b
/
c
But only DD->a->c are being hit. I thought the for loop state would be saved and continued after DD -> c was traversed.
The goal is that every node in the tree is hit. And the priority value is set to 1.
A return statement always exits the current function. If you're in a loop when you do this, the rest of the loop is never executed.
If you need to return all the values of the recursive calls, you need to collect them in a list during the loop, then return that list after the loop is done.
But in this case there doesn't seem to be anything you need to return. This function is just for setting an attribute, there's nothing being extracted. So just make the recursive call without returning.
def getAll(currentNode):
print(currentNode.value)
for sibling in currentNode.getChildren():
sibling.priority = 1
getAll(sibling)
BTW, this won't set the priority of the root node, since it only sets the priority of children. If you want to include the root node, it should be:
def getAll(currentNode):
print(currentNode.value)
currentNode.priority = 1
for sibling in currentNode.getChildren():
getAll(sibling)
Also, you shouldn't call getAll() in the getChildren() method. It just return self.children, not self.children.getAll().
If you remove the return before calling getAll() and place it outside the enclosing for loop, that would fix your problem.
In your code, you are unable to process all the children because right after your first iteration you call getAll() with the return statement. So, all the other siblings except first are/will not be explored at every depth.

How can I replace OrderedDict with dict in a Python AST before literal_eval?

I have a string with Python code in it that I could evaluate as Python with literal_eval if it only had instances of OrderedDict replaced with {}.
I am trying to use ast.parse and ast.NodeTransformer to do the replacement, but when I catch the node with nodetype == 'Name' and node.id == 'OrderedDict', I can't find the list that is the argument in the node object so that I can replace it with a Dict node.
Is this even the right approach?
Some code:
from ast import NodeTransformer, parse
py_str = "[OrderedDict([('a', 1)])]"
class Transformer(NodeTransformer):
def generic_visit(self, node):
nodetype = type(node).__name__
if nodetype == 'Name' and node.id == 'OrderedDict':
pass # ???
return NodeTransformer.generic_visit(self, node)
t = Transformer()
tree = parse(py_str)
t.visit(tree)
The idea is to replace all OrderedDict nodes, represented as ast.Call having specific attributes (which can be seen from ordered_dict_conditions below), with ast.Dict nodes whose key / value arguments are extracted from the ast.Call arguments.
import ast
class Transformer(ast.NodeTransformer):
def generic_visit(self, node):
# Need to call super() in any case to visit child nodes of the current one.
super().generic_visit(node)
ordered_dict_conditions = (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Name)
and node.func.id == 'OrderedDict'
and len(node.args) == 1
and isinstance(node.args[0], ast.List)
)
if ordered_dict_conditions:
return ast.Dict(
[x.elts[0] for x in node.args[0].elts],
[x.elts[1] for x in node.args[0].elts]
)
return node
def transform_eval(py_str):
return ast.literal_eval(Transformer().visit(ast.parse(py_str, mode='eval')).body)
print(transform_eval("[OrderedDict([('a', 1)]), {'k': 'v'}]")) # [{'a': 1}, {'k': 'v'}]
print(transform_eval("OrderedDict([('a', OrderedDict([('b', 1)]))])")) # {'a': {'b': 1}}
Notes
Because we want to replace the innermost node first, we place a call to super() at the beginning of the function.
Whenever an OrderedDict node is encountered, the following things are used:
node.args is a list containing the arguments to the OrderedDict(...) call.
This call has a single argument, namely a list containing key-value pairs as tuples, which is accessible by node.args[0] (ast.List) and node.args[0].elts are the tuples wrapped in a list.
So node.args[0].elts[i] are the different ast.Tuples (for i in range(len(node.args[0].elts))) whose elements are accessible again via the .elts attribute.
Finally node.args[0].elts[i].elts[0] are the keys and node.args[0].elts[i].elts[1] are the values which are used in the OrderedDict call.
The latter keys and values are then used to create a fresh ast.Dict instance which is then used to replace the current node (which was ast.Call).
You could use the ast.NodeVisitor class to observe the OrderedDict tree in order to build the {} tree manually from the encountered nodes, using the parsed nodes from an empty dict as a basis.
import ast
from collections import deque
class Builder(ast.NodeVisitor):
def __init__(self):
super().__init__()
self._tree = ast.parse('[{}]')
self._list_node = self._tree.body[0].value
self._dict_node = self._list_node.elts[0]
self._new_item = False
def visit_Tuple(self, node):
self._new_item = True
self.generic_visit(node)
def visit_Str(self, node):
if self._new_item:
self._dict_node.keys.append(node)
self.generic_visit(node)
def visit_Num(self, node):
if self._new_item:
self._dict_node.values.append(node)
self._new_item = False
self.generic_visit(node)
def literal_eval(self):
return ast.literal_eval(self._list_node)
builder = Builder()
builder.visit(ast.parse("[OrderedDict([('a', 1)])]"))
print(builder.literal_eval())
Note that this only works for the simple structure of your example which uses str as keys and int as values. However extensions to more complex structures should be possible in a similar fashion.
Instead of using ast for parsing and transforming the expression you could also use a regular expression for doing that. For example:
>>> re.sub(
... r"OrderedDict\(\[((\(('[a-z]+'), (\d+)\)),?\s*)+\]\)",
... r'{\3: \4}',
... "[OrderedDict([('a', 1)])]"
... )
"[{'a': 1}]"
The above expression is based on the example string of the OP and considers single quoted strings as keys and positive integers as values, but of course it can be extended to more complex cases.

Modifying Python 3 code using abstract syntax trees

I'm currently playing around with abstract syntax trees, using the ast and astor modules. The documentation taught me how to retrieve and pretty-print source code for various functions, and various examples on the web show how to modify parts of the code by replacing the contents of one line with another or changing all occurrences of + to *.
However, I would like to insert additional code in various places, specifically when a function calls another function. For instance, the following hypothetical function:
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
return all_other_cases(param)
would yield (once we've used astor.to_source(modified_ast)):
def some_function(param):
if param == 0:
print ("Hey, we're calling case_0")
return case_0(param)
elif param < 0:
print ("Hey, we're calling negative_case")
return negative_case(param)
print ("Seems we're in the general case, calling all_other_cases")
return all_other_cases(param)
Is this possible with abstract syntax trees? (note: I'm aware that decorating functions that are called would produce the same results when running the code, but this is not what I'm after; I need to actually output the modified code, and insert more complicated things than print statements).
It's not clear from your question if you're asking about how to insert nodes into an AST tree at a low level, or more specifically about how to do node insertions with a higher level tool to walk the AST tree (e.g. a subclass of ast.NodeVisitor or astor.TreeWalk).
Inserting nodes at a low level is exceedingly easy. You just use list.insert on an appropriate list in the tree. For instance, here's some code that adds the last of the three print calls you want (the other two would be almost as easy, they'd just require more indexing). Most of the code is building the new AST node for the print call. The actual insertion is very short:
source = """
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
return all_other_cases(param)
"""
tree = ast.parse(source) # parse an ast tree from the source code
# build a new tree of AST nodes to insert into the main tree
message = ast.Str("Seems we're in the general case, calling all_other_cases")
print_func = ast.Name("print", ast.Load())
print_call = ast.Call(print_func, [message], []) # add two None args in Python<=3.4
print_statement = ast.Expr(print_call)
tree.body[0].body.insert(1, print_statement) # doing the actual insert here!
# now, do whatever you want with the modified ast tree.
print(astor.to_source(tree))
The output will be:
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
print("Seems we're in the general case, calling all_other_cases")
return all_other_cases(param)
(Note that the arguments for ast.Call changed between Python 3.4 and 3.5+. If you're using an older version of Python, you may need to add two additional None arguments: ast.Call(print_func, [message], [], None, None))
If you're using a higher level approach, things are a little bit trickier, since the code needs to figure out where to insert the new nodes, rather than using your own knowledge of the input to hard code things.
Here's a quick and dirty implementation of a TreeWalk subclass that adds a print call as the statement before any statement that has a Call node under it. Note that Call nodes include calls to classes (to create instances), not only function calls. This code only handles the outermost of a nested set of calls, so if the code had foo(bar()) the inserted print will only mention foo:
class PrintBeforeCall(astor.TreeWalk):
def pre_body_name(self):
body = self.cur_node
print_func = ast.Name("print", ast.Load())
for i, child in enumerate(body[:]):
self.__name = None
self.walk(child)
if self.__name is not None:
message = ast.Str("Calling {}".format(self.__name))
print_statement = ast.Expr(ast.Call(print_func, [message], []))
body.insert(i, print_statement)
self.__name = None
return True
def pre_Call(self):
self.__name = self.cur_node.func.id
return True
You'd call it like this:
source = """
def some_function(param):
if param == 0:
return case_0(param)
elif param < 0:
return negative_case(param)
return all_other_cases(param)
"""
tree = ast.parse(source)
walker = PrintBeforeCall() # create an instance of the TreeWalk subclass
walker.walk(tree) # modify the tree in place
print(astor.to_source(tree)
The output this time is:
def some_function(param):
if param == 0:
print('Calling case_0')
return case_0(param)
elif param < 0:
print('Calling negative_case')
return negative_case(param)
print('Calling all_other_cases')
return all_other_cases(param)
That's not quite the exact messages you wanted, but it's close. The walker can't describe the cases being handled in detail since it only looks at the names functions being called, not the conditions that got it there. If you have a very well defined set of things to look for, you could perhaps change it to look at the ast.If nodes, but I suspect that would be a lot more challenging.

Categories

Resources