Permutation sign for batched vectors - Python - python

I would like to find the permutation parity sign for a given batch of vectors (in Python /Jax).
n = jnp.array([[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[1., 1., 0., 0.],
[1., 1., 0., 0.]],
[[1., 0., 1., 0.],
[0., 1., 0., 1.],
[1., 0., 1., 0.],
[0., 1., 0., 1.]],
[[0., 1., 1., 0.],
[1., 0., 0., 1.],
[1., 0., 0., 1.],
[0., 1., 1., 0.]]])
sorted_index = jax.vmap(sorted_idx)(n)
sorted_perms = jax.vmap(jax.vmap(sorted_perm, in_axes=(0, 0)), in_axes=(0,0))(n, sorted_index)
parities = jax.vmap(parities)(sorted_index)
I expect the following solution:
sorted_elements= [[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]],
[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]],
[[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]]]
parities = [[1, 1, 1, 1],
[-1, -1, -1, -1],
[1, 1, 1, 1]]
I tried the following:
# sort the array and return the arg_sort indices
def sorted_idx(permutations):
sort_idx = jnp.argsort(permutations)
return sort_idx
# sort the permutations (vectors) given the sorted_indices
def sorted_perm(permutations, sort_idx):
perm = permutations[sort_idx]
return perm
# Calculate the permutation cycle, from which we compute the permutation parity
#jax.vmap
def parities(sort_idx):
length = len(sort_idx)
elements_seen = jnp.zeros(length)
cycles = 0
for index in range(length):
if elements_seen[index] == True:
continue
cycles += 1
current = index
if elements_seen[current] == False:
elements_seen.at[current].set(True)
current = sort_idx[current]
is_even = (length - cycles) % 2 == 0
return +1 if is_even else -1
But I get the following: parities= [[1 1 1 1], [1 1 1 1], [1 1 1 1]]
I get for each permutation vector a parity factor of 1, which is wrong....

The reason your routine doesn't work is because you're attempting to vmap over Python control flow, and this must be done very carefully (See JAX Sharp Bits: Control Flow). I suspect it would be a bit complicated to try to construct your iterative parity approach in terms of jax.lax control flow operators, but there might be another way.
The parity of a permutation is related to the determinant of its cycle matrix, and the jacobian of a sort happens to be equivalent to that cycle matrix, so you could (ab)use JAX's automatic differentiation of the sort operator to compute the parities very concisely:
def compute_parity(p):
return jnp.linalg.det(jax.jacobian(jnp.sort)(p.astype(float))).astype(int)
sorted_index = jnp.argsort(n, axis=-1)
parities = jax.vmap(jax.vmap(compute_parity))(sorted_index)
print(parities)
# [[ 1 1 1 1]
# [-1 -1 -1 -1]
# [ 1 1 1 1]]
This does end up being O[N^3] where N is the length of the permutations, but due to the nature of XLA computations, particularly on accelerators like GPU, the vectorized approach will likely be more efficient than an iterative approach for reasonably-sized N.
Also note that there's no reason to compute the sorted_index with this implementation; you could call compute_parity directly on your array n instead.

Related

Generate all n choose k binary vectors python

Is there any efficient way (numpy style) to generate all n choose k binary vectors (with k ones)?
for example, if n=3 and k=2, then I want to generate (1,1,0), (1,0,1), (0,1,1).
Thanks
I do not know how efficient this is, but here is a way:
from itertools import combinations
import numpy as np
n, k = 5, 3
np.array(
[
[1 if i in comb else 0 for i in range(n)]
for comb in combinations(np.arange(n), k)
]
)
>>>
array([[1., 1., 1., 0., 0.],
[1., 1., 0., 1., 0.],
[1., 1., 0., 0., 1.],
[1., 0., 1., 1., 0.],
[1., 0., 1., 0., 1.],
[1., 0., 0., 1., 1.],
[0., 1., 1., 1., 0.],
[0., 1., 1., 0., 1.],
[0., 1., 0., 1., 1.],
[0., 0., 1., 1., 1.]])

How to interpret numpy advanced indexing solution

I have a piece of numpy code that I know works. I know this because I have tested it in my generic case successfully. However, I arrived at the solution after two hours of back and forth referencing the docs and trial and error. I can't grasp how I would know to do this intuitively.
The setup:
a = np.zeros((5,5,3))
The goal: Set to 1 indices 0,1 of axis 1, 0,1 of axis 2, all of axis 3 and indices 3,4 of axis 1, 3,4 of axis 2, all of axis 3
Clearer goal: Set block 1 and 2's first two rows to 1 and block 3 and 4's last two rows to 1
The result:
ax1 =np.array([np.array([0,1]),np.array([3,4])])
ax1 =np.array([x[:,np.newaxis] for x in ax1])
ax2 = np.array([[[0,1]],[[3,4]]])
a[ax1,ax2,:] = 1
a
Output:
array([[[1., 1., 1.],
[1., 1., 1.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[1., 1., 1.],
[1., 1., 1.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[1., 1., 1.],
[1., 1., 1.]]])
I'm inclined to believe I should be able to look at the shape of the matrix in question, the shape of the indices, and the index operation to intuitively know the output. However, I can't put the story together in my head. Like, what's the final shape of the subspace it is altering? How would you explain how this works?
The shapes:
input: (5, 5, 3)
ind1: (2, 2, 1)
ind2: (2, 1, 2)
final_op: input[ind1, ind2, :]
With shapes
ind1: (2, 2, 1)
ind2: (2, 1, 2)
they broadcast together to select a (2,2,2) space
In [4]: ax1
Out[4]:
array([[[0],
[1]],
[[3],
[4]]])
In [5]: ax2
Out[5]:
array([[[0, 1]],
[[3, 4]]])
So for the 1st dimension (blocks) it is selecting blocks 0,1,3,and 4. In the second dimension it is also selecting these rows.
Together that's the first 2 rows of the first 2 blocks, and the last 2 rows of the last 2 blocks. That's where the 1s appear in your result.
A simpler way of creating the index arrays:
In [7]: np.array([[0,1],[3,4]])[:,:,None] # (2,2) expanded to (2,2,1)
In [8]: np.array([[0,1],[3,4]])[:,None,:] # expand to (2,1,2)
This is how broadcasting expands them:
In [10]: np.broadcast_arrays(ax1,ax2)
Out[10]:
[array([[[0, 0], # block indices
[1, 1]],
[[3, 3],
[4, 4]]]),
array([[[0, 1], # row indices
[0, 1]],
[[3, 4],
[3, 4]]])]
This may make the pattern clearer:
In [15]: a[ax1,ax2,:] = np.arange(1,5).reshape(2,2,1)
In [16]: a[:,:,0]
Out[16]:
array([[1., 2., 0., 0., 0.],
[3., 4., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 1., 2.],
[0., 0., 0., 3., 4.]])

Inset a matrix of zeros across another matrix's diagonal

Watered down example
Consider I have the following matrix A,
1 2 4 3
1 7 3 6
2 4 1 1
6 9 3 6
I would want to convert it to the matrix B which looks like,
0 0 4 4
0 0 3 6
2 4 0 0
6 9 0 0
So basically I would want to have, let's say a 2x2 matrix of zeros in the diagonal of the 4x4 matrix given above.
Need for a general solution
What i have provided above is just an example and I am going to use (1296, 1296) sized matrix as input and I want to inset a 3x3 matrix of zeros inside its diagonal.
What I have done so far?
A simple range based loop and then setting values to zero like so,
for i in range(0, mat.shape[0] - 1, 3):
mat[i][i] = 0
mat[i][i + 1] = 0
mat[i][i + 2] = 0
mat[i + 1][i] = 0
mat[i + 1][i + 1] = 0
mat[i + 1][i + 2] = 0
mat[i + 2][i] = 0
mat[i + 2][i + 1] = 0
mat[i + 2][i + 2] = 0
I completely understand that this is a very crude and nasty way to do it. Please suggest a fast and "numpy" way of doing this.
You could try something like this:
start=0
stop=1296
step=3
for i in np.arange(start=start, stop=stop, step=step):
mat[i:i+step, i:i+step] = 0
Here is a loop free solution:
mat = np.ones((12,12))
np.einsum('ijik->ijk',mat.reshape((4,3,4,3)))[...] = 0
mat
# array([[0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 1., 1.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
# [1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.]])
This should fix it:
your_matrix = [
[1, 2, 4, 3],
[1, 7, 3, 6],
[2, 4, 1, 1],
[6, 9, 3, 6]
]
def calculate(matrix, x, y, size):
for index, x_row in enumerate(matrix):
for add in range(size):
if index == y+add:
for add_2 in range(size):
x_row[x+add_2] = 0
return matrix
your_matrix = calculate(matrix=your_matrix, x=1, y=1, size=2)
print(your_matrix)
The x, y is on the left top corner of your 0 box and the size is how big you want your box.

Opposite of binary_dilation

Is there a function that does the opposite of binary_dilation? I'm looking to remove 'islands' from an array of 0's and 1's. That is, if a value of 1 in a 2D array doesn't have at least 1 adjacent neighbor that is also 1, its value gets set to 0 (rather than have its neighbor's values set equal to 1 as in binary_dilation). So for example:
test = np.zeros((5,5))
test[1,1] = test[1,2] = test[3,3] = test[4,3] = test[0,3] = test[3,1] = 1
test
array([[0., 0., 0., 1., 0.],
[0., 1., 1., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 1., 0., 1., 0.],
[0., 0., 0., 1., 0.]])
And the function I'm seeking would return:
array([[0., 0., 0., 0., 0.],
[0., 1., 1., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 1., 0.]])
Note the values changed in locations [0,3] and [3,1] from 1 to 0 because they have no adjacent neighbors with a value equal 1 (diagonal doesn't count as a neighbor).
You can create a mask with the cells to check and do a 2d convolution with test to identify the cells with 1s adjacent to them. The logical and of the convolution and test should produce the desired output.
First define your mask. Since you are only looking for up/down and left/right adjacency, you want the following:
mask = np.ones((3, 3))
mask[1,1] = mask[0, 0] = mask[0, 2] = mask[2, 0] = mask[2, 2] = 0
print(mask)
#array([[0., 1., 0.],
# [1., 0., 1.],
# [0., 1., 0.]])
If you wanted to include diagonal elements, you'd simply update mask to include 1s in the corners.
Now apply a 2d convolution of test with mask. This will multiply and add the values from the two matrices. With this mask, this will have the effect of returning the sum of all adjacent values for each cell.
from scipy.signal import convolve2d
print(convolve2d(test, mask, mode='same'))
#array([[0., 1., 2., 0., 1.],
# [1., 1., 1., 2., 0.],
# [0., 2., 1., 1., 0.],
# [1., 0., 2., 1., 1.],
# [0., 1., 1., 1., 1.]])
You have to specify mode='same' so the result is the same size as the first input (test). Notice that the two cells that you wanted to remove from test are 0 in the convolution output.
Finally do a element wise and operation with this output and test to find the desired cells:
res = np.logical_and(convolve2d(test, mask, mode='same'), test).astype(int)
print(res)
#array([[0, 0, 0, 0, 0],
# [0, 1, 1, 0, 0],
# [0, 0, 0, 0, 0],
# [0, 0, 0, 1, 0],
# [0, 0, 0, 1, 0]])
Update
For the last step, you could also just clip the values in the convolution between 0 and 1 and do an element wise multiplication.
res = convolve2d(test, mask, mode='same').clip(0, 1)*test
#array([[0., 0., 0., 0., 0.],
# [0., 1., 1., 0., 0.],
# [0., 0., 0., 0., 0.],
# [0., 0., 0., 1., 0.],
# [0., 0., 0., 1., 0.]])

Numpy: Replace every n element in the first half of an array

If I have a numpy array and want to replace every nth element to 0 in the first half of the array( no change in the second half), how can I do this efficiently? Now my code is not efficient enough:
for i in xrange(1,half,n):
s[i] = 0
Just use a[:a.size//2:n] = 0. e.g.:
a = np.ones(10)
a[:a.size//2:2] = 0
a
array([ 0., 1., 0., 1., 0., 1., 1., 1., 1., 1.])
Another example:
a = np.ones(20)
n = 3
a[:a.size//2:n] = 0
a
array([ 0., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1.])
You could slice the array by doing something like:
import numpy as np
# make an array of 11 elements filled with zeros
my_arr = np.zeros(11)
# get indexes to change in the array. range is: range(start, stop[, step])
a = range(0, 5, 2)
# print the original array
print my_arr
# Change the array
my_arr[a] = 1
# print the changes
print my_arr
Outputs:
array([ 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
array([ 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0.])

Categories

Resources