z3 power mod constraint throws Z3Exception - python

I have set up a large table of constraints, as such:
value_n=1319677575750664593269928592823122088231102756079484792845079638972350094036132691958059688916820689747030184510113829740113041751835106779040361964704732781674939111970921386382076965209456061359009064103696057581141587585573457409644899626416953339600078178839886677162434056466419378007334549881514299423667452655244230384294862385143577667675301828094532154331478209619704980444350466007158636501159025323438259949492591777519923746198537834074290240109117233
value_a=342691274423584349791056473125411134098438481683180895903867322580082465637958239934448846612075098614285219068292647847040005389205050272757677433858953494823759242244829180636673330238802274884481886959589470615694177803698174718690765848886907721839332228324388423577817424201793491303906581758910921445168317877902384290187158732544836897687264045955879735934334920222645677067477720579027115764510363046519727124832263309244176383203195395120143396393280178
value_c=966614148477848624030200621625258849031380625915218872283078816676928740125344244698257425964002821854099461580094436358123535723290845114669846893220163142039503329214545667563764436310988727558365456532947041619144359286499858057080380889267035885314605337688144688305837511055021190348080983083350965687703714494963398464234969815079998913589315977406087449734769882954669673024108621512376903207095402574078110055132403183173806026764738715566878703793779398
value_sigma=8464844919528024050776743166493391951099685330616502566136514764715071037185126965166263456592507665693764888945402100735300264199432720007286127252491425
value_tau=10382212427405516607551590772908169835382410408959523296597160297199683505639053103035865826384731049385602808678770211265514520018139117856623297634063128
value_rho=2180308065203712074036416623837957892343118044869765657224934920507007488415740097667016690791656923501004556879088800046225392947371826678430341848539725784764496743600951899723067587234367926146679537689197024965709655834286819439034454072201704826252187806261154259276367446581026005274199093558591833001447262700832466548485390911986069977032843293472245605677321290489737169381053552716457700040526454320831401386889040870267821074531233371107830697865854527
sigma = Int('sigma')
tau = Int('tau')
rho = Int('rho')
random_1340 = Int('random_1340')
random_1160 = Int('random_1160')
random_335 = Int('random_335')
n = Int('n')
a = Int('a')
c = Int('c')
p = Int('p')
q = Int('q')
a_to_n = Int('a_to_n')
a_to_flag = Int('a_to_flag')
a_to_p_1 = Int('a_to_p_1')
kappa = Int('kappa')
flag = Int('flag')
equations = [
flag < p,
n == value_n, a == value_a, c == value_c, sigma == value_sigma, tau == value_tau, rho == value_rho,
a_to_n == a ** n,
a_to_flag == a ** flag,
a_to_p_1 == a ** (p-1),
c == ((((a_to_n % n) ** kappa) % n) * (a_to_flag % n)) % n,
n == p ** 2 * q,
kappa >= 1, kappa <= n - 1,
p > q, 2 * q > p,
a >= 2, a <= n - 1 , a_to_p1 % (p * p) != 1,
sigma >= 2, sigma <= q - 1,
random_1340 >= 1, random_1340 <= 2**1340, random_1160 >= 1, random_1160 <= 2**1160, random_335 >= 1, random_335 <= 2 ** 335,
tau == (p * p * sigma - random_1340) / q ** 2 + random_335,
rho == p * p * sigma + q * q * tau + random_1160
]
When I try to solve, two of the equations are giving me issues:
[
c == ((((a_to_n % n) ** kappa) % n) * (a_to_flag % n)) % n,
a_to_p1 % (p * p) != 1,
]
I get a z3.z3types.Z3Exception: Z3 integer expression expected:
c == ((((a_to_n % n) ** kappa) % n) * (a_to_flag % n)) % n,
File "/home/retep/.local/lib/python3.8/site-packages/z3/z3.py", line 2411, in __mod__
_z3_assert(a.is_int(), "Z3 integer expression expected")
A is defined as an Int type, so why does this happen and how do I resolve it?

The problem here is that ** produces a Real number, not an Int, and thus you cannot combine the result with % later on. Here's a simple example to illustrate:
>>> from z3 import *
>>> a, n = Ints('a n')
>>> (a ** n).sort()
Real
That is, the expression a ** n is Real valued; so you cannot take the modulus with an integer later on. To convert back to integer, use ToInt:
>>> ToInt(a ** n).sort()
Int
So, whenever you use **, wrap the result in a call to ToInt to avoid this issue.
Having said that, mixing reals and integers this way will no doubt make your problem quite hard to deal with in z3. I wouldn't be surprised if z3 had very hard time solving your equations, assuming they're not merely constant folding. Perhaps you should simply stick to Real values for everything, which has a decidable theory. (Doesn't mean z3 can answer your queries easily; it just means z3 can decide it if you are willing to wait long enough and have enough computing resources.)

Related

Given a rod of length N , you need to cut it into R pieces , such that each piece's length is positive, how many ways are there to do so?

Description:
Given two positive integers N and R, how many different ways are there to cut a rod of length N into R pieces, such that the length of each piece is a positive integer? Output this answer modulo 1,000,000,007.
Example:
With N = 7 and R = 3, there are 15 ways to cut a rod of length 7 into 3 pieces: (1,1,5) , (1,5,1), (5,1,1) , (1,2,4) , (1,4,2) (2,1,4), (2,4,1) , (4,1,2), (4,2,1) , (1,3,3), (3,1,3), (3,3,1), (2,2,3), (2,3,2), (3,2,2).
Constraints:
1 <= R <= N <= 200,000
Testcases:
N R Output
7 3 15
36 6 324632
81 66 770289477
96 88 550930798
My approach:
I know that the answer is (N-1 choose R-1) mod 1000000007. I have tried all different ways to calculate it, but always 7 out of 10 test cases went time limit exceeded. Here is my code, can anyone tell me what other approach I can use to make it in O(1) time complexity.
from math import factorial
def new(n, r):
D = factorial(n - 1) // (factorial(r - 1) * factorial(n - r))
return (D % 1000000007)
if __name__ == '__main__':
N = [7, 36, 81, 96]
R = [3, 6, 66, 88]
answer = [new(n, r) for n,r in zip(N,R)]
print(answer)
I think there's two big optimizations that the problem is looking for you to exploit. The first being to cache intermediate values of factorial() to save computational effort across large batches (large T). The second optimization being to reduce your value mod 1000000007 incrementally, so your numbers stay small, and multiplication stays a constant-time. I've updated the below example to precompute a factorial table using a custom function and itertools.accumulate, instead of merely caching the calls in a recursive implementation (which will eliminate the issues with recursion depth you were seeing).
from itertools import accumulate
MOD_BASE = 1000000007
N_BOUND = 200000
def modmul(m):
def mul(x, y):
return x * y % m
return mul
FACTORIALS = [1] + list(accumulate(range(1, N_BOUND+1), modmul(MOD_BASE)))
def nck(n, k, m):
numerator = FACTORIALS[n]
denominator = FACTORIALS[k] * FACTORIALS[n-k]
return numerator * pow(denominator, -1, m) % m
def solve(n, k):
return nck(n-1, k-1, MOD_BASE)
Running this against the example:
>>> pairs = [(36, 6), (81, 66), (96, 88)]
>>> print([solve(n, k) for n, k in pairs])
[324632, 770289477, 550930798]
I literally translated code from accepted answer of Ivaylo Strandjev here and it works much faster:
def get_degree(n, p):# { // returns the degree with which p is in n!
degree_num = 0
u = p
temp = n
while (u <= temp):
degree_num += temp // u
u *= p
return degree_num
def degree(a, k, p):
res = 1
cur = a
while (k):
if (k % 2):
res = (res * cur) % p
k //= 2
cur = (cur * cur) % p
return res
def CNKmodP( n, k, p):
num_degree = get_degree(n, p) - get_degree(n - k, p)
den_degree = get_degree(k, p)
if (num_degree > den_degree):
return 0
res = 1
for i in range(n, n - k, -1):
ti = i
while(ti % p == 0):
ti //= p
res = (res * ti) % p
denom = 1
for i in range(1, k + 1):
ti = i
while(ti % p == 0):
ti //= p
denom = (denom * ti) % p
res = (res * degree(denom, p-2, p)) % p
return res
To apply this approach, you just need to call
result = CNKmodP(n-1, r-1, 1000000007)
In Java we can use BigInteger because the value of factorials that we calculate may not fit in integer. Additionally BigInteger provides built in methods multiply and divide.
static int CNRmodP(int N, int R, int P) {
BigInteger ret = BigInteger.ONE;
for (int i = 0; i < R; i++) {
ret = ret.multiply(BigInteger.valueOf(N - i))
.divide(BigInteger.valueOf(i + 1));
}
BigInteger p = BigInteger.valueOf(P);
//Calculate Modulus
BigInteger answer = ret.mod(p);
//Convert BigInteger to integer and return it
return answer.intValue();
}
To apply the above approach, you just need to call
result = CNRmodP(N-1, R-1, 1000000007);

Pythagorean Triplet with given sum

The following code prints the pythagorean triplet if it is equal to the input, but the problem is that it takes a long time for large numbers like 90,000 to answer.
What can I do to optimize the following code?
1 ≤ n ≤ 90 000
def pythagoreanTriplet(n):
# Considering triplets in
# sorted order. The value
# of first element in sorted
# triplet can be at-most n/3.
for i in range(1, int(n / 3) + 1):
# The value of second element
# must be less than equal to n/2
for j in range(i + 1,
int(n / 2) + 1):
k = n - i - j
if (i * i + j * j == k * k):
print(i, ", ", j, ", ",
k, sep="")
return
print("Impossible")
# Driver Code
vorodi = int(input())
pythagoreanTriplet(vorodi)
Your source code does a brute force search for a solution so it's slow.
Faster Code
def solve_pythagorean_triplets(n):
" Solves for triplets whose sum equals n "
solutions = []
for a in range(1, n):
denom = 2*(n-a)
num = 2*a**2 + n**2 - 2*n*a
if denom > 0 and num % denom == 0:
c = num // denom
b = n - a - c
if b > a:
solutions.append((a, b, c))
return solutions
OP code
Modified OP code so it returns all solutions rather than printing the first found to compare performance
def pythagoreanTriplet(n):
# Considering triplets in
# sorted order. The value
# of first element in sorted
# triplet can be at-most n/3.
results = []
for i in range(1, int(n / 3) + 1):
# The value of second element
# must be less than equal to n/2
for j in range(i + 1,
int(n / 2) + 1):
k = n - i - j
if (i * i + j * j == k * k):
results.append((i, j, k))
return results
Timing
n pythagoreanTriplet (OP Code) solve_pythagorean_triplets (new)
900 0.084 seconds 0.039 seconds
5000 3.130 seconds 0.012 seconds
90000 Timed out after several minutes 0.430 seconds
Explanation
Function solve_pythagorean_triplets is O(n) algorithm that works as follows.
Searching for:
a^2 + b^2 = c^2 (triplet)
a + b + c = n (sum equals input)
Solve by searching over a (i.e. a fixed for an iteration). With a fixed, we have two equations and two unknowns (b, c):
b + c = n - a
c^2 - b^2 = a^2
Solution is:
denom = 2*(n-a)
num = 2*a**2 + n**2 - 2*n*a
if denom > 0 and num % denom == 0:
c = num // denom
b = n - a - c
if b > a:
(a, b, c) # is a solution
Iterate a range(1, n) to get different solutions
Edit June 2022 by #AbhijitSarkar:
For those who like to see the missing steps:
c^2 - b^2 = a^2
b + c = n - a
=> b = n - a - c
c^2 - (n - a - c)^2 = a^2
=> c^2 - (n - a - c) * (n - a - c) = a^2
=> c^2 - n(n - a - c) + a(n - a - c) + c(n - a - c) = a^2
=> c^2 - n^2 + an + nc + an - a^2 - ac + cn - ac - c^2 = a^2
=> -n^2 + 2an + 2nc - a^2 - 2ac = a^2
=> -n^2 + 2an + 2nc - 2a^2 - 2ac = 0
=> 2c(n - a) = n^2 - 2an + 2a^2
=> c = (n^2 - 2an + 2a^2) / 2(n - a)
DarrylG's answer is correct, and I've added the missing steps to it as well, but there's another solution that's faster than iterating from [1, n). Let me explain it, but I'll leave the code up to the reader.
We use Euclid's formula of generating a tuple.
a = m^2 - n^2, b = 2mn, c = m^2 + n^2, where m > n > 0 ---(i)
a + b + c = P ---(ii)
Combining equations (i) and (ii), we have:
2m^2 + 2mn = P ---(iii)
Since m > n > 0, 1 <= n <= m - 1.
Putting n=1 in equation (iii), we have:
2m^2 + 2m - P = 0, ax^2 + bx + c = 0, a=2, b=2, c=-P
m = (-b +- sqrt(b^2 - 4ac)) / 2a
=> (-2 +- sqrt(4 + 8P)) / 4
=> (-1 +- sqrt(1 + 2P)) / 2
Since m > 0, sqrt(b^2 - 4ac) > -b, the only solution is
(-1 + sqrt(1 + 2P)) / 2 ---(iv)
Putting n=m-1 in equation (iii), we have:
2m^2 + 2m(m - 1) - P = 0
=> 4m^2 - 2m - P = 0, ax^2 + bx + c = 0, a=4, b=-2, c=-P
m = (-b +- sqrt(b^2 - 4ac)) / 2a
=> (2 +- sqrt(4 + 16P)) / 8
=> (1 +- sqrt(1 + 4P)) / 4
Since m > 0, the only solution is
(1 + sqrt(1 + 4P)) / 4 ---(v)
From equation (iii), m^2 + mn = P/2; since P/2 is constant,
when n is the smallest, m must be the largest, and vice versa.
Thus:
(1 + sqrt(1 + 4P)) / 4 <= m <= (-1 + sqrt(1 + 2P)) / 2 ---(vi)
Solving equation (iii) for n, we have:
n = (P - 2m^2) / 2m ---(vii)
We iterate for m within the bounds given by the inequality (vi)
and check when the corresponding n given by equation (vii) is
an integer.
Despite generating all primitive triples, Euclid's formula does not
produce all triples - for example, (9, 12, 15) cannot be generated using
integer m and n. This can be remedied by inserting an additional
parameter k to the formula. The following will generate all Pythagorean
triples uniquely.
a = k(m^2 - n^2), b = 2kmn, c = k(m^2 + n^2), for k >= 1.
Thus, we iterate for integer values of P/k until P < 12,
lowest possible perimeter corresponding to the triple (3, 4, 5).
Yo
I don't know if you still need the answer or not but hopefully, this can help.
n = int(input())
ans = [(a, b, c) for a in range(1, n) for b in range(a, n) for c in range(b, n) if (a**2 + b**2 == c**2 and a + b + c == n)]
if ans:
print(ans[0][0], ans[0][1], ans[0][2])
else:
print("Impossible")

Implement the function fast modular exponentiation

I am trying to implement the function fast modular exponentiation(b, k, m) which computes:
b(2k) mod m using only around 2k modular multiplications.
I tried this method:
def FastModularExponentiation(b, k, m):
res = 1
b = b % m
while (k > 0):
if ((k & 1) == 1):
res = (res * b) % m
k = k >> 1
b = (b * b) % m
return res
but I am still stuck in same problem which is if I try b = 2, k = 1, m = 10, my code returns 22. However, the correct answer is:
2^(2^1) mod 10 = 2^2 mod 10 = 4
and I cannot find the reason why.
Update: I finally understood that you do not want regular modular exponentiation (i.e., b^k mod m), but b^(2^k) mod m (as you plainly stated).
Using the regular built-in Python function pow this would be:
def FastModularExponentiation(b, k, m):
return pow(b, pow(2, k), m)
Or, without using pow:
def FastModularExponentiation(b, k, m):
b %= m
for _ in range(k):
b = b ** 2 % m
return b
If you know r = phi(m) (Euler's totient function), you could reduce the exponent first: exp = pow(2, k, r) and then calculate pow(b, exp, m). Depending on the input values, this might speed things up.
(This was the original answer when I thought you wanted, b^k mod m)
This is what works for me:
def fast_mod_exp(b, exp, m):
res = 1
while exp > 1:
if exp & 1:
res = (res * b) % m
b = b ** 2 % m
exp >>= 1
return (b * res) % m
The only significant differences I spot is in the last line: return (b * res) % m and that my while loop terminates earlier: while exp > 1 (which should be the same thing you do - except it saves an unnecessary squaring operation).
Also note that the built-in function pow will do all that for free (if you supply a third argument):
pow(4, 13, 497)
# 445
def fast_exponentiation(k, x, q):
# make sure all variables are non-negative
assert (k >= 0 and x >= 0 and q >=1)
result = 1 # define a counter
while x:
if x % 2 == 1:
result = (result * k) % q
k = (k ^ 2) % q
x >> = 1 # bit shift operator, dividing x by 2 ** y thus x >> 2 ** 1 = x / 2
return result

How to find mod of large numbers?

How to find (a^(b^c)) % (10^9 + 7) in Python for large inputs?
My code just get terminated after a few test cases.
My code:
numbers = list(map(int, input().split()))
x = numbers[2]
y = numbers[1]
z = numbers[0]
m = pow(10,9) + 7
a = pow(y,x)
r = z % m
for i in range (0,a):
r = r*z
r = r % m
print(r)
You should use Fermat's Little Theorem and Pingala's algorithm for power.
a ^ (p - 1) is 1 so if you take (b ^ c) % (10 ^ 9 + 6) that will become 1 (after taking power of a) and for calculating power modulo a number you can use binary approach that is a^2k % M = ((a^k % M)^2) % M and a ^ (2k + 1) % M = ((a % M) * (a^k % M) ^ 2) % M.
f a b c p = modPow a (modPow b c (p - 1)) p
modPow x n p
| n == 0 = 1
| even n =
let val = x ^ (n `div` 2)
in (val * val) `mod` p
| otherwise = (x * modPow x (n-1) p) `mod` p

How to compute the nth root of a very big integer

I need a way to compute the nth root of a long integer in Python.
I tried pow(m, 1.0/n), but it doesn't work:
OverflowError: long int too large to convert to float
Any ideas?
By long integer I mean REALLY long integers like:
11968003966030964356885611480383408833172346450467339251
196093144141045683463085291115677488411620264826942334897996389
485046262847265769280883237649461122479734279424416861834396522
819159219215308460065265520143082728303864638821979329804885526
557893649662037092457130509980883789368448042961108430809620626
059287437887495827369474189818588006905358793385574832590121472
680866521970802708379837148646191567765584039175249171110593159
305029014037881475265618958103073425958633163441030267478942720
703134493880117805010891574606323700178176718412858948243785754
898788359757528163558061136758276299059029113119763557411729353
915848889261125855717014320045292143759177464380434854573300054
940683350937992500211758727939459249163046465047204851616590276
724564411037216844005877918224201569391107769029955591465502737
961776799311859881060956465198859727495735498887960494256488224
613682478900505821893815926193600121890632
If it's a REALLY big number. You could use a binary search.
def find_invpow(x,n):
"""Finds the integer component of the n'th root of x,
an integer such that y ** n <= x < (y + 1) ** n.
"""
high = 1
while high ** n <= x:
high *= 2
low = high/2
while low < high:
mid = (low + high) // 2
if low < mid and mid**n < x:
low = mid
elif high > mid and mid**n > x:
high = mid
else:
return mid
return mid + 1
For example:
>>> x = 237734537465873465
>>> n = 5
>>> y = find_invpow(x,n)
>>> y
2986
>>> y**n <= x <= (y+1)**n
True
>>>
>>> x = 119680039660309643568856114803834088331723464504673392511960931441>
>>> n = 45
>>> y = find_invpow(x,n)
>>> y
227661383982863143360L
>>> y**n <= x < (y+1)**n
True
>>> find_invpow(y**n,n) == y
True
>>>
Gmpy is a C-coded Python extension module that wraps the GMP library to provide to Python code fast multiprecision arithmetic (integer, rational, and float), random number generation, advanced number-theoretical functions, and more.
Includes a root function:
x.root(n): returns a 2-element tuple (y,m), such that y is the
(possibly truncated) n-th root of x; m, an ordinary Python int,
is 1 if the root is exact (x==y**n), else 0. n must be an ordinary
Python int, >=0.
For example, 20th root:
>>> import gmpy
>>> i0=11968003966030964356885611480383408833172346450467339251
>>> m0=gmpy.mpz(i0)
>>> m0
mpz(11968003966030964356885611480383408833172346450467339251L)
>>> m0.root(20)
(mpz(567), 0)
You can make it run slightly faster by avoiding the while loops in favor of setting low to 10 ** (len(str(x)) / n) and high to low * 10. Probably better is to replace the len(str(x)) with the bitwise length and using a bit shift. Based on my tests, I estimate a 5% speedup from the first and a 25% speedup from the second. If the ints are big enough, this might matter (and the speedups may vary). Don't trust my code without testing it carefully. I did some basic testing but may have missed an edge case. Also, these speedups vary with the number chosen.
If the actual data you're using is much bigger than what you posted here, this change may be worthwhile.
from timeit import Timer
def find_invpow(x,n):
"""Finds the integer component of the n'th root of x,
an integer such that y ** n <= x < (y + 1) ** n.
"""
high = 1
while high ** n < x:
high *= 2
low = high/2
while low < high:
mid = (low + high) // 2
if low < mid and mid**n < x:
low = mid
elif high > mid and mid**n > x:
high = mid
else:
return mid
return mid + 1
def find_invpowAlt(x,n):
"""Finds the integer component of the n'th root of x,
an integer such that y ** n <= x < (y + 1) ** n.
"""
low = 10 ** (len(str(x)) / n)
high = low * 10
while low < high:
mid = (low + high) // 2
if low < mid and mid**n < x:
low = mid
elif high > mid and mid**n > x:
high = mid
else:
return mid
return mid + 1
x = 237734537465873465
n = 5
tests = 10000
print "Norm", Timer('find_invpow(x,n)', 'from __main__ import find_invpow, x,n').timeit(number=tests)
print "Alt", Timer('find_invpowAlt(x,n)', 'from __main__ import find_invpowAlt, x,n').timeit(number=tests)
Norm 0.626754999161
Alt 0.566340923309
If you are looking for something standard, fast to write with high precision. I would use decimal and adjust the precision (getcontext().prec) to at least the length of x.
Code (Python 3.0)
from decimal import *
x = '11968003966030964356885611480383408833172346450467339251\
196093144141045683463085291115677488411620264826942334897996389\
485046262847265769280883237649461122479734279424416861834396522\
819159219215308460065265520143082728303864638821979329804885526\
557893649662037092457130509980883789368448042961108430809620626\
059287437887495827369474189818588006905358793385574832590121472\
680866521970802708379837148646191567765584039175249171110593159\
305029014037881475265618958103073425958633163441030267478942720\
703134493880117805010891574606323700178176718412858948243785754\
898788359757528163558061136758276299059029113119763557411729353\
915848889261125855717014320045292143759177464380434854573300054\
940683350937992500211758727939459249163046465047204851616590276\
724564411037216844005877918224201569391107769029955591465502737\
961776799311859881060956465198859727495735498887960494256488224\
613682478900505821893815926193600121890632'
minprec = 27
if len(x) > minprec: getcontext().prec = len(x)
else: getcontext().prec = minprec
x = Decimal(x)
power = Decimal(1)/Decimal(3)
answer = x**power
ranswer = answer.quantize(Decimal('1.'), rounding=ROUND_UP)
diff = x - ranswer**Decimal(3)
if diff == Decimal(0):
print("x is the cubic number of", ranswer)
else:
print("x has a cubic root of ", answer)
Answer
x is the cubic number of 22873918786185635329056863961725521583023133411
451452349318109627653540670761962215971994403670045614485973722724603798
107719978813658857014190047742680490088532895666963698551709978502745901
704433723567548799463129652706705873694274209728785041817619032774248488
2965377218610139128882473918261696612098418
Oh, for numbers that big, you would use the decimal module.
ns: your number as a string
ns = "11968003966030964356885611480383408833172346450467339251196093144141045683463085291115677488411620264826942334897996389485046262847265769280883237649461122479734279424416861834396522819159219215308460065265520143082728303864638821979329804885526557893649662037092457130509980883789368448042961108430809620626059287437887495827369474189818588006905358793385574832590121472680866521970802708379837148646191567765584039175249171110593159305029014037881475265618958103073425958633163441030267478942720703134493880117805010891574606323700178176718412858948243785754898788359757528163558061136758276299059029113119763557411729353915848889261125855717014320045292143759177464380434854573300054940683350937992500211758727939459249163046465047204851616590276724564411037216844005877918224201569391107769029955591465502737961776799311859881060956465198859727495735498887960494256488224613682478900505821893815926193600121890632"
from decimal import Decimal
d = Decimal(ns)
one_third = Decimal("0.3333333333333333")
print d ** one_third
and the answer is: 2.287391878618402702753613056E+305
TZ pointed out that this isn't accurate... and he's right. Here's my test.
from decimal import Decimal
def nth_root(num_decimal, n_integer):
exponent = Decimal("1.0") / Decimal(n_integer)
return num_decimal ** exponent
def test():
ns = "11968003966030964356885611480383408833172346450467339251196093144141045683463085291115677488411620264826942334897996389485046262847265769280883237649461122479734279424416861834396522819159219215308460065265520143082728303864638821979329804885526557893649662037092457130509980883789368448042961108430809620626059287437887495827369474189818588006905358793385574832590121472680866521970802708379837148646191567765584039175249171110593159305029014037881475265618958103073425958633163441030267478942720703134493880117805010891574606323700178176718412858948243785754898788359757528163558061136758276299059029113119763557411729353915848889261125855717014320045292143759177464380434854573300054940683350937992500211758727939459249163046465047204851616590276724564411037216844005877918224201569391107769029955591465502737961776799311859881060956465198859727495735498887960494256488224613682478900505821893815926193600121890632"
nd = Decimal(ns)
cube_root = nth_root(nd, 3)
print (cube_root ** Decimal("3.0")) - nd
if __name__ == "__main__":
test()
It's off by about 10**891
Possibly for your curiosity:
http://en.wikipedia.org/wiki/Hensel_Lifting
This could be the technique that Maple would use to actually find the nth root of large numbers.
Pose the fact that x^n - 11968003.... = 0 mod p, and go from there...
I may suggest four methods for solving your task. First is based on Binary Search. Second is based on Newton's Method. Third is based on Shifting n-th Root Algorithm. Fourth is called by me Chord-Tangent method described by me in picture here.
Binary Search was already implemented in many answers above. I just introduce here my own vision of it and its implementation.
As alternative I also implement Optimized Binary Search method (marked Opt). This method just starts from range [hi / 2, hi) where hi is equal to 2^(num_bit_length / k) if we're computing k-th root.
Newton's Method is new here, as I see it wasn't implemented in other answers. It is usually considered to be faster than Binary Search, although my own timings in code below don't show any speedup. Hence this method here is just for reference/interest.
Shifting Method is 30-50% faster than optimized binary search method, and should be even faster if implemented in C++, because C++ has fast 64 bit arithemtics which is partially used in this method.
Chord-Tangent Method:
Chord-Tangent Method is invented by me on piece of paper (see image above), it is inspired and is an improvement of Newton method. Basically I draw a Chord and a Tangent Line and find intersection with horizontal line y = n, these two intersections form lower and upper bound approximations of location of root solution (x0, n) where n = x0 ^ k. This method appeared to be fastest of all, while all other methods do more than 2000 iterations, this method does just 8 iterations, for the case of 8192-bit numbers. So this method is 200-300x times faster than previous (by speed) Shifting Method.
As an example I generate really huge random integer of 8192 bits in size. And measure timings of finding cubic root with both methods.
In test() function you can see that I passed k = 3 as root's power (cubic root), you can pass any power instead of 3.
Try it online!
def binary_search(begin, end, f, *, niter = [0]):
while begin < end:
niter[0] += 1
mid = (begin + end) >> 1
if f(mid):
begin = mid + 1
else:
end = mid
return begin
def binary_search_kth_root(n, k, *, verbose = False):
# https://en.wikipedia.org/wiki/Binary_search_algorithm
niter = [0]
res = binary_search(0, n + 1, lambda root: root ** k < n, niter = niter)
if verbose:
print('Binary Search iterations:', niter[0])
return res
def binary_search_opt_kth_root(n, k, *, verbose = False):
# https://en.wikipedia.org/wiki/Binary_search_algorithm
niter = [0]
hi = 1 << (n.bit_length() // k - 1)
while hi ** k <= n:
niter[0] += 1
hi <<= 1
res = binary_search(hi >> 1, hi, lambda root: root ** k < n, niter = niter)
if verbose:
print('Binary Search Opt iterations:', niter[0])
return res
def newton_kth_root(n, k, *, verbose = False):
# https://en.wikipedia.org/wiki/Newton%27s_method
f = lambda x: x ** k - n
df = lambda x: k * x ** (k - 1)
x, px, niter = n, 2 * n, [0]
while abs(px - x) > 1:
niter[0] += 1
px = x
x -= f(x) // df(x)
if verbose:
print('Newton Method iterations:', niter[0])
mini, minv = None, None
for i in range(-2, 3):
v = abs(f(x + i))
if minv is None or v < minv:
mini, minv = i, v
return x + mini
def shifting_kth_root(n, k, *, verbose = False):
# https://en.wikipedia.org/wiki/Shifting_nth_root_algorithm
B_bits = 64
r, y = 0, 0
B = 1 << B_bits
Bk_bits = B_bits * k
Bk_mask = (1 << Bk_bits) - 1
niter = [0]
for i in range((n.bit_length() + Bk_bits - 1) // Bk_bits - 1, -1, -1):
alpha = (n >> (i * Bk_bits)) & Bk_mask
B_y = y << B_bits
Bk_yk = (y ** k) << Bk_bits
Bk_r_alpha = (r << Bk_bits) + alpha
Bk_yk_Bk_r_alpha = Bk_yk + Bk_r_alpha
beta = binary_search(1, B, lambda beta: (B_y + beta) ** k <= Bk_yk_Bk_r_alpha, niter = niter) - 1
y, r = B_y + beta, Bk_r_alpha - ((B_y + beta) ** k - Bk_yk)
if verbose:
print('Shifting Method iterations:', niter[0])
return y
def chord_tangent_kth_root(n, k, *, verbose = False):
niter = [0]
hi = 1 << (n.bit_length() // k - 1)
while hi ** k <= n:
niter[0] += 1
hi <<= 1
f = lambda x: x ** k
df = lambda x: k * x ** (k - 1)
# https://i.stack.imgur.com/et9O0.jpg
x_begin, x_end = hi >> 1, hi
y_begin, y_end = f(x_begin), f(x_end)
for icycle in range(1 << 30):
if x_end - x_begin <= 1:
break
niter[0] += 1
if 0: # Do Binary Search step if needed
x_mid = (x_begin + x_end) >> 1
y_mid = f(x_mid)
if y_mid > n:
x_end, y_end = x_mid, y_mid
else:
x_begin, y_begin = x_mid, y_mid
# (y_end - y_begin) / (x_end - x_begin) = (n - y_begin) / (x_n - x_begin) ->
x_n = x_begin + (n - y_begin) * (x_end - x_begin) // (y_end - y_begin)
y_n = f(x_n)
tangent_x = x_n + (n - y_n) // df(x_n) + 1
chord_x = x_n + (n - y_n) * (x_end - x_n) // (y_end - y_n)
assert chord_x <= tangent_x, (chord_x, tangent_x)
x_begin, x_end = chord_x, tangent_x
y_begin, y_end = f(x_begin), f(x_end)
assert y_begin <= n, (chord_x, y_begin, n, n - y_begin)
assert y_end > n, (icycle, tangent_x - binary_search_kth_root(n, k), y_end, n, y_end - n)
if verbose:
print('Chord Tangent Method iterations:', niter[0])
return x_begin
def test():
import random, timeit
nruns = 3
bits = 8192
n = random.randrange(1 << (bits - 1), 1 << bits)
a = binary_search_kth_root(n, 3, verbose = True)
b = binary_search_opt_kth_root(n, 3, verbose = True)
c = newton_kth_root(n, 3, verbose = True)
d = shifting_kth_root(n, 3, verbose = True)
e = chord_tangent_kth_root(n, 3, verbose = True)
assert abs(a - b) <= 0 and abs(a - c) <= 1 and abs(a - d) <= 1 and abs(a - e) <= 1, (a - b, a - c, a - d, a - e)
print()
print('Binary Search timing:', round(timeit.timeit(lambda: binary_search_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
print('Binary Search Opt timing:', round(timeit.timeit(lambda: binary_search_opt_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
print('Newton Method timing:', round(timeit.timeit(lambda: newton_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
print('Shifting Method timing:', round(timeit.timeit(lambda: shifting_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
print('Chord Tangent Method timing:', round(timeit.timeit(lambda: chord_tangent_kth_root(n, 3), number = nruns) / nruns, 3), 'sec')
if __name__ == '__main__':
test()
Output:
Binary Search iterations: 8192
Binary Search Opt iterations: 2732
Newton Method iterations: 9348
Shifting Method iterations: 2752
Chord Tangent Method iterations: 8
Binary Search timing: 0.506 sec
Binary Search Opt timing: 0.05 sec
Newton Method timing: 2.09 sec
Shifting Method timing: 0.03 sec
Chord Tangent Method timing: 0.001 sec
I came up with my own answer, which takes #Mahmoud Kassem's idea, simplifies the code, and makes it more reusable:
def cube_root(x):
return decimal.Decimal(x) ** (decimal.Decimal(1) / decimal.Decimal(3))
I tested it in Python 3.5.1 and Python 2.7.8, and it seemed to work fine.
The result will have as many digits as specified by the decimal context the function is run in, which by default is 28 decimal places. According to the documentation for the power function in the decimal module, "The result is well-defined but only “almost always correctly-rounded”.". If you need a more accurate result, it can be done as follows:
with decimal.localcontext() as context:
context.prec = 50
print(cube_root(42))
In older versions of Python, 1/3 is equal to 0. In Python 3.0, 1/3 is equal to 0.33333333333 (and 1//3 is equal to 0).
So, either change your code to use 1/3.0 or switch to Python 3.0 .
Try converting the exponent to a floating number, as the default behaviour of / in Python is integer division
n**(1/float(3))
Well, if you're not particularly worried about precision, you could convert it to a sting, chop off some digits, use the exponent function, and then multiply the result by the root of how much you chopped off.
E.g. 32123 is about equal to 32 * 1000, the cubic root is about equak to cubic root of 32 * cubic root of 1000. The latter can be calculated by dividing the number of 0s by 3.
This avoids the need for the use of extension modules.

Categories

Resources