JAX best way to iterate RNGKeys? - python

In JAX I find myself needing a PRNGKey that changes on each iteration of a loop. I'm not sure of the best pattern. I've considered
a) split
for i in range(N):
rng, _ = jax.random.split(rng)
# Alternatively.
rng = jax.random.split(rng, 1)[0]
b) fold_in
for i in range(N):
rng = jax.random.fold_in(rng, i)
c) use the iterator index? seems bad since the rng doesn't depend on a prior rng.
for i in range(N):
rng = jax.random.PRNGKey(i)
Which of these is the best pattern and why? I am leaning towards (b) as it maintains dependency on the previous rng key (e.g. passed in as an argument) but im not sure if this is really the intended use-case for jax.random.fold_in

JAX docs (including the PRNG design doc) recommend something similar to (a):
for i in range(N):
key, subkey = jax.random.split(key)
values = random.uniform(subkey, shape)
# key carries over to the next iteration
The reason this is better than splitting and throwing away the subkey is that it ensures that the streams in each iteration are independent.
Your option (b) is also safe, and in fact is the pattern that developers had in mind when creating fold_in (see e.g. https://github.com/google/jax/discussions/12395).
If you have a fixed number of iterations, it may be better to do all the splits once; for example:
for i, key in enumerate(random.split(key, N)):
values = random.uniform(key, shape)
Or if your iterations do not have sequential dependence, it's better to use vmap to vectorize the operation:
def f(key):
return random.uniform(key, shape)
jax.vmap(f)(random.split(key, N))

Related

The most efficient way rather than using np.setdiff1d and np.in1d, to remove common values of 1D arrays with unique values

I need a much faster code to remove values of an 1D array (array length ~ 10-15) that are common with another 1D array (array length ~ 1e5-5e5 --> rarely up to 7e5), which are index arrays contain integers. There is no duplicate in the arrays, and they are not sorted and the order of the values must be kept in the main array after modification. I know that can be achieved using such np.setdiff1d or np.in1d (which both are not supported for numba jitted in no-python mode), and other similar posts (e.g. this) have not much more efficient way to do so, but performance is important here because all the values in the main index array will be gradually be removed in loops.
import numpy as np
import numba as nb
n = 500000
r = 10
arr1 = np.random.permutation(n)
arr2 = np.random.randint(0, n, r)
# #nb.jit
def setdif1d_np(a, b):
return np.setdiff1d(a, b, assume_unique=True)
# #nb.jit
def setdif1d_in1d_np(a, b):
return a[~np.in1d(a, b)]
There is another related post that proposed by norok2 for 2D arrays, that is ~15 times faster solution (hashing-like way using numba) than usual methods described there. This solution may be the best if it could be prepared for 1D arrays:
#nb.njit
def mul_xor_hash(arr, init=65537, k=37):
result = init
for x in arr.view(np.uint64):
result = (result * k) ^ x
return result
#nb.njit
def setdiff2d_nb(arr1, arr2):
# : build `delta` set using hashes
delta = {mul_xor_hash(arr2[0])}
for i in range(1, arr2.shape[0]):
delta.add(mul_xor_hash(arr2[i]))
# : compute the size of the result
n = 0
for i in range(arr1.shape[0]):
if mul_xor_hash(arr1[i]) not in delta:
n += 1
# : build the result
result = np.empty((n, arr1.shape[-1]), dtype=arr1.dtype)
j = 0
for i in range(arr1.shape[0]):
if mul_xor_hash(arr1[i]) not in delta:
result[j] = arr1[i]
j += 1
return result
I tried to prepare that for 1D arrays, but I have some problems/question with that.
At first, IDU what does mul_xor_hash exactly do, and if init and k are arbitrary selected or not
Why mul_xor_hash will not work without nb.njit:
File "C:/Users/Ali/Desktop/test - Copy - Copy.py", line 21, in mul_xor_hash
result = (result * k) ^ x
TypeError: ufunc 'bitwise_xor' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
IDK how to implement mul_xor_hash on 1D arrays (if it could), which I guess may make it faster more than for 2Ds, so I broadcast the input arrays to 2D by [None, :], which get the following error just for arr2:
print(mul_xor_hash(arr2[0]))
ValueError: new type not compatible with array
and what does delta do
I am searching the most efficient way in this regard. In the absence of better method than norok2 solution, how to prepare this solution for 1D arrays?
Understanding the hash-based solution
At first, IDU what does mul_xor_hash exactly do, and if init and k are arbitrary selected or not
mul_xor_hash is a custom hash function. Functions mixing xor and multiply (possibly with shifts) are known to be relatively fast to compute the hash of a raw data buffer. The multiplication tends to shuffle bits and the xor is used to somehow combine/accumulate the result in a fixed size small value (ie. the final hash). There are many different hashing functions. Some are faster than others, some cause more collisions than other in a given context. A fast hashing function causing too many collisions can be useless in practice as it would result in a pathological situation where all conflicting values needs to be compared. This is why fast hash functions are hard to implement.
init and k are parameter certainly causing the hash to be pretty balance. This is pretty common in such a hash function. k needs to be sufficiently big for the multiplication to shuffle bits and it should typically also be a prime number (values like power of two tends to increase collisions due to modular arithmetic behaviours). init plays a significant role only for very small arrays (eg. with 1 item): it helps to reduce collisions by xoring the final hash by a non-trivial constant. Indeed, if arr.size = 1, then result = (init * k) ^ arr[0] where init * k is a constant. Having an identity hash function equal to arr[0] is known to be bad since it tends to result in many collisions (this is a complex topic, but put it shortly, arr[0] can be divided by the number of buckets in the hash table for example). Thus, init should be a relatively big number and init * k should also be a big non-trivial value (a prime number is a good target value).
Why mul_xor_hash will not work without nb.njit
It depends of the input. The input needs to be a 1D array and have a raw size in byte divisible by 8 (eg. 64-bit items, 2n x 32-bit ones, 4n x 16-bit one or 8n 8-bit ones). Here is some examples:
mul_xor_hash(np.random.rand(10))
mul_xor_hash(np.arange(10)) # Do not work with 9
and what does delta do
It is a set containing the hash of the arr2 row so to find matching lines faster than comparing them without hashes.
how to prepare this solution for 1D arrays?
AFAIK, hashes are only use to avoid comparisons of rows but this is because the input is the 2D array. In 1D, there is no such a problem.
There is big catch with this method: it only works if there is no hash collisions. Otherwise, the implementation wrongly assumes that values are equal even if they are not! #norok explicitly mentioned it in the comments though:
Note that the collision handling for the hashings should also be implemented
Faster implementation
Using the 2D solution of #norok2 for 1D is not a good idea since hashes will not make it faster the way they are used. In fact, a set already use a hash function internally anyway. Not to mention collisions needs to be properly implemented (which is done by a set).
Using a set is a relatively good idea since it causes the complexity to be O(n + m) where n = len(arr1) and m = len(arr2). That being said, if arr1 is converted to a set, then it will be too big to fit in L1 cache (due to the size of arr1 in your case) resulting in slow cache misses. Additionally, the growing size of the set will cause values to be re-hashed which is not efficient. If arr2 is converted to a set, then the many hash table fetches will not be very efficient since arr2 is very small in your case. This is why this solution is sub-optimal.
One solution is to split arr1 in chunks and then build a set based on the target chunk. You can then check if a value is in the set or not efficiently. Building the set is still not very efficient due to the growing size. This problem is due to Python itself which do not provide a way to reserve some space for the data structure like other languages do (eg. C++). One solution to avoid this issue is simply to reimplement an hash-table which is not trivial and cumbersome. Actually, Bloom filters can be used to speed up this process since they can quickly find if there is no collision between the two sets arr1 and arr2 in average (though they are not trivial to implement).
Another optimization is to use multiple threads to compute the chunks in parallel since they are independent. That being said, the appending to the final array is not easy to do efficiently in parallel, especially since you do not want the order to be modified. One solution is to move away the copy from the parallel loop and do it serially but this is slow and AFAIK there is no simple way to do that in Numba currently (since the parallelism layer is very limited). Consider using native languages like C/C++ for an efficient parallel implementation.
In the end, hashing can be pretty complex and the speed up can be quite small compared to a naive implementation with two nested loops since arr2 only have few items and modern processors can compare values quickly using SIMD instructions (while hash-based method can hardly benefit from them on mainstream processors). Unrolling can help to write a pretty simple and fast implementation. Again, unfortunately, Numba use LLVM-Jit internally which appear to fail to vectorize such a simple code (certainly due to missing optimizations in either LLVM-Jit or even LLVM itself). As a result, the non vectorized code is finally a bit slower (rather than 4~10 times faster on a modern mainstream processor). One solution is to use a C/C++ code instead to do that (or possibly Cython).
Here is a serial implementation using basic Bloom filters:
#nb.njit('uint32(int32)')
def hash_32bit_4k(value):
return (np.uint32(value) * np.uint32(27_644_437)) & np.uint32(0x0FFF)
#nb.njit(['int32[:](int32[:], int32[:])', 'int32[:](int32[::1], int32[::1])'])
def setdiff1d_nb_faster(arr1, arr2):
out = np.empty_like(arr1)
bloomFilter = np.zeros(4096, dtype=np.uint8)
for j in range(arr2.size):
bloomFilter[hash_32bit_4k(arr2[j])] = True
cur = 0
for i in range(arr1.size):
# If the bloom-filter value is true, we know arr1[i] is not in arr2.
# Otherwise, there is maybe a false positive (conflict) and we need to check to be sure.
if bloomFilter[hash_32bit_4k(arr1[i])] and arr1[i] in arr2:
continue
out[cur] = arr1[i]
cur += 1
return out[:cur]
Here is an untested variant that should work for 64-bit integers (floating point numbers need memory views and possibly a prime constant too):
#nb.njit('uint64(int64)')
def hash_32bit_4k(value):
return (np.uint64(value) * np.uint64(67_280_421_310_721)) & np.uint64(0x0FFF)
Note that if all the values in the small array are contained in the main array in each loop, then we can speed up the arr1[i] in arr2 part by removing values from arr2 when we find them. That being said, collisions and findings should be very rare so I do not expect this to be significantly faster (not to mention it adds some overhead and complexity). If items are computed in chunks, then the last chunks can be directly copied without any check but the benefit should still be relatively small. Note that this strategy can be effective for the naive (C/C++) SIMD implementation previously mentioned though (it can be about 2x faster).
Generalization and parallel implementation
This section focus on the algorithm to use regarding the input size. It particularly details an SIMD-based implementation and discuss about the use of multiple threads.
First of all, regarding the value r, the best algorithm to use can be different. More specifically:
when r is 0, the best thing to do is to return the input array arr1 unmodified (possibly a copy to avoid issue with in-place algorithms);
when r is 1, we can use one basic loop iterating over the array, but the best implementation is likely to use np.where of Numpy which is highly optimized for that
when r is small like <10, then using a SIMD-based implementation should be particularly efficient, especially if the iteration range of the arr2-based loop is known at compile-time and is unrolled
for bigger r values that are still relatively small (eg. r < 1000 and r << n), the provided hash-based solution should be one of the best;
for larger r values with r << n, the hash-based solution can be optimized by packing boolean values as bits in bloomFilter and by using multiple hash-functions instead of one so to better handle collisions while being more cache-friendly (in fact, this is what actual bloom filters does); note that multi-threading can be used so speed up the lookups when r is huge and r << n;
when r is big and not much smaller than n, then the problem is pretty hard to solve efficiently and the best solution is certainly to sort both arrays (typically with a radix sort) and use a merge-based method to remove the duplicates, possibly with multiple threads when both r and n are huge (hard to implement).
Let's start with the SIMD-based solution. Here is an implementation:
#nb.njit('int32[:](int32[::1], int32[::1])')
def setdiff1d_nb_simd(arr1, arr2):
out = np.empty_like(arr1)
limit = arr1.size // 4 * 4
limit2 = arr2.size // 2 * 2
cur = 0
z32 = np.int32(0)
# Tile (x4) based computation
for i in range(0, limit, 4):
f0, f1, f2, f3 = z32, z32, z32, z32
v0, v1, v2, v3 = arr1[i], arr1[i+1], arr1[i+2], arr1[i+3]
# Unrolled (x2) loop searching for a match in `arr2`
for j in range(0, limit2, 2):
val1 = arr2[j]
val2 = arr2[j+1]
f0 += (v0 == val1) + (v0 == val2)
f1 += (v1 == val1) + (v1 == val2)
f2 += (v2 == val1) + (v2 == val2)
f3 += (v3 == val1) + (v3 == val2)
# Remainder of the previous loop
if limit2 != arr2.size:
val = arr2[arr2.size-1]
f0 += v0 == val
f1 += v1 == val
f2 += v2 == val
f3 += v3 == val
if f0 == 0: out[cur] = arr1[i+0]; cur += 1
if f1 == 0: out[cur] = arr1[i+1]; cur += 1
if f2 == 0: out[cur] = arr1[i+2]; cur += 1
if f3 == 0: out[cur] = arr1[i+3]; cur += 1
# Remainder
for i in range(limit, arr1.size):
if arr1[i] not in arr2:
out[cur] = arr1[i]
cur += 1
return out[:cur]
It turns out this implementation is always slower than the hash-based one on my machine since Numba clearly generate an inefficient for the inner arr2-based loop and this appears to come from broken optimizations related to the ==: Numba simply fail use SIMD instructions for this operation (for no apparent reasons). This prevent many alternative SIMD-related codes to be fast as long as they are using Numba.
Another issue with Numba is that np.where is slow since it use a naive implementation while the one of Numpy has been heavily optimized. The optimization done in Numpy can hardly be applied to the Numba implementation due to the previous issue. This prevent any speed up using np.where in a Numba code.
In practice, the hash-based implementation is pretty fast and the copy takes a significant time on my machine already. The computing part can be speed up using multiple thread. This is not easy since the parallelism model of Numba is very limited. The copy cannot be easily optimized with Numba (one can use non-temporal store but this is not yet supported by Numba) unless the computation is possibly done in-place.
To use multiple threads, one strategy is to first split the range in chunk and then:
build a boolean array determining, for each item of arr1, whether the item is found in arr2 or not (fully parallel)
count the number of item found by chunk (fully parallel)
compute the offset of the destination chunk (hard to parallelize, especially with Numba, but fast thanks to chunks)
copy the chunk to the target location without copying found items (fully parallel)
Here is an efficient parallel hash-based implementation:
#nb.njit('int32[:](int32[:], int32[:])', parallel=True)
def setdiff1d_nb_faster_par(arr1, arr2):
# Pre-computation of the bloom-filter
bloomFilter = np.zeros(4096, dtype=np.uint8)
for j in range(arr2.size):
bloomFilter[hash_32bit_4k(arr2[j])] = True
chunkSize = 1024 # To tune regarding the kind of input
chunkCount = (arr1.size + chunkSize - 1) // chunkSize
# Find for each item of `arr1` if the value is in `arr2` (parallel)
# and count the number of item found for each chunk on the fly.
# Note: thanks to page fault, big parts of `found` are not even written in memory if `arr2` is small
found = np.zeros(arr1.size, dtype=nb.bool_)
foundCountByChunk = np.empty(chunkCount, dtype=nb.uint16)
for i in nb.prange(chunkCount):
start, end = i * chunkSize, min((i + 1) * chunkSize, arr1.size)
foundCountInChunk = 0
for j in range(start, end):
val = arr1[j]
if bloomFilter[hash_32bit_4k(val)] and val in arr2:
found[j] = True
foundCountInChunk += 1
foundCountByChunk[i] = foundCountInChunk
# Compute the location of the destination chunks (sequential)
outChunkOffsets = np.empty(chunkCount, dtype=nb.uint32)
foundCount = 0
for i in range(chunkCount):
outChunkOffsets[i] = i * chunkSize - foundCount
foundCount += foundCountByChunk[i]
# Parallel chunk-based copy
out = np.empty(arr1.size-foundCount, dtype=arr1.dtype)
for i in nb.prange(chunkCount):
srcStart, srcEnd = i * chunkSize, min((i + 1) * chunkSize, arr1.size)
cur = outChunkOffsets[i]
# Optimization: we can copy the whole chunk if there is nothing found in it
if foundCountByChunk[i] == 0:
out[cur:cur+(srcEnd-srcStart)] = arr1[srcStart:srcEnd]
else:
for j in range(srcStart, srcEnd):
if not found[j]:
out[cur] = arr1[j]
cur += 1
return out
This implementation is the fastest for the target input on my machine. It is generally fast when n is quite big and the overhead to create threads is relatively small on the target platform (eg. on PCs but typically not computing servers with many cores). The overhead of the parallel implementation is significant so the number of core on the target machine needs to be at least 4 so the implementation can be significantly faster than the sequential implementation.
It may be useful to tune the chunkSize variable for the target inputs. If r << n, it is better to use a pretty big chunkSize. That being said, the number of chunk needs to be sufficiently big for multiple thread to operate on many chunks. Thus, chunkSize should be significantly smaller than n / numberOfThreads.
On my machine most of the time (65-70%) is spent in the final copy which is mostly memory-bound and can hardly be optimized further with Numba.
Results
Here are results on my i5-9600KF-based machine (with 6 cores):
setdif1d_np: 2.65 ms
setdif1d_in1d_np: 2.61 ms
setdiff1d_nb: 2.33 ms
setdiff1d_nb_simd: 1.85 ms
setdiff1d_nb_faster: 0.73 ms
setdiff1d_nb_faster_par: 0.49 ms
The best provided implementation is about 4~5 time faster than the other ones.
What I found is that hashing does not help,. It is just trick for 2D case, to convert 1d arrays to single numbers and put them as such in a set.
Below is method of norok2 I converted to 1d arrays (and added annotations for faster compilation).
Note that this is only slightly (20-30%) faster than the methods you already have. And of course after second function call, on first due to compilation it is slightly slower.
#nb.njit('int32[:](int32[:], int32[:])')
def setdiff1d_nb(arr1, arr2):
delta = set(arr2)
# : build the result
result = np.empty(len(arr1), dtype=arr1.dtype)
j = 0
for i in range(arr1.shape[0]):
if arr1[i] not in delta:
result[j] = arr1[i]
j += 1
return result[:j]

How should I improve Python/Cython performance? Parallelization/memoryviews/numpy?

My task: take 3 lists of ints, each with some multiplier, and see if the elements can be rearranged to make two lists (with larger multipliers).
I have code that does this - looped over my whole data set, it takes about 15 seconds: (EDIT: fixed errors)
%%cython
cdef bint my_check(
list pattern1,
list pattern2,
list pattern3,
int amount1,
int amount2,
int amount3
):
cdef dict all_items = dict()
cdef int i, total_amount = amount1 + amount2 + amount3, m1, m2
cdef bint bad_split = False
# Pool the items together.
for i in range(len(pattern1)):
all_items[pattern1[i]] = all_items.get(pattern1[i],0) + amount1
for i in range(len(pattern2)):
all_items[pattern2[i]] = all_items.get(pattern2[i],0) + amount2
for i in range(len(pattern3)):
all_items[pattern3[i]] = all_items.get(pattern3[i],0) + amount3
# Iterate through possible split points:
for m1 in range(total_amount//2, total_amount):
m2 = total_amount - m1
# Split items into those with quantities divisible at this point and those without
divisible = {i:all_items[i] for i in all_items if all_items[i]%m1 == 0}
not_divisible = {i:all_items[i] for i in all_items if all_items[i]%m1 != 0}
# Check that all of the element amounts that are not divisible by m1 are divisible by m2.
for i in not_divisible:
if not_divisible[i]%m2 != 0:
bad_split = True
break
# If there is an element that doesn't divide by either, try the next split value.
if bad_split:
continue
items1 = {i:divisible[i]//m1 for i in divisible}
items2 = {i:not_divisible[i]//m2 for i in not_divisible}
if <some other stuff here>:
return True
# Tried all of the split points
return False
Then if this returns True, I run another function to do the combination. On my data set, the my_check() function is being called > 150,000 times (and taking the bulk of the time) and the other function < 500 times, so I'm not too concerned with optimizing that one.
I'd like to parallelize this to improve the performance, but what I've found:
my first thought was to use numpy functions to take advantage of vectorization, by converting all_items to a numpy array, using np.mod() and np.logical_not() to split the items, and other numpy functions in the last if clause, but that blows the time up by 3-4x compared to using the dict comprehension
if I switch the m1 range to a Cython prange, the compiler complained about using Python objects without the GIL. I switched the dicts to cdef'd numpy arrays, but that was even slower. I tried using memoryviews, but they don't seem to be easily manipulated? I read in another question here that slices can't be assigned to variables, so I don't know how I'd work with them. It won't let me cdef new variables inside the for loop.
Since I'm running at different values of m1, and terminating as soon as any of them return True, it should be parallelizable without worrying about race conditions.
What should my approach be here? Numpy? Cython? Something else?
I'm happy to post more detailed errors from any of my attempts, but figured that posting them all would get overwhelming. I haven't been able to get profiling or line profiling working for this - I've added the relevant # cython: statements to the top of the Jupyter notebook cell, but it doesn't find anything when I run it.
EDIT:
Per #DavidW's answer I've replaced the middle chunk of code with the following, which cuts the time in half:
items1 = dict()
items2 = dict()
bad_split = False
for k,v in all_items.items():
if v % m1 == 0:
items1[k] = v//m1
elif v % m2 == 0:
items2[k] = v//m2
else:
bad_split = True
break
I'd still like to find some way of taking advantage of my multi-core processor if that's possible.
There's definitely some improvements you can make to the loops that doesn't change the fundamental approach but may be faster. I haven't timed these so it's worth doing that rather than taking my word for it.
for i in range(len(pattern1)):
all_items[pattern1[i] = all_items.get(pattern1[i],0) + amount1
(Ignoring the syntax error). It's generally more ideomatic to iterate by item rather than over a range, and it avoids two lookups (sometimes that isn't true in Cython, for example iterating over numpy arrays, but for a list it's probably true):
for pattern1_i in pattern1:
all_items[pattern1_i] = all_items.get(pattern1_i,0) + amount1
More significantly you have two loops:
divisible = {i:all_items[i] for i in all_items if all_items[i]//m1 == 0}
not_divisible = {i:all_items[i] for i in all_items if all_items[i]//m1 != 0}
You're wasting a lot of time doing dict-lookups when you could iterate directly over both keys and values. For example
divisible = {k: v for k, v in all_items.items() if v//m1 == 0}
But you're also looping over the dictionary twice and performing the same test twice.
divisible = {}
not_divisible = {}
for k, v in all_items.items():
if v//m1 == 0:
divisible[k] = v
else:
not_divisible[k] = v
It might well be possible to translate your algorithm to something involving Numpy arrays, but it's a fairly significant change and beyond my interest here.
Addendum: I'm increasingly reluctant to recommend people use C++ classes in Cython these days. Mainly because a) it can often lead to quite awkward code, b) people tend to use it in a cargo-culty way because "it's C++ so it must be faster than Python objects, and c) people tend to forgot about the cost of converting their objects to/from C++ at the start and end of every function.
However, in this case it might actually be a good choice, since your dict objects are uniformly typed, and entirely contained with a single function. The key substitution is dict -> unordered_map.
What you want to do (rough outline) is
from libcpp.unordered_map cimport unordered_map
Then type all_items, items1 and items2 as cdef unordered_map[int, int (I think...). You do this typing outside the loop. The rest of your code then remains largely the same (you may need to find a substitute for dict.get...).
Once you've got it working as a serial calculation, you should be able to
turn your for m1 in range(total_amount//2, total_amount): into a prange loop, and assuming everything is correctly typed then this should work in parallel. Obviously if <some other stuff here> is a big unknown.
You must treat all_items as strictly read-only during the loop to avoid race-conditions. However, items1 and items2 should be correctly identified as loop-local variables by Cython I hope.
Here's a fairly similar answer to use as a starting point. For future readers: please think twice about whether you really need to convert all your Python objects to C++ ones; you probably don't

Is it more performant to access function from variable?

I was reading sources of Python statistics module and saw strage variable partials_get = partials.get which then was used once in for loop partials[d] = partials_get(d, 0) + n.
def _sum(data, start=0):
count = 0
n, d = _exact_ratio(start)
partials = {d: n}
partials_get = partials.get # STRANGE VARIABLE
T = _coerce(int, type(start))
for typ, values in groupby(data, type):
T = _coerce(T, typ) # or raise TypeError
for n, d in map(_exact_ratio, values):
count += 1
partials[d] = partials_get(d, 0) + n # AND IT'S USAGE
if None in partials:
# The sum will be a NAN or INF. We can ignore all the finite
# partials, and just look at this special one.
total = partials[None]
assert not _isfinite(total)
else:
# Sum all the partial sums using builtin sum.
# FIXME is this faster if we sum them in order of the denominator?
total = sum(Fraction(n, d) for d, n in sorted(partials.items()))
return (T, total, count)
So my question: Why not just write partials[d] = partials.get(d, 0) + n? Is it slower than storing and calling function from variable?
partials.get has to search for the get attribute, starting with the object's dictionary and then going to the dictionary of the class and its parent classes. This will be done each time through the loop.
Assigning it to a variable does this lookup once, rather than repeating it.
This is a microoptimization that's typically only significant if the loop has many repetitions. The statistics library often processes large data sets, so it's reasonable here. It's rarely needed in ordinary application code.
Short answer: yes.
Python is an interpreted language, and while dictionary/attribute access is blazingly fast and very optimized, it still incurs a hit.
Since they are running this in a tight loop, they are taking the slight performance advantage of removing the "dot" from accessing partials.get.
There are other slight improvements from doing this in other cases where the variable is enough of a hint to the compiler (for cpython at least) to ensure this stays local, but I'm not sure this is the case here.

How to shuffle an array with n entries without generating range(n)

So, I'm looking at python and I have a large 2d numpy array of data, and I want to take m rows of this large data matrix. I've looked into random.sample, and numpy.random.shuffle and numpy.random.permutation, all of these work, but usually they return the whole permutation or at least generate the entire range(n). If I had a very large dataset, then doing something like
data = numpy.random.uniform((n,100))
myvec = data[random.sample(range(n),m),:]
will allocate a vector range(n) which blows up pretty fast. So i thought I could use xrange, which return a generator, but hey, you can't just get any element from an generator, that's not the way they work.
I tried it out, and it works.
data = numpy.random.uniform((n,100))
myvec = data[random.sample(xrange(n),m),:]
Any idea how?
UPDATE:
I can use
samp = random.sample(range(n),10)
for n up to 100000000 before I get a memory error. If i use
samp = random.sample(xrange(n),10)
on the other hand, I only start getting errors because of int converson to C, namely, the int gets too long to get converted to C, at around 1000000000. Sure it's only a factor of 10, but I'm curious. the xrange variant is also much faster.
def sample(n, m):
d = set()
while len(d) < m:
d.add(randrange(n))
return d
>>> sample(100000000000000000000000000000000000, 10)
set([5577049102993258248888250482046894L, 86044086231860190654588187118815513L, 2021737354726858669049814270580972L, 6253501639432326715043836478191628L, 5306460388221333758367322518700483L, 62195356583363524099133566314034473L, 376650426515181012918370326724858L, 80588135672357701239461833469588557L, 1978959860575617450893346333245569L, 41904683348442252013350548717573039L])
Note that simple {randrange(n) for _ in range(m)} will do the job with very high probability.
So it turns out xrange and iterators can be accessed by indexing, which is exactly what random.sample() uses. So that's how it works.
a = xrange(10)
print a[5] #this works.
Elazar's solution works just as well though.

Efficient generic Python memoize

I have a generic Python memoizer:
cache = {}
def memoize(f):
"""Memoize any function."""
def decorated(*args):
key = (f, str(args))
result = cache.get(key, None)
if result is None:
result = f(*args)
cache[key] = result
return result
return decorated
It works, but I'm not happy with it, because sometimes it's not efficient. Recently, I used it with a function that takes lists as arguments, and apparently making keys with whole lists slowed everything down. What is the best way to do that? (i.e., to efficiently compute keys, whatever the args, and however long or complex they are)
I guess the question is really about how you would efficiently produce keys from the args and the function for a generic memoizer - I have observed in one program that poor keys (too expensive to produce) had a significant impact on the runtime. My prog was taking 45s with 'str(args)', but I could reduce that to 3s with handcrafted keys. Unfortunately, the handcrafted keys are specific to this prog, but I want a fast memoizer where I won't have to roll out specific, handcrafted keys for the cache each time.
First, if you're pretty sure that O(N) hashing is reasonable and necessary here, and you just want to speed things up with a faster algorithm than hash(str(x)), try this:
def hash_seq(iterable):
result = hash(type(iterable))
for element in iterable:
result ^= hash(element)
return result
Of course this won't work for possibly-deep sequences, but there's an obvious way around that:
def hash_seq(iterable):
result = hash(type(iterable))
for element in iterable:
try:
result ^= hash(element)
except TypeError:
result ^= hash_seq(element)
return result
I don't think sure this is a good-enough hash algorithm, because it will return the same value for different permutations of the same list. But I am pretty sure that no good-enough hash algorithm will be much faster. At least if it's written in C or Cython, which you'll probably ultimately want to do if this is the direction you're going.
Also, it's worth noting that this will be correct in many cases where str (or marshal) will not—for example, if your list may have some mutable element whose repr involves its id rather than its value. However, it's still not correct in all cases. In particular, it assumes that "iterates the same elements" means "equal" for any iterable type, which obviously isn't guaranteed to be true. False negatives aren't a huge deal, but false positives are (e.g., two dicts with the same keys but different values may spuriously compare equal and share a memo).
Also, it uses no extra space, instead of O(N) with a rather large multiplier.
At any rate, it's worth trying this first, and only then deciding whether it's worth analyzing for good-enough-ness and tweaking for micro-optimizations.
Here's a trivial Cython version of the shallow implementation:
def test_cy_xor(iterable):
cdef int result = hash(type(iterable))
cdef int h
for element in iterable:
h = hash(element)
result ^= h
return result
From a quick test, the pure Python implementation is pretty slow (as you'd expect, with all that Python looping, compared to the C looping in str and marshal), but the Cython version wins easily:
test_str( 3): 0.015475
test_marshal( 3): 0.008852
test_xor( 3): 0.016770
test_cy_xor( 3): 0.004613
test_str(10000): 8.633486
test_marshal(10000): 2.735319
test_xor(10000): 24.895457
test_cy_xor(10000): 0.716340
Just iterating the sequence in Cython and doing nothing (which is effectively just N calls to PyIter_Next and some refcounting, so you're not going to do much better in native C) is 70% of the same time as test_cy_xor. You can presumably make it faster by requiring an actual sequence instead of an iterable, and even more so by requiring a list, although either way it might require writing explicit C rather than Cython to get the benefits.
Anyway, how do we fix the ordering problem? The obvious Python solution is to hash (i, element) instead of element, but all that tuple manipulation slows down the Cython version up to 12x. The standard solution is to multiply by some number between each xor. But while you're at it, it's worth trying to get the values to spread out nicely for short sequences, small int elements, and other very common edge cases. Picking the right numbers is tricky, so… I just borrowed everything from tuple. Here's the complete test.
_hashtest.pyx:
cdef _test_xor(seq):
cdef long result = 0x345678
cdef long mult = 1000003
cdef long h
cdef long l = 0
try:
l = len(seq)
except TypeError:
# NOTE: This probably means very short non-len-able sequences
# will not be spread as well as they should, but I'm not
# sure what else to do.
l = 100
for element in seq:
try:
h = hash(element)
except TypeError:
h = _test_xor(element)
result ^= h
result *= mult
mult += 82520 + l + l
result += 97531
return result
def test_xor(seq):
return _test_xor(seq) ^ hash(type(seq))
hashtest.py:
import marshal
import random
import timeit
import pyximport
pyximport.install()
import _hashtest
def test_str(seq):
return hash(str(seq))
def test_marshal(seq):
return hash(marshal.dumps(seq))
def test_cy_xor(seq):
return _hashtest.test_xor(seq)
# This one is so slow that I don't bother to test it...
def test_xor(seq):
result = hash(type(seq))
for i, element in enumerate(seq):
try:
result ^= hash((i, element))
except TypeError:
result ^= hash(i, hash_seq(element))
return result
smalltest = [1,2,3]
bigtest = [random.randint(10000, 20000) for _ in range(10000)]
def run():
for seq in smalltest, bigtest:
for f in test_str, test_marshal, test_cy_xor:
print('%16s(%5d): %9f' % (f.func_name, len(seq),
timeit.timeit(lambda: f(seq), number=10000)))
if __name__ == '__main__':
run()
Output:
test_str( 3): 0.014489
test_marshal( 3): 0.008746
test_cy_xor( 3): 0.004686
test_str(10000): 8.563252
test_marshal(10000): 2.744564
test_cy_xor(10000): 0.904398
Here are some potential ways to make this faster:
If you have lots of deep sequences, instead of using try around hash, call PyObject_Hash and check for -1.
If you know you have a sequence (or, even better, specifically a list), instead of just an iterable, PySequence_ITEM (or PyList_GET_ITEM) is probably going to be faster than the PyIter_Next implicitly used above.
In either case, once you start calling C API calls, it's usually easier to drop Cython and just write the function in C. (You can still use Cython to write a trivial wrapper around that C function, instead of manually coding up the extension module.) And at that point, just borrow the tuplehash code directly instead of reimplementing the same algorithm.
If you're looking for a way to avoid the O(N) in the first place, that's just not possible. If you look at how tuple.__hash__, frozenset.__hash__, and ImmutableSet.__hash__ work (the last one is pure Python and very readable, by the way), they all take O(N). However, they also all cache the hash values. So, if you're frequently hashing the same tuple (rather than non-identical-but-equal ones), it approaches constant time. (It's O(N/M), where M is the number of times you call with each tuple.)
If you can assume that your list objects never mutate between calls, you can obviously do the same thing, e.g., with a dict mapping id to hash as an external cache. But in general, that obviously isn't a reasonable assumption. (If your list objects never mutate, it would be easier to just switch to tuple objects and not bother with all this complexity.)
But you can wrap up your list objects in a subclass that adds a cached hash value member (or slot), and invalidates the cache whenever it gets a mutating call (append, __setitem__, __delitem__, etc.). Then your hash_seq can check for that.
The end result is the same correctness and performance as with tuples: amortized O(N/M), except that for tuple M is the number of times you call with each identical tuple, while for list it's the number of times you call with each identical list without mutating in between.
You could try a couple of things:
Using marshal.dumps instead of str might be slightly faster (at least on my machine):
>>> timeit.timeit("marshal.dumps([1,2,3])","import marshal", number=10000)
0.008287056301007567
>>> timeit.timeit("str([1,2,3])",number=10000)
0.01709315717356219
Also, if your functions are expensive to compute, and could possibly return None themselves, then your memoizing function will be re-computing them each time (I'm possibly reaching here, but without knowing more I can only guess).
Incorporating these 2 things gives:
import marshal
cache = {}
def memoize(f):
"""Memoize any function."""
def decorated(*args):
key = (f, marshal.dumps(args))
if key in cache:
return cache[key]
cache[key] = f(*args)
return cache[key]
return decorated

Categories

Resources