How to compute the nth root of a very big integer - python

I need a way to compute the nth root of a long integer in Python.
I tried pow(m, 1.0/n), but it doesn't work:
OverflowError: long int too large to convert to float
Any ideas?
By long integer I mean REALLY long integers like:

If it's a REALLY big number. You could use a binary search.
def find_invpow(x,n):
"""Finds the integer component of the n'th root of x,
an integer such that y ** n <= x < (y + 1) ** n.
high = 1
while high ** n <= x:
high *= 2
low = high/2
while low < high:
mid = (low + high) // 2
if low < mid and mid**n < x:
low = mid
elif high > mid and mid**n > x:
high = mid
return mid
return mid + 1
For example:
>>> x = 237734537465873465
>>> n = 5
>>> y = find_invpow(x,n)
>>> y
>>> y**n <= x <= (y+1)**n
>>> x = 119680039660309643568856114803834088331723464504673392511960931441>
>>> n = 45
>>> y = find_invpow(x,n)
>>> y
>>> y**n <= x < (y+1)**n
>>> find_invpow(y**n,n) == y

Gmpy is a C-coded Python extension module that wraps the GMP library to provide to Python code fast multiprecision arithmetic (integer, rational, and float), random number generation, advanced number-theoretical functions, and more.
Includes a root function:
x.root(n): returns a 2-element tuple (y,m), such that y is the
(possibly truncated) n-th root of x; m, an ordinary Python int,
is 1 if the root is exact (x==y**n), else 0. n must be an ordinary
Python int, >=0.
For example, 20th root:
>>> import gmpy
>>> i0=11968003966030964356885611480383408833172346450467339251
>>> m0=gmpy.mpz(i0)
>>> m0
>>> m0.root(20)
(mpz(567), 0)

You can make it run slightly faster by avoiding the while loops in favor of setting low to 10 ** (len(str(x)) / n) and high to low * 10. Probably better is to replace the len(str(x)) with the bitwise length and using a bit shift. Based on my tests, I estimate a 5% speedup from the first and a 25% speedup from the second. If the ints are big enough, this might matter (and the speedups may vary). Don't trust my code without testing it carefully. I did some basic testing but may have missed an edge case. Also, these speedups vary with the number chosen.
If the actual data you're using is much bigger than what you posted here, this change may be worthwhile.
from timeit import Timer
def find_invpow(x,n):
"""Finds the integer component of the n'th root of x,
an integer such that y ** n <= x < (y + 1) ** n.
high = 1
while high ** n < x:
high *= 2
low = high/2
while low < high:
mid = (low + high) // 2
if low < mid and mid**n < x:
low = mid
elif high > mid and mid**n > x:
high = mid
return mid
return mid + 1
def find_invpowAlt(x,n):
"""Finds the integer component of the n'th root of x,
an integer such that y ** n <= x < (y + 1) ** n.
low = 10 ** (len(str(x)) / n)
high = low * 10
while low < high:
mid = (low + high) // 2
if low < mid and mid**n < x:
low = mid
elif high > mid and mid**n > x:
high = mid
return mid
return mid + 1
x = 237734537465873465
n = 5
tests = 10000
print "Norm", Timer('find_invpow(x,n)', 'from __main__ import find_invpow, x,n').timeit(number=tests)
print "Alt", Timer('find_invpowAlt(x,n)', 'from __main__ import find_invpowAlt, x,n').timeit(number=tests)
Norm 0.626754999161
Alt 0.566340923309

If you are looking for something standard, fast to write with high precision. I would use decimal and adjust the precision (getcontext().prec) to at least the length of x.
Code (Python 3.0)
from decimal import *
x = '11968003966030964356885611480383408833172346450467339251\
minprec = 27
if len(x) > minprec: getcontext().prec = len(x)
else: getcontext().prec = minprec
x = Decimal(x)
power = Decimal(1)/Decimal(3)
answer = x**power
ranswer = answer.quantize(Decimal('1.'), rounding=ROUND_UP)
diff = x - ranswer**Decimal(3)
if diff == Decimal(0):
print("x is the cubic number of", ranswer)
print("x has a cubic root of ", answer)
x is the cubic number of 22873918786185635329056863961725521583023133411

Oh, for numbers that big, you would use the decimal module.
ns: your number as a string
ns = "11968003966030964356885611480383408833172346450467339251196093144141045683463085291115677488411620264826942334897996389485046262847265769280883237649461122479734279424416861834396522819159219215308460065265520143082728303864638821979329804885526557893649662037092457130509980883789368448042961108430809620626059287437887495827369474189818588006905358793385574832590121472680866521970802708379837148646191567765584039175249171110593159305029014037881475265618958103073425958633163441030267478942720703134493880117805010891574606323700178176718412858948243785754898788359757528163558061136758276299059029113119763557411729353915848889261125855717014320045292143759177464380434854573300054940683350937992500211758727939459249163046465047204851616590276724564411037216844005877918224201569391107769029955591465502737961776799311859881060956465198859727495735498887960494256488224613682478900505821893815926193600121890632"
from decimal import Decimal
d = Decimal(ns)
one_third = Decimal("0.3333333333333333")
print d ** one_third
and the answer is: 2.287391878618402702753613056E+305
TZ pointed out that this isn't accurate... and he's right. Here's my test.
from decimal import Decimal
def nth_root(num_decimal, n_integer):
exponent = Decimal("1.0") / Decimal(n_integer)
return num_decimal ** exponent
def test():
ns = "11968003966030964356885611480383408833172346450467339251196093144141045683463085291115677488411620264826942334897996389485046262847265769280883237649461122479734279424416861834396522819159219215308460065265520143082728303864638821979329804885526557893649662037092457130509980883789368448042961108430809620626059287437887495827369474189818588006905358793385574832590121472680866521970802708379837148646191567765584039175249171110593159305029014037881475265618958103073425958633163441030267478942720703134493880117805010891574606323700178176718412858948243785754898788359757528163558061136758276299059029113119763557411729353915848889261125855717014320045292143759177464380434854573300054940683350937992500211758727939459249163046465047204851616590276724564411037216844005877918224201569391107769029955591465502737961776799311859881060956465198859727495735498887960494256488224613682478900505821893815926193600121890632"
nd = Decimal(ns)
cube_root = nth_root(nd, 3)
print (cube_root ** Decimal("3.0")) - nd
if __name__ == "__main__":
It's off by about 10**891

Possibly for your curiosity:
This could be the technique that Maple would use to actually find the nth root of large numbers.
Pose the fact that x^n - 11968003.... = 0 mod p, and go from there...

I may suggest four methods for solving your task. First is based on Binary Search. Second is based on Newton's Method. Third is based on Shifting n-th Root Algorithm. Fourth is called by me Chord-Tangent method described by me in picture here.
Binary Search was already implemented in many answers above. I just introduce here my own vision of it and its implementation.
As alternative I also implement Optimized Binary Search method (marked Opt). This method just starts from range [hi / 2, hi) where hi is equal to 2^(num_bit_length / k) if we're computing k-th root.
Newton's Method is new here, as I see it wasn't implemented in other answers. It is usually considered to be faster than Binary Search, although my own timings in code below don't show any speedup. Hence this method here is just for reference/interest.
Shifting Method is 30-50% faster than optimized binary search method, and should be even faster if implemented in C++, because C++ has fast 64 bit arithemtics which is partially used in this method.
Chord-Tangent Method:
Chord-Tangent Method is invented by me on piece of paper (see image above), it is inspired and is an improvement of Newton method. Basically I draw a Chord and a Tangent Line and find intersection with horizontal line y = n, these two intersections form lower and upper bound approximations of location of root solution (x0, n) where n = x0 ^ k. This method appeared to be fastest of all, while all other methods do more than 2000 iterations, this method does just 8 iterations, for the case of 8192-bit numbers. So this method is 200-300x times faster than previous (by speed) Shifting Method.
As an example I generate really huge random integer of 8192 bits in size. And measure timings of finding cubic root with both methods.
In test() function you can see that I passed k = 3 as root's power (cubic root), you can pass any power instead of 3.
Try it online!
def binary_search(begin, end, f, *, niter = [0]):
while begin < end:
niter[0] += 1
mid = (begin + end) >> 1
if f(mid):
begin = mid + 1
end = mid
return begin
def binary_search_kth_root(n, k, *, verbose = False):
niter = [0]
res = binary_search(0, n + 1, lambda root: root ** k < n, niter = niter)
if verbose:
print('Binary Search iterations:', niter[0])
return res
def binary_search_opt_kth_root(n, k, *, verbose = False):
niter = [0]
hi = 1 << (n.bit_length() // k - 1)
while hi ** k <= n:
niter[0] += 1
hi <<= 1
res = binary_search(hi >> 1, hi, lambda root: root ** k < n, niter = niter)
if verbose:
print('Binary Search Opt iterations:', niter[0])
return res
def newton_kth_root(n, k, *, verbose = False):
f = lambda x: x ** k - n
df = lambda x: k * x ** (k - 1)
x, px, niter = n, 2 * n, [0]
while abs(px - x) > 1:
niter[0] += 1
px = x
x -= f(x) // df(x)
if verbose:
print('Newton Method iterations:', niter[0])
mini, minv = None, None
for i in range(-2, 3):
v = abs(f(x + i))
if minv is None or v < minv:
mini, minv = i, v
return x + mini
def shifting_kth_root(n, k, *, verbose = False):
B_bits = 64
r, y = 0, 0
B = 1 << B_bits
Bk_bits = B_bits * k
Bk_mask = (1 << Bk_bits) - 1
niter = [0]
for i in range((n.bit_length() + Bk_bits - 1) // Bk_bits - 1, -1, -1):
alpha = (n >> (i * Bk_bits)) & Bk_mask
B_y = y << B_bits
Bk_yk = (y ** k) << Bk_bits
Bk_r_alpha = (r << Bk_bits) + alpha
Bk_yk_Bk_r_alpha = Bk_yk + Bk_r_alpha
beta = binary_search(1, B, lambda beta: (B_y + beta) ** k <= Bk_yk_Bk_r_alpha, niter = niter) - 1
y, r = B_y + beta, Bk_r_alpha - ((B_y + beta) ** k - Bk_yk)
if verbose:
print('Shifting Method iterations:', niter[0])
return y
def chord_tangent_kth_root(n, k, *, verbose = False):
niter = [0]
hi = 1 << (n.bit_length() // k - 1)
while hi ** k <= n:
niter[0] += 1
hi <<= 1
f = lambda x: x ** k
df = lambda x: k * x ** (k - 1)
x_begin, x_end = hi >> 1, hi
y_begin, y_end = f(x_begin), f(x_end)
for icycle in range(1 << 30):
if x_end - x_begin <= 1:
niter[0] += 1
if 0: # Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1
y_mid = f(x_mid)
if y_mid > n:
x_end, y_end = x_mid, y_mid
x_begin, y_begin = x_mid, y_mid
# (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) // (y_end - y_begin)
y_n = f(x_n)
tangent_x = x_n + (n - y_n) // df(x_n) + 1
chord_x = x_n + (n - y_n) * (x_end - x_n) // (y_end - y_n)
assert chord_x <= tangent_x, (chord_x, tangent_x)
x_begin, x_end = chord_x, tangent_x
y_begin, y_end = f(x_begin), f(x_end)
assert y_begin <= n, (chord_x, y_begin, n, n - y_begin)
assert y_end > n, (icycle, tangent_x - binary_search_kth_root(n, k), y_end, n, y_end - n)
if verbose:
print('Chord Tangent Method iterations:', niter[0])
return x_begin
def test():
import random, timeit
nruns = 3
bits = 8192
n = random.randrange(1 << (bits - 1), 1 << bits)
a = binary_search_kth_root(n, 3, verbose = True)
b = binary_search_opt_kth_root(n, 3, verbose = True)
c = newton_kth_root(n, 3, verbose = True)
d = shifting_kth_root(n, 3, verbose = True)
e = chord_tangent_kth_root(n, 3, verbose = True)
assert abs(a - b) <= 0 and abs(a - c) <= 1 and abs(a - d) <= 1 and abs(a - e) <= 1, (a - b, a - c, a - d, a - e)
print('Binary Search timing:', round(timeit.timeit(lambda: binary_search_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
print('Binary Search Opt timing:', round(timeit.timeit(lambda: binary_search_opt_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
print('Newton Method timing:', round(timeit.timeit(lambda: newton_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
print('Shifting Method timing:', round(timeit.timeit(lambda: shifting_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
print('Chord Tangent Method timing:', round(timeit.timeit(lambda: chord_tangent_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
if __name__ == '__main__':
Binary Search iterations: 8192
Binary Search Opt iterations: 2732
Newton Method iterations: 9348
Shifting Method iterations: 2752
Chord Tangent Method iterations: 8
Binary Search timing: 0.506 sec
Binary Search Opt timing: 0.05 sec
Newton Method timing: 2.09 sec
Shifting Method timing: 0.03 sec
Chord Tangent Method timing: 0.001 sec

I came up with my own answer, which takes #Mahmoud Kassem's idea, simplifies the code, and makes it more reusable:
def cube_root(x):
return decimal.Decimal(x) ** (decimal.Decimal(1) / decimal.Decimal(3))
I tested it in Python 3.5.1 and Python 2.7.8, and it seemed to work fine.
The result will have as many digits as specified by the decimal context the function is run in, which by default is 28 decimal places. According to the documentation for the power function in the decimal module, "The result is well-defined but only “almost always correctly-rounded”.". If you need a more accurate result, it can be done as follows:
with decimal.localcontext() as context:
context.prec = 50

In older versions of Python, 1/3 is equal to 0. In Python 3.0, 1/3 is equal to 0.33333333333 (and 1//3 is equal to 0).
So, either change your code to use 1/3.0 or switch to Python 3.0 .

Try converting the exponent to a floating number, as the default behaviour of / in Python is integer division

Well, if you're not particularly worried about precision, you could convert it to a sting, chop off some digits, use the exponent function, and then multiply the result by the root of how much you chopped off.
E.g. 32123 is about equal to 32 * 1000, the cubic root is about equak to cubic root of 32 * cubic root of 1000. The latter can be calculated by dividing the number of 0s by 3.
This avoids the need for the use of extension modules.


Reduce time /space complexity of simple loop

So basically if i have an iteration like this in python
Ive editted the question to include my full code
class Solution:
def myPow(self, x: float, n: int) -> float:
temp = [];
span = range(1,abs(n))
if n ==0:
return 1
if abs(n)==1:
for y in span:
if y == 1:
temp = []
temp.append(temp[-1] * x)
if(n < 0):
return 1/temp[-1]
return temp[-1]
The problem link is : Pow(x,n)-leetcode
How can I modify this to conserve memory and time. Is there another data structure i can use. Im just learning python....
ive modified the code to use a variable instead of a list for the temp data
class Solution:
def myPow(self, x: float, n: int) -> float:
span = range(1,abs(n))
if n ==0:
return 1
if abs(n)==1:
temp = x
for y in span:
if y == 1:
temp = x*x
temp = temp * x
if(n < 0):
return 1/temp
return temp
I still have a problem with my time complexity.
Its working for many testcases, however when it trys to run with x = 0.00001 and n = 2147483647. The time limit issue arises
To reduce the time complexity you can divide the work each time by taking x to the power of 2 and dividing the exponent by two. This makes a logarithmic time algorithm since the exponent is halved at each step.
Consider the following examples:
10^8 = 10^(2*4) = (10^2)^4 = (10*10)^4
Now, there is one edge case. When the exponent is an odd number you can't integer divide it by 2. So in that case you need to multiply the results by the base one additional time.
The following is a direct recursive implementation of the above idea:
class Solution:
def myPow(self, x: float, n: int) -> float:
sign = -1 if n < 0 else 1
n = abs(n)
def helper(x, n):
if n == 1: return x
if n == 0: return 1
if n % 2 == 1:
return helper(x*x, n // 2) * x
return helper(x*x, n // 2)
res = helper(x, n)
if sign == -1:
return 1/res
return res
Note that we have taken abs of the exponent and stored the sign and deal with it at the end.
Instead of iterating from 1 to n, use divide-and-conquer: divide the exponent by 2 and use recursion to get that power, and then square that result. If n was odd, multiply one time more with x:
class Solution:
def myPow(self, x: float, n: int) -> float:
if n == 0:
return 1
if n == 1:
return x
if n < 0:
return self.myPow(1/x, -n)
temp = self.myPow(x, n // 2)
temp *= temp
if n % 2:
temp *= x
return temp
A simple naive solution might be:
def myPow(x: float, n: int) -> float:
## -----------------------
## if we have a negative n then invert x and take the absolute value of n
## -----------------------
if n < 0:
x = 1/x
n = -n
## -----------------------
retval = 1
for _ in range(n):
retval *= x
return retval
While this technically works, you will wait until the cows come home to get a result for:
x = 0.00001 and n = 2147483647
So we need to find a shortcut. Lets' consider 2^5. Our naïve method would calculate that as:
(((2 * 2) * 2) * 2) * 2 == 32
However, what might we observe about the problem if we group some stuff together in a different way:
(2 * 2) * (2 * 2) * 2 == 32
((2 * 2) * (2 * 2) * 2) * ((2 * 2) * (2 * 2) * 2) == 32 * 32 = 1024
We might observe that we only technically need to calculate
(2 * 2) * (2 * 2) * 2 == 32
once and use it twice to get 2^10.
Similarly we only need to calcuate:
2 * 2 = 4
once and use it twice to get 2^5....
This suggests a recursion to me.
Let's modify our first attempt to use this divide and concur method.
def myPow2(x: float, n: int) -> float:
## -----------------------
## if we have a negative n then invert x and take the absolute value of n
## -----------------------
if n < 0:
x = 1/x
n = -n
## -----------------------
## -----------------------
## We only need to calculate approximately half the work and use it twice
## at any step.
## -----------------------
def _recurse(x, n):
if n == 0:
return 1
res = _recurse(x, n//2) # calculate it once
res = res * res # use it twice
return res * x if n % 2 else res # if n is odd, multiple by x one more time (see 2^5 above)
## -----------------------
return _recurse(x, n)
Now let's try:
print(myPow2(2.0, 0))
print(myPow2(2.0, 1))
print(myPow2(2.0, 5))
print(myPow2(2.1, 3))
print(myPow2(2.0, -2))
print(myPow2(0.00001, 2147483647))
That gives me:
If you have to loop, you have to lope and there is nothing that can be done. Loops in python are slow. That said you may not have to loop and if you do have to loop, it may be possible to push this loop to a highly optimised internal function. Tell us what you are trying to do (not how you think you have to do it, appending elements to a lis may or may not be needed). Always recall the two rules of program optimisation General Rule: Don't do it. Rule for experts: Don't do it yet. Make it work before you make it fast, who knows, it may be fast enough.

Why classical division ("/") for large integers is much slower than integer division ("//")?

I ran into a problem: The code was very slow for 512 bit odd integers if you use classical division for (p-1)/2. But with floor division it works instantly. Is it caused by float conversion?
def solovayStrassen(p, iterations):
for i in range(iterations):
a = random.randint(2, p - 1)
if gcd(a, p) > 1:
return False
first = pow(a, int((p - 1) / 2), p)
j = (Jacobian(a, p) + p) % p
if first != j:
return False
return True
The full code
import random
from math import gcd
#Jacobian symbol
def Jacobian(a, n):
if (a == 0):
return 0
ans = 1
if (a < 0):
a = -a
if (n % 4 == 3):
ans = -ans
if (a == 1):
return ans
while (a):
if (a < 0):
a = -a
if (n % 4 == 3):
ans = -ans
while (a % 2 == 0):
a = a // 2
if (n % 8 == 3 or n % 8 == 5):
ans = -ans
a, n = n, a
if (a % 4 == 3 and n % 4 == 3):
ans = -ans
a = a % n
if (a > n // 2):
a = a - n
if (n == 1):
return ans
return 0
def solovayStrassen(p, iterations):
for i in range(iterations):
a = random.randint(2, p - 1)
if gcd(a, p) > 1:
return False
first = pow(a, int((p - 1) / 2), p)
j = (Jacobian(a, p) + p) % p
if first != j:
return False
return True
def findFirstPrime(n, k):
while True:
if solovayStrassen(n,k):
return n
a = random.getrandbits(512)
if a%2==0:
As noted in comments, int((p - 1) / 2) can produce garbage if p is an integer with more than 53 bits. Only the first 53 bits of p-1 are retained when converting to float for the division.
>>> p = 123456789123456789123456789
>>> (p-1) // 2
>>> hex(_)
>>> int((p-1) / 2)
>>> hex(_) # lots of trailing zeroes
Of course the theory underlying the primality test relies on using exactly the infinitely precise value of (p-1)/2, not some approximation more-or-less good to only the first 53 most-significant bits.
As also noted in a comment, using garbage is likely to make this part return earlier, not later:
if first != j:
return False
So why is it much slower over all? Because findFirstPrime() has to call solovayStrassen() many more times to find garbage that passes by sheer blind luck.
To see this, change the code to show how often the loop is trying:
def findFirstPrime(n, k):
count = 0
while True:
count += 1
if count % 1000 == 0:
print(f"at count {count:,}")
if solovayStrassen(n,k):
return n, count
Then add, e.g.,
at the start of the main program so you can get reproducible results.
Using floor (//) division, it runs fairly quickly, displaying
(6170518232878265099306454685234429219657996228748920426206889067017854517343512513954857500421232718472897893847571955479036221948870073830638539006377457, 906)
So it found a probable prime on the 906th try.
But with float (/) division, I never saw it succeed by blind luck:
at count 1,000
at count 2,000
at count 3,000
at count 1,000,000
Gave up then - "garbage in, garbage out".
One other thing to note, in passing: the + p in:
j = (Jacobian(a, p) + p) % p
has no effect on the value of j. Right? p % p is 0.

Efficiently generating Stern's Diatomic Sequence

Stern's Diatomic Sequence can be read about in more details over here; however, for my purpose I will define it now.
Definition of Stern's Diatomic Sequence
Let n be a number to generate the fusc function out of. Denoted fusc(n).
If n is 0 then the returned value is 0.
If n is 1 then the returned value is 1.
If n is even then the returned value is fusc(n / 2).
If n is odd then the returned value is fusc((n - 1) / 2) + fusc((n + 1) / 2).
Currently, my Python code brute forces through most of the generation, other than the dividing by two part since it will always yield no change.
def fusc (n):
if n <= 1:
return n
while n > 2 and n % 2 == 0:
n /= 2
return fusc((n - 1) / 2) + fusc((n + 1) / 2)
However, my code must be able to handle digits in the magnitude of 1000s millions of bits, and recursively running through the function thousands millions of times does not seem very efficient or practical.
Is there any way I could algorithmically improve my code such that massive numbers can be passed through without having to recursively call the function so many times?
With memoization for a million bits, the recursion stack would be extremely large. We can first try to look at a sufficiently large number which we can work by hand, fusc(71) in this case:
fusc(71) = fusc(35) + fusc(36)
fusc(35) = fusc(17) + fusc(18)
fusc(36) = fusc(18)
fusc(71) = 1 * fusc(17) + 2 * fusc(18)
fusc(17) = fusc(8) + fusc(9)
fusc(18) = fusc(9)
fusc(71) = 1 * fusc(8) + 3 * fusc(9)
fusc(8) = fusc(4)
fusc(9) = fusc(4) + fusc(5)
fusc(71) = 4 * fusc(4) + 3 * fusc(5)
fusc(4) = fusc(2)
fusc(3) = fusc(1) + fusc(2)
fusc(71) = 7 * fusc(2) + 3 * fusc(3)
fusc(2) = fusc(1)
fusc(3) = fusc(1) + fusc(2)
fusc(71) = 11 * fusc(1) + 3 * fusc(2)
fusc(2) = fusc(1)
fusc(71) = 14 * fusc(1) = 14
We realize that we can avoid recursion completely in this case as we can always express fusc(n) in the form a * fusc(m) + b * fusc(m+1) while reducing the value of m to 0. From the example above, you may find the following pattern:
if m is odd:
a * fusc(m) + b * fusc(m+1) = a * fusc((m-1)/2) + (b+a) * fusc((m+1)/2)
if m is even:
a * fusc(m) + b * fusc(m+1) = (a+b) * fusc(m/2) + b * fusc((m/2)+1)
Therefore, you may use a simple loop function to solve the problem in O(lg(n)) time
def fusc(n):
if n == 0: return 0
a = 1
b = 0
while n > 0:
if n%2:
b = b + a
n = (n-1)/2
a = a + b
n = n/2
return b
lru_cache works wonders in your case. make sure maxsize is a power of 2. may need to fiddle a bit with that size for your application. cache_info() will help with that.
also use // instead of / for integer division.
from functools import lru_cache
#lru_cache(maxsize=512, typed=False)
def fusc(n):
if n <= 1:
return n
while n > 2 and n % 2 == 0:
n //= 2
return fusc((n - 1) // 2) + fusc((n + 1) // 2)
and yes, this is just meomization as proposed by Filip Malczak.
you might gain an additional tiny speedup using bit-operations in the while loop:
while not n & 1: # as long as the lowest bit is not 1
n >>= 1 # shift n right by one
here is a simple way of doing meomzation 'by hand':
def fusc(n, _mem={}): # _mem will be the cache of the values
# that have been calculated before
if n in _mem: # if we know that one: just return the value
return _mem[n]
if n <= 1:
return n
while not n & 1:
n >>= 1
if n == 1:
return 1
ret = fusc((n - 1) // 2) + fusc((n + 1) // 2)
_mem[n] = ret # store the value for next time
return ret
after reading a short article by dijkstra himself a minor update.
the article states, that f(n) = f(m) if the fist and last bit of m are the same as those of n and the bits in between are inverted. the idea is to get n as small as possible.
that is what the bitmask (1<<n.bit_length()-1)-2 is for (first and last bits are 0; those in the middle 1; xoring n with that gives m as described above).
i was only able to do small benchmarks; i'm interested if this is any help at all for the magitude of your input... this will reduce the memory for the cache and hopefully bring some speedup.
def fusc_ed(n, _mem={}):
if n <= 1:
return n
while not n & 1:
n >>= 1
if n == 1:
return 1
# bit invert the middle bits and check if this is smaller than n
m = n ^ (1<<n.bit_length()-1)-2
n = m if m < n else n
if n in _mem:
return _mem[n]
ret = fusc(n >> 1) + fusc((n >> 1) + 1)
_mem[n] = ret
return ret
i had to increase the recursion limit:
import sys
sys.setrecursionlimit(10000) # default limit was 1000
benchmarking gave strange results; using the code below and making sure that i always started a fresh interperter (having an empty _mem) i sometimes got significantly better runtimes; on other occasions the new code was slower...
benchmarking code:
ti = timeit('fusc(n)', setup='from __main__ import fusc, n', number=1)
ti = timeit('fusc_ed(n)', setup='from __main__ import fusc_ed, n', number=1)
and these are three random results i got:
that is where i stopped...

How to make perfect power algorithm more efficient?

I have the following code:
def isPP(n):
pos = [int(i) for i in range(n+1)]
pos = pos[2:] ##to ignore the trivial n** 1 == n case
y = []
for i in pos:
for it in pos:
if i** it == n:
#return list((i,it))
if len(y) <1:
return None
return list(y[0])
Which works perfectly up until ~2000, since I'm storing far too much in memory. What can I do to make it work efficiently for large numbers (say, 50000 or 100000). I tried to make it end after finding one case, but my algorithm is still far too inefficient if the number is large.
Any tips?
A number n is a perfect power if there exists a b and e for which b^e = n. For instance 216 = 6^3 = 2^3 * 3^3 is a perfect power, but 72 = 2^3 * 3^2 is not.
The trick to determining if a number is a perfect power is to know that, if the number is a perfect power, then the exponent e must be less than log2 n, because if e is greater then 2^e will be greater than n. Further, it is only necessary to test prime es, because if a number is a perfect power to a composite exponent it will also be a perfect power to the prime factors of the composite component; for instance, 2^15 = 32768 = 32^3 = 8^5 is a perfect cube root and also a perfect fifth root.
The function isPerfectPower shown below tests each prime less than log2 n by first computing the integer root using Newton's method, then powering the result to check if it is equal to n. Auxiliary function primes compute a list of prime numbers by the Sieve of Eratosthenes, iroot computes the integer kth-root by Newton's method, and ilog computes the integer logarithm to base b by binary search.
def primes(n): # sieve of eratosthenes
i, p, ps, m = 0, 3, [2], n // 2
sieve = [True] * m
while p <= n:
if sieve[i]:
for j in range((p*p-3)/2, m, p):
sieve[j] = False
i, p = i+1, p+2
return ps
def iroot(k, n): # assume n > 0
u, s, k1 = n, n+1, k-1
while u < s:
s = u
u = (k1 * u + n // u ** k1) // k
return s
def ilog(b, n): # max e where b**e <= n
lo, blo, hi, bhi = 0, 1, 1, b
while bhi < n:
lo, blo, hi, bhi = hi, bhi, hi+hi, bhi*bhi
while 1 < (hi - lo):
mid = (lo + hi) // 2
bmid = blo * pow(b, (mid - lo))
if n < bmid: hi, bhi = mid, bmid
elif bmid < n: lo, blo = mid, bmid
else: return mid
if bhi == n: return hi
return lo
def isPerfectPower(n): # x if n == x ** y, or False
for p in primes(ilog(2,n)):
x = iroot(p, n)
if pow(x, p) == n: return x
return False
There is further discussion of the perfect power predicate at my blog.
IIRC, it's far easier to iteratively check "Does it have a square root? Does it have a cube root? Does it have a fourth root? ..." You will very quickly get to the point where putative roots have to be between 1 and 2, at which point you can stop.
I think a better way would be implementing this "hack":
import math
def isPP(n):
range = math.log(n)/math.log(2)
range = (int)(range)
result = []
for i in xrange(n):
exponent = (int)(math.log(n)/math.log(i))
for j in [exponent-1, exponent, exponent+1]:
if i ** j == n:
return result
print isPP(10000)
The hack uses the fact that:
if log(a)/log(b) = c,
then power(b,c) = a
Since this calculation can be a bit off in floating points giving really approximate results, exponent is checked to the accuracy of +/- 1.
You can make necessary adjustments for handling corner cases like n=1, etc.
a relevant improvement would be:
import math
def isPP(n):
# first have a look at the length of n in binary representation
ln = int(math.log(n)/math.log(2)) + 1
y = []
for i in range(n+1):
if (i <= 1):
# calculate max power
li = int(math.log(i)/math.log(2))
mxi = ln / li + 1
for it in range(mxi):
if (it <= 1):
if i ** it == n:
# break if you only need 1
if len(y) <1:
return None
return list(y[0])

Root of polynomial using bisection

I'm new to python and i'm having a hard time trying to find the root of a polynomial via using the bisection method. So far I have 2 methods. One for evaluating the polynomial at value x
def eval(x, poly):
Evaluate the polynomial at the value x.
poly is a list of coefficients from lowest to highest.
:param x: Argument at which to evaluate
:param poly: The polynomial coefficients, lowest order to highest
:return: The result of evaluating the polynomial at x
result = poly[0]
for i in range(1, len(poly)):
result = result + poly[i] * x**i
return result
The next method is supposed to use bisection to find the root of the polynomials given
def bisection(a, b, poly, tolerance):
poly(a) <= 0
poly(b) >= 0
Assume that poly(a) <= 0 and poly(b) >= 0.
:param a: poly(a) <= 0 Raises an exception if not true
:param b: poly(b) >= 0 Raises an exception if not true
:param poly: polynomial coefficients, low order first
:param tolerance: greater than 0
:return: a value between a and b that is within tolerance of a root of the polynomial
How would I find the root using bisection? I have been provided a test script to test these out.
EDIT: I followed the pseudocode and ended up with this:
def bisection(a, b, poly, tolerance):
#poly(a) <= 0
#poly(b) >= 0
difference = abs(a-b)
xmid = (a-b)/2
n = 1
nmax = 60
while n <= nmax:
mid = (a-b) / 2
if poly(mid) == 0 or (b - a)/2 < tolerance:
n = n + 1
if sign(poly(mid)) == sign(poly(a)):
a = mid
b = mid
return xmid
is this correct? I havent been able to test it because of indentation errors with the return xmid statement.
Your code seems fine, besides the mess with xmid and mid. mid = (a + b) / 2 instead of mid = (a - b) / 2 and you don't need the difference variable.
Cleaned it up a bit:
def sign(x):
return -1 if x < 0 else (1 if x > 0 else 0)
def bisection(a, b, poly, tolerance):
mid = a # will be overwritten
for i in range(60):
mid = (a+b) / 2
if poly(mid) == 0 or (b - a)/2 < tolerance:
return mid
if sign(poly(mid)) == sign(poly(a)):
a = mid
b = mid
return mid
print(bisection(-10**10, 10**10, lambda x: x**5 - x**4 - x**3 - x**2 - x + 9271, 0.00001))

