Parallelising function that updates numpy matrix - python

Suppose I have a function that takes in a 4D matrix and 4 other arrays and does an update of the matrix entries:
#njit
def func(matrix, A, B, C, D):
for i in range(len(A)):
for j in range(len(B)):
for k in range(len(C)):
for l in range(len(D)):
matrix[i][j][k][l] = another_func(A[i], B[j], C[k], D[l])
An example of another_func is of the form below:
def another_func(a, b, c, d)
if c <= 5:
return b
else:
return 1 + d * (c - a) + (b - d)
Is there a way to parallelise these updates or ways to speed it up more?
I tried using numba's #njit but the speedup's not good enough.

Related

A function works fine with one number, but not an array. What am I missing?

Pardon my ignorance, but I'm very new to coding in python. I have a pretty simple function; it just needs to make a calculation based on b's relative location to a and c:
a = 6
b = 3
c = 2
def function(a, b, c):
if ((a >= b) & (b >= c)):
return b - c
elif ((a <= b) & (b >= c)):
return a - c
else:
return 0
t = function(a, b, c)
print(t)
When I run it with simple numbers like above, it gives me the right answer no matter what I make b. (In this case 1)
But When I run it with a,b, and c as Numpy Arrays, it only returns b - c across the entire "t" array.
It's not too much different, but here's what I'm using for the array version:
def function(a, b, c):
if ((a >= b) & (b >= c)).any():
return b - c
elif ((a <= b) & (b >= c)).any():
return a - c
else:
return 0
t = function(a, b, c[i>1])
print(t)
(The [i>1] is there because there is a variable amount of array input, and another function will be used for when [i = 0])
I've also tried this:
t = np.where(((prev2 >= Head_ELV) & (Head_ELV >= Bottom_ELV)).any, Head_ELV - Bottom_ELV, 0)
but ran into the same result.
Would a while-loop work better?
I don't think you need looping here as the problem can be solved using array operations. You could try the below, assuming the arrays are of the same length.
# import numpy to be able to work with arrays
import numpy as np
def function(a, b, c):
# declare array t with only zeros
t = np.zeros_like(a)
# declare filters
mask_1 = (a >= b) * (b >= c)
mask_2 = (a <= b) * (b >= c)
# modifying t based on the filters above
t[mask_1] = (b - c)[mask_1]
t[mask_2] = (a - c)[mask_2]
return t
# example 2d arrays
a = np.array([[1800,5], [5,5]])
b = np.array([[3416,2], [3,4]])
c = np.array([[1714,2], [3,4]])
# run function
function(a, b, c)

How to iterate over tuples using jax.lax.scan

I am looking to translate a bit of code from a NumPy version listed here, to a JAX compatible version. The NumPy code iteratively calculates the value of a matrix, E from the values of other matrices, A, B, D, as well as the value of E from the previous iteration: E_jm1.
Both the NumPy and JAX version work in their listed forms and produce identical results. How can I get the JAX version to work when passing A, B, D as a tuple instead of as a concatenated array? I have a specific use case where a tuple would be more useful.
I found a question asking something similar, but it just confirmed that this should be possible. There are no examples in the documentation or elsewhere that I could find.
Original NumPy version
import numpy as np
import jax
import jax.numpy as jnp
def BAND_J(A, B, D, E_jm1):
'''
output: E(N x N)
input: A(N x N), B(N x N), D(N x N), E_jm1(N x N)
๐„โฑผ = -[๐ + ๐€๐„โฑผโ‚‹โ‚]โปยน ๐ƒ
'''
B_inv = np.linalg.inv(B + np.dot( A, E_jm1 ))
E = -np.dot(B_inv, D)
return E
key = jax.random.PRNGKey(0)
N = 2
NJ = 4
# initialize matrices with random values
A, B, D = [ jax.random.normal(key, shape=(N,N,NJ)),
jax.random.normal(key, shape=(N,N,NJ)),
jax.random.normal(key, shape=(N,N,NJ)) ]
A_np, B_np, D_np = [np.asarray(A), np.asarray(B), np.asarray(D)]
# initialize E_0
E_0 = jax.random.normal(key+2, shape=(N,N))
E_np = np.empty((N,N,NJ))
E_np[:,:,0] = np.asarray(E_0)
# iteratively calculate E from A, B, D, and ๐„โฑผโ‚‹โ‚
for j in range(1,NJ):
E_jm1 = E_np[:,:,j-1]
E_np[:,:,j] = BAND_J(A_np[:,:,j], B_np[:,:,j], D_np[:,:,j], E_jm1)
JAX scan version
def BAND_J(E, ABD):
'''
output: E(N x N)
input: A(N x N), B(N x N), D(N x N), E_jm1(N x N)
'''
A, B, D = ABD
B_inv = jnp.linalg.inv(B + jnp.dot( A, E ))
E = -jnp.dot(B_inv, D)
return E, E # ("carryover", "accumulated")
abd = jnp.asarray([(A[:,:,j], B[:,:,j], D[:,:,j]) for j in range(NJ)])
# abd = tuple([(A[:,:,j], B[:,:,j], D[:,:,j]) for j in range(NJ)]) # this produces error
# ValueError: too many values to unpack (expected 3)
_, E = lax.scan(BAND_J, E_0, abd)
for j in range(1, NJ):
print(np.isclose(E[j-1], E_np[:,:,j]))
The short answer is "you can't". By design, jax.scan can scan over axes of arrays, not entries of arbitrary Python collections.
So if you want to use scan, you'll have to stack your entires into an array.
That said, since your tuple only has three elements, a good alternative would be to skip the scan and simply JIT-compile the for loop approach. JAX tracing will effectively unroll the loop and optimize the flattened sequence of operations. While this can lead to long compile times for large loops, since your application is only 3 iterations it shouldn't be problematic.

Numpy Equivalent of Logic Operation on Multidimensional Arrays in MATLAB

I have the following logic operation coded in MATLAB, where [A, B, C, and D] are all 5x3x16 doubles and [a, b, c, and d] are all 240x1 doubles. I am trying to implement the same logic operation in python using numpy.
D = zeros(size(A));
for i = 1:numel(D)
flag = ...
(a == A(i)) & ...
(b == B(i)) & ...
(c == C(i));
D(i) = d(flag);
end
d is a column vector that is already populated with data. a, b, and c are also populated column vectors of equal size. Meshgrid was used to construct A, B, and C into a LxMxN grid of the unique values within a, b, and c. Now I want to use d to populate a LxMxN D with the appropriate values using the boolean expression.
I have tried:
D= np.zeros(np.shape(N))
for i in range(len(D)):
for j in range(len(D[0])):
for k in range(len(D[0][0])):
flag = np.logical_and(
(a == A[i][j][k]),
(b == B[i][j][k]),
(c == C[i][j][k])
)
D[i][j][k] = d[flag];
The syntax will be a little messier, but you can use the np.logical_* functions to do this.

Solving linear equation with large numbers

I have equation of form a+b*n1=c+d*n2, where a,b,c,d are known numbers with around 1000 digits and I need to solve n1.
I tried:
i=1
while True:
a=(a+b)%d
if(a==c):
break
i+=1
print(i)
, but this method is too slow for numbers this big. Is there some better method to use in this kind of situations?
You want to find x such that x = a (mod b) and x = c (mod d). For then, n1 = (x - a) / b and n2 = (x - c) / d.
If b and d are coprime, then the existence of x is guaranteed by the Chinese Remainder Theorem -- and a solution can be found using the Extended Euclidean Algorithm.
If b and d aren't coprime (that is, if gcd(b, d) != 1), then (noting that a = c (mod gcd(b, d))), we can subtract a % gcd(b, d) from both sides, and divide through by gcd(b, d) to reduce to a problem as above.
Putting it into code
Here's code that finds n1 and n2 using this method:
def egcd(a, b):
if a == 0:
return (b, 0, 1)
else:
g, y, x = egcd(b % a, a)
return (g, x - (b // a) * y, y)
def modinv(a, m):
return egcd(a, m)[1] % m
def solve(a, b, c, d):
gcd = egcd(b, d)[0]
if gcd != 1:
if a % gcd != c % gcd:
raise ValueError('no solution')
a, c = a - a % gcd, c - c % gcd
a //= gcd
b //= gcd
c //= gcd
d //= gcd
x = a * d * modinv(d, b) + c * b * modinv(b, d)
return (x - a) // b, (x - c) // d
And here's some test code that runs 1000 random trials of 1000-digit inputs:
import sys
sys.setrecursionlimit(10000)
import random
digit = '0123456789'
def rn(k):
return int(''.join(random.choice(digit) for _ in xrange(k)), 10)
k = 1000
for _ in xrange(1000):
a, b, c, d, = rn(k), rn(k), rn(k), rn(k)
print a, b, c, d
try:
n1, n2 = solve(a, b, c, d)
except ValueError, exn:
print 'no solution'
print
continue
if a + b * n1 != c + d * n2:
raise AssertionError('failed!')
print 'found solution:', n1, n2
print
(Note, the recursion limit has to be increased because the egcd function which implements the Extended Euclidean algorithm is recursive, and running it on 1000 digit numbers can require a quite deep stack).
Also note, that this checks the result when a solution is returned. But when a != c (mod gcd(b, d)) and the exception is raised signalling no result, no check is done. So you need to think through if this can fail to find results when solutions do exist.
This runs (1000 trials) in around 7-8 seconds on my machine, so it performs reasonably well.

Dynamic programming solution to maximizing an expression by placing parentheses

I'm trying to implement an algorithm from Algorithmic Toolbox course on Coursera that takes an arithmetic expression such as 5+8*4-2 and computes its largest possible value. However, I don't really understand the choice of indices in the last part of the shown algorithm; my implementation fails to compute values using the ones initialized in 2 tables (which are used to store maximized and minimized values of subexpressions).
The evalt function just takes the char, turns it into the operand and computes a product of two digits:
def evalt(a, b, op):
if op == '+':
return a + b
#and so on
MinMax computes the minimum and the maximum values of subexpressions
def MinMax(i, j, op, m, M):
mmin = 10000
mmax = -10000
for k in range(i, j-1):
a = evalt(M[i][k], M[k+1][j], op[k])
b = evalt(M[i][k], m[k+1][j], op[k])
c = evalt(m[i][k], M[k+1][j], op[k])
d = evalt(m[i][k], m[k+1][j], op[k])
mmin = min(mmin, a, b, c, d)
mmax = max(mmax, a, b, c, d)
return(mmin, mmax)
And this is the body of the main function
def get_maximum_value(dataset):
op = dataset[1:len(dataset):2]
d = dataset[0:len(dataset)+1:2]
n = len(d)
#iniitializing matrices/tables
m = [[0 for i in range(n)] for j in range(n)] #minimized values
M = [[0 for i in range(n)] for j in range(n)] #maximized values
for i in range(n):
m[i][i] = int(d[i]) #so that the tables will look like
M[i][i] = int(d[i]) #[[i, 0, 0...], [0, i, 0...], [0, 0, i,...]]
for s in range(n): #here's where I get confused
for i in range(n-s):
j = i + s
m[i][j], M[i][j] = MinMax(i,j,op,m,M)
return M[0][n-1]
Sorry to bother, here's what had to be improved:
for s in range(1,n)
in the main function, and
for k in range(i, j):
in MinMax function. Now it works.
The following change should work.
for s in range(1,n):
for i in range(0,n-s):

Categories

Resources