How to change this tree-recursion to a tail-recursion? - python

I am writing a function ChrNumber that converts Arab number string to Chinese financial number string. I work out a tree recursion form. But when I tried to get a tail-recursion form, it is really difficult for me to handle the situation bit equals 6,7 or 8 or 10 and bigger ones.
You can see how it works at the end of my question.
Here's the tree-recursion solution. It works:
# -*- coding:utf-8 -*-
unitArab=(2,3,4,5,9)
#unitStr=u'十百千万亿' #this is an alternative
unitStr=u'拾佰仟万亿'
unitDic=dict(zip(unitArab,(list(unitStr))))
numArab=list(u'0123456789')
#numStr=u'零一二三四五六七八九' #this is an alternative
numStr=u'零壹贰叁肆伍陆柒捌玖'
numDic=dict(zip(numArab,list(numStr)))
def ChnNumber(s):
def wrapper(v):
'this is to adapt the string to a abbreviation'
if u'零零' in v:
return wrapper(v.replace(u'零零',u'零'))
return v[:-1] if v[-1]==u'零' else v
def recur(s,bit):
'receives the number sting and its length'
if bit==1:
return numDic[s]
if s[0]==u'0':
return wrapper(u'%s%s' % (u'零',recur(s[1:],bit-1)))
if bit<6 or bit==9:
return wrapper(u'%s%s%s' % (numDic[s[0]],unitDic[bit],recur(s[1:],bit-1)))
'below is the hard part to be converted to tail-recurion'
if bit<9:
return u'%s%s%s' % (recur(s[:-4],bit-4),u"万",recur(s[-4:],4))
if bit>9:
return u'%s%s%s' % (recur(s[:-8],bit-8),u"亿",recur(s[-8:],8))
return recur(s,len(s))
My attempt version is only in recur function, I use a closure res and move the bit inside the recur so there is less arguments.:
res=[]
def recur(s):
bit=len(s)
print s,bit,res
if bit==0:
return ''.join(res)
if bit==1:
res.append(numDic[s])
return recur(s[1:])
if s[0]==u'0':
res.append(u'零')
return recur(s[1:])
if bit<6 or bit==9:
res.append(u'%s%s' %(numDic[s[0]],unitDic[bit]))
return recur(s[1:])
if bit<9:
#...can't work it out
if bit>9:
#...can't work it out
the test code is:
for i in range(17):
v1='9'+'0'*(i+1)
v2='9'+'0'*i+'9'
v3='1'*(i+2)
print '%s->%s\n%s->%s\n%s->%s'% (v1,ChnNumber(v1),v2,ChnNumber(v2),v3,ChnNumber(v3))
which should output:
>>>
90->玖拾
99->玖拾玖
11->壹拾壹
900->玖佰
909->玖佰零玖
111->壹佰壹拾壹
9000->玖仟
9009->玖仟零玖
1111->壹仟壹佰壹拾壹
90000->玖万
90009->玖万零玖
11111->壹万壹仟壹佰壹拾壹
900000->玖拾万
900009->玖拾万零玖
111111->壹拾壹万壹仟壹佰壹拾壹
9000000->玖佰万
9000009->玖佰万零玖
1111111->壹佰壹拾壹万壹仟壹佰壹拾壹
90000000->玖仟万
90000009->玖仟万零玖
11111111->壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
900000000->玖亿
900000009->玖亿零玖
111111111->壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
9000000000->玖拾亿
9000000009->玖拾亿零玖
1111111111->壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
90000000000->玖佰亿
90000000009->玖佰亿零玖
11111111111->壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
900000000000->玖仟亿
900000000009->玖仟亿零玖
111111111111->壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
9000000000000->玖万亿
9000000000009->玖万亿零玖
1111111111111->壹万壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
90000000000000->玖拾万亿
90000000000009->玖拾万亿零玖
11111111111111->壹拾壹万壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
900000000000000->玖佰万亿
900000000000009->玖佰万亿零玖
111111111111111->壹佰壹拾壹万壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
9000000000000000->玖仟万亿
9000000000000009->玖仟万亿零玖
1111111111111111->壹仟壹佰壹拾壹万壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
90000000000000000->玖亿亿
90000000000000009->玖亿亿零玖
11111111111111111->壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹
900000000000000000->玖拾亿亿
900000000000000009->玖拾亿亿零玖
111111111111111111->壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹亿壹仟壹佰壹拾壹万壹仟壹佰壹拾壹

Python doesn't support tail call elimination nor tail call optimizations. However, there are a number of ways in which you can mimic this approach (Trampolines being the most widely used in other languages.)
Tail call recursive functions should look like the following pseudo code:
def tail_call(*args, acc):
if condition(*args):
return acc
else:
# Operations happen here, producing new_args and new_acc
return tail_call(*new_args, new_acc)
For your example I would not form a closure over anything as your are introducing side-effects and stateful manipulation. Instead, anything that needs to be modified should be modified in isolation of everything else. That makes it easier to reason about.
Copy whatever you're attempting to change (using string.copy for the final output) and pass it in as an argument to the next recursive call. That's where the acc variable comes into play. It's "accumulating" all your changes up to that point.
A classical trampoline can be had from this snippet. There, they are wrapping the function in an object which will eventually either result a result or return another function object which should be called. I prefer this approach as I find it easier to reason about.
This isn't the only way. Take a look at this code snippet. The "magic" occurs when it reaches a point which "solves" the condition and it throws an exception to escape the infinite loop.
Finally, you can read about Trampolines here, here and here.

I keep studying this question off and on these days. and now, I work it out!
NOTE,not just tail-recursion, it's also pure Functional Programming!
The key is to think in a different way (tree-recursion version is processing numbers from left to right while this version is from right to left)
unitDic=dict(zip(range(8),u'拾佰仟万拾佰仟亿'))
numDic=dict(zip('0123456789',u'零壹贰叁肆伍陆柒捌玖'))
wapDic=[(u'零拾',u'零'),(u'零佰',u'零'),(u'零仟',u'零'),
(u'零万',u'万'),(u'零亿',u'亿'),(u'亿万',u'亿'),
(u'零零',u'零'),]
#pure FP
def ChnNumber(s):
def wrapper(s,wd=wapDic):
def rep(s,k,v):
if k in s:
return rep(s.replace(k,v),k,v)
return s
if not wd:
return s
return wrapper(rep(s,*wd[0]),wd[1:])
def recur(s,acc='',ind=0):
if s=='':
return acc
return recur(s[:-1],numDic[s[-1]]+unitDic[ind%8]+acc,ind+1)
def end(s):
if s[-1]!='0':
return numDic[s[-1]]
return ''
def result(start,end):
if end=='' and start[-1]==u'零':
return start[:-1]
return start+end
return result(wrapper(recur(s[:-1])),end(s))
for i in range(18):
v1='9'+'0'*(i+1)
v2='9'+'0'*i+'9'
v3='1'*(i+2)
print ('%s->%s\n%s->%s\n%s->%s'% (v1,ChnNumber(v1),v2,ChnNumber(v2),v3,ChnNumber(v3)))
if any one say that it won't work when facing a huge number(something like a billion-figure number), yeah, I admit that, but this version can solve it(while it will not be pure FP but pure FP won't need this version so..):
class TailCaller(object) :
def __init__(self, f) :
self.f = f
def __call__(self, *args, **kwargs) :
ret = self.f(*args, **kwargs)
while type(ret) is TailCall :
ret = ret.handle()
return ret
class TailCall(object) :
def __init__(self, call, *args, **kwargs) :
self.call = call
self.args = args
self.kwargs = kwargs
def handle(self) :
if type(self.call) is TailCaller :
return self.call.f(*self.args, **self.kwargs)
else :
return self.f(*self.args, **self.kwargs)
def ChnNumber(s):
def wrapper(s,wd=wapDic):
#TailCaller
def rep(s,k,v):
if k in s:
return TailCall(rep,s.replace(k,v),k,v)
return s
if not wd:
return s
return wrapper(rep(s,*wd[0]),wd[1:])
#TailCaller
def recur(s,acc='',ind=0):
if s=='':
return acc
return TailCall(recur,s[:-1],numDic[s[-1]]+unitDic[ind%8]+acc,ind+1)
def end(s):
if s[-1]!='0':
return numDic[s[-1]]
return ''
def result(start,end):
if end=='' and start[-1]==u'零':
return start[:-1]
return start+end
return result(wrapper(recur(s[:-1])),end(s))

Related

Getting the value of a mutable keyword argument of a decorator

I have the following code, in which I simply have a decorator for caching a function's results, and as a concrete implementation, I used the Fibonacci function.
After playing around with the code, I wanted to print the cache variable, that's initiated in the cache wrapper.
(It's not because I suspect the cache might be faulty, I simply want to know how to access it without going into debug mode and put a breakpoint inside the decorator)
I tried to explore the fib_w_cache function in debug mode, which is supposed to actually be the wrapped fib_w_cache, but with no success.
import timeit
def cache(f, cache = dict()):
def args_to_str(*args, **kwargs):
return str(args) + str(kwargs)
def wrapper(*args, **kwargs):
args_str = args_to_str(*args, **kwargs)
if args_str in cache:
#print("cache used for: %s" % args_str)
return cache[args_str]
else:
val = f(*args, **kwargs)
cache[args_str] = val
return val
return wrapper
#cache
def fib_w_cache(n):
if n == 0: return 0
elif n == 1: return 1
else:
return fib_w_cache(n-2) + fib_w_cache(n-1)
def fib_wo_cache(n):
if n == 0: return 0
elif n == 1: return 1
else:
return fib_wo_cache(n-1) + fib_wo_cache(n-2)
print(timeit.timeit('[fib_wo_cache(i) for i in range(0,30)]', globals=globals(), number=1))
print(timeit.timeit('[fib_w_cache(i) for i in range(0,30)]', globals=globals(), number=1))
I admit this is not an "elegant" solution in a sense, but keep in mind that python functions are also objects. So with some slight modification to your code, I managed to inject the cache as an attribute of a decorated function:
import timeit
def cache(f):
def args_to_str(*args, **kwargs):
return str(args) + str(kwargs)
def wrapper(*args, **kwargs):
args_str = args_to_str(*args, **kwargs)
if args_str in wrapper._cache:
#print("cache used for: %s" % args_str)
return wrapper._cache[args_str]
else:
val = f(*args, **kwargs)
wrapper._cache[args_str] = val
return val
wrapper._cache = {}
return wrapper
#cache
def fib_w_cache(n):
if n == 0: return 0
elif n == 1: return 1
else:
return fib_w_cache(n-2) + fib_w_cache(n-1)
#cache
def fib_w_cache_1(n):
if n == 0: return 0
elif n == 1: return 1
else:
return fib_w_cache(n-2) + fib_w_cache(n-1)
def fib_wo_cache(n):
if n == 0: return 0
elif n == 1: return 1
else:
return fib_wo_cache(n-1) + fib_wo_cache(n-2)
print(timeit.timeit('[fib_wo_cache(i) for i in range(0,30)]', globals=globals(), number=1))
print(timeit.timeit('[fib_w_cache(i) for i in range(0,30)]', globals=globals(), number=1))
print(fib_w_cache._cache)
print(fib_w_cache_1._cache) # to prove that caches are different instances for different functions
cache is of course a perfectly normal local variable in scope within the cache function, and a perfectly normal nonlocal cellvar in scope within the wrapper function, so if you want to access the value from there, you just do it—as you already are.
But what if you wanted to access it from somewhere else? Then there are two options.
First, cache happens to be defined at the global level, meaning any code anywhere (that hasn't hidden it with a local variable named cache) can access the function object.
And if you're trying to access the values of a function's default parameters from outside the function, they're available in the attributes of the function object. The inspect module docs explain the inspection-oriented attributes of each builtin type:
__defaults__ is a sequence of the values for all positional-or-keyword parameters, in order.
__kwdefaults__ is a mapping from keywords to values for all keyword-only parameters.
So:
>>> def f(a, b=0, c=1, *, d=2, e=3): pass
>>> f.__defaults__
(0, 1)
>>> f.__kwdefaults__
{'e': 3, 'd': 2}
So, for a simple case where you know there's exactly one default value and know which argument it belongs to, all you need is:
>>> cache.__defaults__[0]
{}
If you need to do something more complicated or dynamic, like get the default value for c in the f function above, you need to dig into other information—the only way to know that c's default value will be the second one in __defaults__ is to look at the attributes of the function's code object, like f.__code__.co_varnames, and figure it out from there. But usually, it's better to just use the inspect module's helpers. For example:
>>> inspect.signature(f).parameters['c'].default
1
>>> inspect.signature(cache).parameters['cache'].default
{}
Alternatively, if you're trying to access the cache from inside fib_w_cache, while there's no variable in lexical scope in that function body you can look at, you do know that the function body is only called by the decorator wrapper, and it is available there.
So, you can get your stack frame
frame = inspect.currentframe()
… follow it back to your caller:
back = frame.f_back
… and grab it from that frame's locals:
back.f_locals['cache']
It's worth noting that f_locals works like the locals function: it's actually a copy of the internal locals storage, so modifying it may have no effect, and that copy flattens nonlocal cell variables to regular local variables. If you wanted to access the actual cell variable, you'd have to grub around in things like back.f_code.co_freevars to get the index and then dig it out of the function object's __closure__. But usually, you don't care about that.
Just for a sake of completeness, python has caching decorator built-in in functools.lru_cache with some inspecting mechanisms:
from functools import lru_cache
#lru_cache(maxsize=None)
def fib_w_cache(n):
if n == 0: return 0
elif n == 1: return 1
else:
return fib_w_cache(n-2) + fib_w_cache(n-1)
print('fib_w_cache(10) = ', fib_w_cache(10))
print(fib_w_cache.cache_info())
Prints:
fib_w_cache(10) = 55
CacheInfo(hits=8, misses=11, maxsize=None, currsize=11)
I managed to find a solution (in some sense by #Patrick Haugh's advice).
I simply accessed cache.__defaults__[0] which holds the cache's dict.
The insights about the shared cache and how to avoid it we're also quite useful.
Just as a note, the cache dictionary can only be accessed through the cache function object. It cannot be accessed through the decorated functions (at least as far as I understand). It logically aligns well with the fact that the cache is shared in my implementation, where on the other hand, in the alternative implementation that was proposed, it is local per decorated function.
You can make a class into a wrapper.
def args_to_str(*args, **kwargs):
return str(args) + str(kwargs)
class Cache(object):
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args, **kwargs):
args_str = args_to_str(*args, **kwargs)
if args_str in self.cache:
return self.cache[args_str]
else:
val = self.func(*args, **kwargs)
self.cache[args_str] = val
return val
Each function has its own cache. you can access it by calling function.cache. This also allows for any methods you wish to attach to your function.
If you wanted all decorated functions to share the same cache, you could use a class variable instead of an instance variable:
class SharedCache(object):
cache = {}
def __init__(self, func):
self.func = func
#rest of the the code is the same
#SharedCache
def function_1(stuff):
things

Collect python source code comments along execution path

E.g. I've got the following python function:
def func(x):
"""Function docstring."""
result = x + 1
if result > 0:
# comment 2
return result
else:
# comment 3
return -1 * result
And I want to have some function that would print all function docstrings and comments that are met along the execution path, e.g.
> trace(func(2))
Function docstring.
Comment 2
3
In fact what I try to achieve is to provide some comments how the result has been calculated.
What could be used? AST as far as I understand does not keep comment in the tree.
I thought this was an interesting challenge, so I decided to give it a try. Here is what I came up with:
import ast
import inspect
import re
import sys
import __future__
if sys.version_info >= (3,5):
ast_Call = ast.Call
else:
def ast_Call(func, args, keywords):
"""Compatibility wrapper for ast.Call on Python 3.4 and below.
Used to have two additional fields (starargs, kwargs)."""
return ast.Call(func, args, keywords, None, None)
COMMENT_RE = re.compile(r'^(\s*)#\s?(.*)$')
def convert_comment_to_print(line):
"""If `line` contains a comment, it is changed into a print
statement, otherwise nothing happens. Only acts on full-line comments,
not on trailing comments. Returns the (possibly modified) line."""
match = COMMENT_RE.match(line)
if match:
return '{}print({!r})\n'.format(*match.groups())
else:
return line
def convert_docstrings_to_prints(syntax_tree):
"""Walks an AST and changes every docstring (i.e. every expression
statement consisting only of a string) to a print statement.
The AST is modified in-place."""
ast_print = ast.Name('print', ast.Load())
nodes = list(ast.walk(syntax_tree))
for node in nodes:
for bodylike_field in ('body', 'orelse', 'finalbody'):
if hasattr(node, bodylike_field):
for statement in getattr(node, bodylike_field):
if (isinstance(statement, ast.Expr) and
isinstance(statement.value, ast.Str)):
arg = statement.value
statement.value = ast_Call(ast_print, [arg], [])
def get_future_flags(module_or_func):
"""Get the compile flags corresponding to the features imported from
__future__ by the specified module, or by the module containing the
specific function. Returns a single integer containing the bitwise OR
of all the flags that were found."""
result = 0
for feature_name in __future__.all_feature_names:
feature = getattr(__future__, feature_name)
if (hasattr(module_or_func, feature_name) and
getattr(module_or_func, feature_name) is feature and
hasattr(feature, 'compiler_flag')):
result |= feature.compiler_flag
return result
def eval_function(syntax_tree, func_globals, filename, lineno, compile_flags,
*args, **kwargs):
"""Helper function for `trace`. Execute the function defined by
the given syntax tree, and return its return value."""
func = syntax_tree.body[0]
func.decorator_list.insert(0, ast.Name('_trace_exec_decorator', ast.Load()))
ast.increment_lineno(syntax_tree, lineno-1)
ast.fix_missing_locations(syntax_tree)
code = compile(syntax_tree, filename, 'exec', compile_flags, True)
result = [None]
def _trace_exec_decorator(compiled_func):
result[0] = compiled_func(*args, **kwargs)
func_locals = {'_trace_exec_decorator': _trace_exec_decorator}
exec(code, func_globals, func_locals)
return result[0]
def trace(func, *args, **kwargs):
"""Run the given function with the given arguments and keyword arguments,
and whenever a docstring or (whole-line) comment is encountered,
print it to stdout."""
filename = inspect.getsourcefile(func)
lines, lineno = inspect.getsourcelines(func)
lines = map(convert_comment_to_print, lines)
modified_source = ''.join(lines)
compile_flags = get_future_flags(func)
syntax_tree = compile(modified_source, filename, 'exec',
ast.PyCF_ONLY_AST | compile_flags, True)
convert_docstrings_to_prints(syntax_tree)
return eval_function(syntax_tree, func.__globals__,
filename, lineno, compile_flags, *args, **kwargs)
It is a bit long because I tried to cover most important cases, and the code might not be the most readable, but I hope it is nice enough to follow.
How it works:
First, read the function's source code using inspect.getsourcelines. (Warning: inspect does not work for functions that were defined interactively. If you need that, maybe you can use dill instead, see this answer.)
Search for lines that look like comments, and replace them with print statements. (Right now only whole-line comments are replaced, but it shouldn't be difficult to extend that to trailing comments if desired.)
Parse the source code into an AST.
Walk the AST and replace all docstrings with print statements.
Compile the AST.
Execute the AST. This and the previous step contain some trickery to try to reconstruct the context that the function was originally defined in (e.g. globals, __future__ imports, line numbers for exception tracebacks). Also, since just executing the source would only re-define the function and not call it, we fix that with a simple decorator.
It works in Python 2 and 3 (at least with the tests below, which I ran in 2.7 and 3.6).
To use it, simply do:
result = trace(func, 2) # result = func(2)
Here is a slightly more elaborate test that I used while writing the code:
#!/usr/bin/env python
from trace_comments import trace
from dateutil.easter import easter, EASTER_ORTHODOX
def func(x):
"""Function docstring."""
result = x + 1
if result > 0:
# comment 2
return result
else:
# comment 3
return -1 * result
if __name__ == '__main__':
result1 = trace(func, 2)
print("result1 = {}".format(result1))
result2 = trace(func, -10)
print("result2 = {}".format(result2))
# Test that trace() does not permanently replace the function
result3 = func(42)
print("result3 = {}".format(result3))
print("-----")
print(trace(easter, 2018))
print("-----")
print(trace(easter, 2018, EASTER_ORTHODOX))

How to eliminate recursion in Python function containing control flow

I have a function of the form:
def my_func(my_list):
for i, thing in enumerate(my_list):
my_val = another_func(thing)
if i == 0:
# do some stuff
else:
if my_val == something:
return my_func(my_list[:-1])
# do some other stuff
The recursive part is getting called enough that I am getting a RecursionError, so I am trying to replace it with a while loop as explained here, but I can't work out how to reconcile this with the control flow statements in the function. Any help would be gratefully received!
There may be a good exact answer, but the most general (or maybe quick-and-dirty) way to switch from recursion to iteration is to manage the stack yourself. Just do manually what programming language does implicitly and have your own unlimited stack.
In this particular case there is tail recursion. You see, my_func recursive call result is not used by the caller in any way, it is immediately returned. What happens in the end is that the deepest recursive call's result bubbles up and is being returned as it is. This is what makes #outoftime's solution possible. We are only interested in into-recursion pass, as the return-from-recursion pass is trivial. So the into-recursion pass is replaced with iterations.
def my_func(my_list):
run = True
while run:
for i, thing in enumerate(my_list):
my_val = another_func(thing)
if i == 0:
# do some stuff
else:
if my_val == something:
my_list = my_list[:-1]
break
# do some other stuff
This is an iterative method.
Decorator
class TailCall(object):
def __init__(self, __function__):
self.__function__ = __function__
self.args = None
self.kwargs = None
self.has_params = False
def __call__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.has_params = True
return self
def __handle__(self):
if not self.has_params:
raise TypeError
if type(self.__function__) is TailCaller:
return self.__function__.call(*self.args, **self.kwargs)
return self.__function__(*self.args, **self.kwargs)
class TailCaller(object):
def __init__(self, call):
self.call = call
def __call__(self, *args, **kwargs):
ret = self.call(*args, **kwargs)
while type(ret) is TailCall:
ret = ret.__handle__()
return ret
#TailCaller
def factorial(n, prev=1):
if n < 2:
return prev
return TailCall(factorial)(n-1, n * prev)
To use this decorator simply wrap your function with #TailCaller decorator and return TailCall instance initialized with required params.
I'd like to say thank you for inspiration to #o2genum and to Kyle Miller who wrote an excellent article about this problem.
Despite how good is to remove this limitation, probably, you have to be
aware of why this feature is not officially supported.

Local const in recursion call

Well... Code first.
def magic(node):
spells_dict = {"AR_OP":ar_op_magic, "PRE_OP":pre_op_magic}
if node:
if node.text in spells_dict:
return spells_dict[node.text](node)
else:
return magic(node.l) + magic(node.r)
else:
return ""
During recursion calls there will be created a lot of spells_dict copies. I know that I can make that dict global, but I don't want, because this dict related to magic function only. So, I can create some class and put spells_dict and function to it, but it don't looks like a good solution.
Is there any way how I can do it with only one copy of spells_dict?
I don't see any problems with a MAGIC_SPELLS constant. You can locale it near the magic function, so you know, the belong together:
def magic_default(node):
return magic(node.l) + magic(node.r)
MAGIC_SPELLS = {
'AR_OP': ar_op_magic,
'PRE_OP': pre_op_magic,
}
def magic(node):
if node:
func = MAGIC_SPELLS.get(node.text, magic_default)
return func(node)
return ""

Workaround for equality of nested functions

I have a nested function that I'm using as a callback in pyglet:
def get_stop_function(stop_key):
def stop_on_key(symbol, _):
if symbol == getattr(pyglet.window.key, stop_key):
pyglet.app.exit()
return stop_on_key
pyglet.window.set_handler('on_key_press', get_stop_function('ENTER'))
But then I run into problems later when I need to reference the nested function again:
pyglet.window.remove_handler('on_key_press', get_stop_function('ENTER'))
This doesn't work because of the way python treats functions:
my_stop_function = get_stop_function('ENTER')
my_stop_function is get_stop_function('ENTER') # False
my_stop_function == get_stop_function('ENTER') # False
Thanks to two similar questions I understand what is going on but I'm not sure what the workaround is for my case. I'm looking through the pyglet source code and it looks like pyglet uses equality to find the handler to remove.
So my final question is: how can I override the inner function's __eq__ method (or some other dunder) so that identical nested functions will be equal?
(Another workaround would be to store a reference to the function myself, but that is duplicating pyglet's job, will get messy with many callbacks, and anyways I'm curious about this question!)
Edit: actually, in the questions I linked above, it's explained that methods have value equality but not reference equality. With nested functions, you don't even get value equality, which is all I need.
Edit2: I will probably accept Bi Rico's answer, but does anyone know why the following doesn't work:
def get_stop_function(stop_key):
def stop_on_key(symbol, _):
if symbol == getattr(pyglet.window.key, stop_key):
pyglet.app.exit()
stop_on_key.__name__ = '__stop_on_' + stop_key + '__'
stop_on_key.__eq__ = lambda x: x.__name__ == '__stop_on_' + stop_key + '__'
return stop_on_key
get_stop_function('ENTER') == get_stop_function('ENTER') # False
get_stop_function('ENTER').__eq__(get_stop_function('ENTER')) # True
You could create a class for your stop functions and define your own comparison method.
class StopFunction(object):
def __init__(self, stop_key):
self.stop_key = stop_key
def __call__(self, symbol, _):
if symbol == getattr(pyglet.window.key, self.stop_key):
pyglet.app.exit()
def __eq__(self, other):
try:
return self.stop_key == other.stop_key
except AttributeError:
return False
StopFunciton('ENTER') == StopFunciton('ENTER')
# True
StopFunciton('ENTER') == StopFunciton('FOO')
# False
the solution is to keep a dictionary containing the generated functions around,
so that when you make the second call, you get the same object as in the first call.
That is, simply build some memoization logic, or use one of the libraries
existing with memoizing decorators:
ALL_FUNCTIONS = {}
def get_stop_function(stop_key):
if not stop_key in ALL_FUNCTIONS:
def stop_on_key(symbol, _):
if symbol == getattr(pyglet.window.key, stop_key):
pyglet.app.exit()
ALL_FUNCTIONS[stop_key] = stop_on_key
else:
stop_on_key = ALL_FUNCTIONS[stop_key]
return stop_on_key
You can generalize Bi Rico's solution to allow wrapping any functions up with some particular equality function pretty easily.
The first problem is defining what the equality function should check. I'm guessing for this case, you want the code to be identical (meaning functions created from the same def statement will be equal, but two functions created from character-for-character copies of the def statement will not), and the closures to be equal (meaning that if you call get_stop_function with two equal but non-identical stop_keys the functions will be equal), and nothing else to be relevant. But that's just a guess, and there are many other possibilities.
Then you just wrap a function the same way you'd wrap any other kind of object; just make sure __call__ is one of the things you delegate:
class EqualFunction(object):
def __init__(self, f):
self.f = f
def __eq__(self, other):
return (self.__code__ == other.__code__ and
all(x.cell_contents == y.cell_contents
for x, y in zip(self.__closure__, other.__closure__)))
def __getattr__(self, attr):
return getattr(self.f, attr)
def __call__(self, *args, **kwargs):
return self.f(*args, **kwargs)
If you want to support other dunder methods that aren't required to go through getattr (I don't think any of them are critical for functions, but I could be wrong…), either do it explicitly (as with __call__) or loop over them and add a generic wrapper to the type for each one.
To use the wrapper:
def make_f(i):
def f():
return i
return EqualFunction(f)
f1 = f(0)
f2 = f(0.0)
assert f1 == f2
Or, notice that EqualFunction actually works as a decorator, which may be more readable.
So, for your code:
def get_stop_function(stop_key):
#EqualFunction
def stop_on_key(symbol, _):
if symbol == getattr(pyglet.window.key, stop_key):
pyglet.app.exit()
return stop_on_key

Categories

Resources