I am using the method of lines with 'solve_ivp` to solve a nonlinear PDE:
#njit(fastmath=True,error_model="numpy",cache=True)
def thinFilmEq(t,h,dx,Ma,phiFun,tempFun):
phi = phiFun(h)
temperature = tempFun(h)
hxx = (np.roll(h,1) - 2*h + np.roll(h,-1))/dx**2 # use np.roll as I'm implementing periodic BC
p = phi - hxx
px = (np.roll(p,-1) - np.roll(p,1))/(2*dx)
Tx = (np.roll(temperature,-1) - np.roll(temperature,1))/(2*dx)
flux = h**3*px/3 + Ma*h**2*Tx/2
dhdt = (np.roll(flux,-1) - np.roll(flux,1))/(2*dx)
return dhdt
I get the following error: TypingError: non-precise type pyobject
[1] During: typing of argument at C:/Users/yhcha/method_of_lines/test_01_thinFilmEq.py (28) I suspect it is due to phiFun and tempFun. They are functions which I supply at the time of calling. I make the functions arguments to the dhdt function just to keep things more general. When I try to remove phiFun and tempFun and explicitly give the function form inside thinFilmEq, the error goes away.
Then, I see the following error TypingError: Use of unsupported NumPy function 'numpy.roll' or unsupported use of the function. I thought maybe np.roll is not supported although it is included in the official website. I tried to 'enlarge' the array to somehow manually apply the same thing as np.roll when dealing with the finite difference for periodic BC:
def augment(x):
x2 = np.empty(len(x)+2)
x2[1:-1] = x
x2[0] = x[-1]
x2[-1] = x[0]
return x2
H = augment(x)
hx = (H[2:]-[H:-2])/dx # use this instead of hx=(roll(h,-1)-roll(h,1))/dx
My questions are:
It seems that I can get numba to work, at the expense of making the codes less generally (cannot supply an arbitrary function like phiFun and elegant (e.g. cannot use a one-liner with np.roll). Are there ways to get around it or is it just the price I need to pay when using numba to 'compile' the code?
The original version without numba is close to 10x slower than the Matlab version I coded, and the numba version is still around 3-4 times slower than Matlab. I don't really expect scipy to outperform Matlab, but are there other ways to speedup the code to bridge the gap?
Related
I am trying to use JAX on another SO question to evaluate JAX applicability and performance on the code (There are useful information on that about what the code does). For this purpose, I have modified the code by jax.numpy (jnp) equivalent methods (Substituting NumPy related codes with their equivalent jnp codes were not as easy as I thought due to my little experience by JAX, and may be it could be written better). Finally, I checked the results with the ex-code (optimized algorithm) and the results were the same, but it takes 7.5 seconds by JAX, which took 0.10 seconds by the ex-one for a sample case (using Colab). I think this long runtime may be related to for loop in the code, which might be substituted by JAX related modules e.g. fori-loop or vectorization and …; but I don’t know what changes, and how, must be done to make this code satisfying in terms of performance and speed (using JAX).
import numpy as np
from scipy.spatial import cKDTree, distance
import jax
from jax import numpy as jnp
jax.config.update("jax_enable_x64", True)
# ---------------------------- input data ----------------------------
""" For testing by prepared files:
radii = np.load('a.npy')
poss = np.load('b.npy')
"""
rnd = np.random.RandomState(70)
data_volume = 1000
radii = rnd.uniform(0.0005, 0.122, data_volume)
dia_max = 2 * radii.max()
x = rnd.uniform(-1.02, 1.02, (data_volume, 1))
y = rnd.uniform(-3.52, 3.52, (data_volume, 1))
z = rnd.uniform(-1.02, -0.575, (data_volume, 1))
poss = np.hstack((x, y, z))
# --------------------------------------------------------------------
# #jax.jit
def ends_gap(poss, dia_max):
particle_corsp_overlaps = jnp.array([], dtype=np.float64)
# kdtree = cKDTree(poss) # Using SciPy
for particle_idx in range(len(poss)):
cur_point = poss[particle_idx]
# nears_i_ind = jnp.array(kdtree.query_ball_point(cur_point, r=dia_max, return_sorted=True), dtype=np.int64) # Using SciPy
# Using NumPy
unshared_idx = jnp.delete(jnp.arange(len(poss)), particle_idx)
poss_without = poss[unshared_idx]
dist_max = radii[particle_idx] + radii.max()
lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] + dist_max
ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dist_max
ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] + dist_max
uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dist_max
lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] + dist_max
uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dist_max
nears_i_ind = jnp.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]
# assert len(nears_i_ind) > 0
# if len(nears_i_ind) <= 1:
# continue
nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
# dist_i = distance.cdist(poss[tuple(nears_i_ind[None, :])], cur_point[None, :]).squeeze() # Using SciPy
dist_i = jnp.linalg.norm(poss[tuple(nears_i_ind[None, :])] - cur_point[None, :], axis=-1) # Using NumPy
contact_check = dist_i - (radii[tuple(nears_i_ind[None, :])] + radii[particle_idx])
connected = contact_check[contact_check <= 0]
particle_corsp_overlaps = jnp.concatenate((particle_corsp_overlaps, connected))
contacts_ind = jnp.where(contact_check <= 0)[0]
contacts_sec_ind = jnp.array(nears_i_ind)[contacts_ind]
sphere_olps_ind = jnp.sort(contacts_sec_ind)
ends_ind_mod_temp = jnp.array([jnp.repeat(particle_idx, len(sphere_olps_ind)), sphere_olps_ind], dtype=np.int64).T
if particle_idx > 0: # ---> these 4-lines perhaps be better to be substituted by just one-line list appending as "ends_ind.append(ends_ind_mod_temp)"
ends_ind = jnp.concatenate((ends_ind, ends_ind_mod_temp))
else:
ends_ind = jnp.array(ends_ind_mod_temp, dtype=np.int64)
ends_ind_org = ends_ind
ends_ind, ends_ind_idx = jnp.unique(jnp.sort(ends_ind_org), axis=0, return_index=True)
gap = jnp.array(particle_corsp_overlaps)[ends_ind_idx]
return gap, ends_ind, ends_ind_idx, ends_ind_org
I have tried to use #jax.jit on this code, but it shows errors: TracerArrayConversionError or ConcretizationTypeError on COLAB TPU:
Using SciPy:
TracerArrayConversionError: The numpy.ndarray conversion method
array() was called on the JAX Tracer object Traced<ShapedArray(float64[1000,3])>with<DynamicJaxprTrace(level=0/1)>
While tracing the function ends_gap at
:1 for jit, this concrete value was not
available in Python because it depends on the value of the argument
'poss'. See
https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
Using NumPy:
ConcretizationTypeError: Abstract tracer value encountered where
concrete value is expected:
Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)> The
size argument of jnp.nonzero must be statically specified to use
jnp.nonzero within JAX transformations. While tracing the function
ends_gap at :1 for jit, this
concrete value was not available in Python because it depends on the
values of the arguments 'poss' and 'dia_max'.
See
https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
I would be appreciated for any help to speed up this code by passing these problems using JAX (and jax.jit if possible). How to utilize JAX to have the best performances on both CPU and GPU or TPU?
Prepared sample test data:
a.npy = Radii data
b.npy = Poss data
Updates
The main aim of this issue is how to modify the code for gaining the best performance of that using JAX library
I have commented the SciPy related lines on the code based on jakevdp answer and uncomment the equivalent NumPy related sections.
For getting better answer, I'm numbering some important subjects:
Is scikit-learn BallTree related methods compatible with JAX?? This methods can be a good alternative for SciPy cKDTree in terms of memory usage (for probable vectorizations).
How to best handle the loop section in the code, using fori_loop or by putting code lines of the loop inside a function and then vectorizing, jitting or …??
I had problem preparing the code for using fori_loop. What has been done for using fori_loop can be understood from the following code line, where particle_corsp_overlaps was the input of the defined function (this function just contains the loop section). It will be useful to show how to do that if using fori_loop is recommended.
particle_corsp_overlaps, ends_ind = jax.lax.fori_loop(0, len(poss), jax_loop, particle_corsp_overlaps)
I put the NumPy section in a function for jitting by #jax.jit to check its capability to improve performance (I don't know how much it can help). It got an error ConcretizationTypeError (--> Shape depends on Traced Value) relating to poss. So, I tried to use #partial(jax.jit, static_argnums=0) decorator by importing partial from functools, but now I am getting the following error; how to solve it if this way is recommended e.g. for:
#partial(jax.jit, static_argnums=0)
def ends_gap(poss):
for particle_idx in range(len(poss)):
cur_point = poss[particle_idx]
unshared_idx = jnp.delete(jnp.arange(len(poss)), particle_idx)
poss_without = poss[unshared_idx]
dist_max = radii[particle_idx] + radii.max()
lx_limit_idx = poss_without[:, 0] <= poss[particle_idx][0] + dist_max
ux_limit_idx = poss_without[:, 0] >= poss[particle_idx][0] - dist_max
ly_limit_idx = poss_without[:, 1] <= poss[particle_idx][1] + dist_max
uy_limit_idx = poss_without[:, 1] >= poss[particle_idx][1] - dist_max
lz_limit_idx = poss_without[:, 2] <= poss[particle_idx][2] + dist_max
uz_limit_idx = poss_without[:, 2] >= poss[particle_idx][2] - dist_max
nears_i_ind = jnp.where(lx_limit_idx & ux_limit_idx & ly_limit_idx & uy_limit_idx & lz_limit_idx & uz_limit_idx)[0]
nears_i_ind = nears_i_ind[nears_i_ind != particle_idx]
dist_i = jnp.linalg.norm(poss[tuple(nears_i_ind[None, :])] - cur_point[None, :], axis=-1)
ValueError: Non-hashable static arguments are not supported. An error
occured during a call to 'nearest_neighbors_jax' while trying to hash
an object of type <class 'jaxlib.xla_extension.DeviceArray'>, [[
8.42519143e-01 1.37693422e+00 -7.97775882e-01] [-3.31436445e-01 -1.67346250e+00 -8.61069684e-01] [-1.57500126e-01 -1.17502591e+00 -7.48879998e-01]]. The error was: TypeError: unhashable type: 'DeviceArray'
I did not put the total loop body into the function due to stuck in this short defined function. Creating a function with all the loop body, which can be jitted or …, is of interest if possible.
Can 4-lines ends_ind related if-else statement be written in just one line using jax methods to avoid probable problems with if during jitting or …?
JAX cannot be used to optimize general numpy/scipy code, however it can be used to optimize/compile code written in JAX.
Your example revolves around the use of scipy's cKDTree. This is not implemented in JAX, and so it cannot be optimized or compiled in JAX, and using it within a jitted function will lead to the error you're seeing. If you want to use a KD tree with JAX, you'll have to find one implemented in JAX. I don't know of any such code.
As for why the code becomes slower when you replace np with jnp here, it's because you're really only using JAX as an alternate array container. Every time you pass a JAX array to a cKDTree call, it has to be converted to a numpy array, and then the result has to be converted back to a JAX array. This extra movement of data adds overhead to each call, making the result slower. This is not because JAX itself is slow, it's because you're not really using JAX as anything but a way of temporarily storing your data before converting it back to numpy.
Generally this kind of overhead can be reduced by wrapping the function in jax.jit, but as mentioned before, this is not compatible with non-jax code like scipy.spatial.cKDTree.
I suspect your best course of action would be to avoid using JAX and just use numpy/scipy and the cKDTree. I don't know of any JAX-compatible implementation of tree-based neighbor search algorithms, and full brute force approaches would not be competitive with cKDTree for large arrays.
I looked into this earlier this year. I had an existing numba implementation and wanted to port it to jax. I started (repo here) but abandoned the project when I realized that jax's jit performance is currently woeful compared to numba for these types of algorithms with loops and index updates. I believe it may be related to this issue, but I could certainly be wrong.
For the moment, if you want to execute KDTree operations inside a jitted function you can use jax.experimental.host_callback.call to wrap an existing implementation. It won't speed up the external function, but jax's jit may improve other aspects of the jitted code.
I am trying to refactor this code in order to minimize its runtime and memory usage (if possible)
for i in range(gbl.NumStoreRows):
cal_effects[i,:,:len(orig_cols)] = cal_effects_vals - **Use ~1gb memory on this line**
priors[i,:len(orig_cols)] = orig_prior_coeffs
priors_SE[i,:len(orig_cols)] = orig_prior_SE
It is only the first operation in the loop which is time/memory intensive, I tried splitting the the memory/runtime intensive line from the other two and created two separate loops. - just made it a second slower, and no memory impact.
I tried to create a jit function for this code block then, but the application stops running later on in the code with error message. - It just stops on one of the LoadFunctions(), so I think jit might be altering the output or my function is incorrectly structured.
Variations of my jit function
Variation 1
#jit
def populate_cal_effects(cal_effects_vals):
for i in range(gbl.NumStoreRows):
cal_effects[i,:,:len(orig_cols)] = cal_effects_vals
populate_cal_effects(cal_effects_vals)
for i in range(gbl.NumStoreRows):
priors[i,:len(orig_cols)] = orig_prior_coeffs
priors_SE[i,:len(orig_cols)] = orig_prior_SE
Variation 2: Adding a return statement to the function
#jit
def populate_cal_effects(cal_effects_vals):
for i in range(gbl.NumStoreRows):
cal_effects[i,:,:len(orig_cols)] = cal_effects_vals
return cal_effects[i,:,:len(orig_cols)]
Variation 3: add the operations from the other for loop to the function
This was the method I expected to be fastest and not affect data output
#jit(parallel=True)
def populate_cal_effects(cal_effects_vals):
for i in prange(gbl.NumStoreRows):
cal_effects[i,:,:len(orig_cols)] = cal_effects_vals
priors[i,:len(orig_cols)] = orig_prior_coeffs
priors_SE[i,:len(orig_cols)] = orig_prior_SE
I wanted to utilize parallel mode and use prange for the loop, but I cannot get this to work.
Context/Other:
I have defined this function inside the main load function. - My next step is too move it out of the Load function and re-run.
If this method doesn't work I was thinking of trying to process in parallel (multiple cores) - not machines. using Dask.
Any pointers on this would be great, maybe I am wasting my time and this is not optimizable, if so, do let me know
Steps to reproduce
gbl.NumstoreRows = 866 (# of stores)
All data types are numpy arrays
cal_effects = np.zeros((gbl.NumStoreRows, n_days, n_cal_effects), dtype=np.float64)
priors = np.zeros((gbl.NumStoreRows, n_cal_effects), dtype=np.float64)
priors_SE = np.zeros((gbl.NumStoreRows, n_cal_effects), dtype=np.float64)
To illustrate my comment:
for i in range(gbl.NumStoreRows):
cal_effects[i,:,:len(orig_cols)] = cal_effects_vals - **Use ~1gb memory on this line**
priors[i,:len(orig_cols)] = orig_prior_coeffs
priors_SE[i,:len(orig_cols)] = orig_prior_SE
from this I deduce cal_effects is a (N,M,L) shape array; priors is (N,L)
big_arr = np.zeros((N,M,L))
arr = np.zeros((N,L)
for i in range(N):
big_arr[i, :, :l] = np.ones((M,l))
arr[i, :l] = np.ones(l)
And apparently np.ones((M,l)) is large, on the order of 1gb.
Do cal_effects_vals and orig_prior_coeffs differ with i. It isn't obvious from the code. If they don't differ, why iterate on i?
So this isn't an answer, but it may help you write a question that is more succinct, and attract more answers.
I'm starting with numba and my first goal is to try and accelerate a not so complicated function with a nested loop.
Given the following class:
class TestA:
def __init__(self, a, b):
self.a = a
self.b = b
def get_mult(self):
return self.a * self.b
and a numpy ndarray that contains class TestA objects. Dimension (N,) where N is usually ~3 million in length.
Now given the following function:
def test_no_jit(custom_class_obj_container):
container_length = len(custom_class_obj_container)
sum = 0
for i in range(container_length):
for j in range(i + 1, container_length):
obj_i = custom_class_obj_container[i]
obj_j = custom_class_obj_container[j]
sum += (obj_i.get_mult() + obj_j.get_mult())
return sum
I've tried to play around numba to get it to work with the function above however I cannot seem to get it to work with nopython=True flag, and if it's set to false, then the runtime is higher than the no-jit function.
Here is my latest try in trying to jit the function (also using nb.prange):
#nb.jit(nopython=False, parallel=True)
def test_jit(custom_class_obj_container):
container_length = len(custom_class_obj_container)
sum = 0
for i in nb.prange(container_length):
for j in nb.prange(i + 1, container_length):
obj_i = custom_class_obj_container[i]
obj_j = custom_class_obj_container[j]
sum += (obj_i.get_mult() + obj_j.get_mult())
return sum
I've tried to search around but I cannot seem to find a tutorial of how to define a custom class in the signature, and how would I go in order to accelerate a function of that sort and get it to run on GPU and possibly (any info regarding that matter would be highly appreciated) to get it to run with cuda libraries - which are installed and ready to use (previously used with tensorflow)
The numba docs give an example of creating a custom type, even for nopython mode: https://numba.pydata.org/numba-doc/latest/extending/interval-example.html
In your case though, unless this is a really slimmed down version of what you actually want to do, it seems like the easiest approach would be to re-use existing types. Additionally, the construction of a 3M length object array is going to be slow, and produce fragmented memory (as the objects are not being stored in contiguous blocks).
An example of how using record arrays might be used to solve the problem:
x_dt = np.dtype([('a', np.float64),
('b', np.float64)])
n = 30000
buf = np.arange(n*2).reshape((n, 2)).astype(np.float64)
vec3 = np.recarray(n, dtype=x_dt, buf=buf)
#numba.njit
def mult(a):
return a.a * a.b
#numba.jit(nopython=True, parallel=True)
def sum_of_prod(vector):
sum = 0
vector_len = len(vector)
for i in numba.prange(vector_len):
for j in numba.prange(i + 1, vector_len):
sum += mult(vector[i]) + mult(vector[j])
return sum
sum_of_prod(vec3)
FWIW, I'm no numba expert. I found this question when searching for how to implement a custom type in numba for non-numerical stuff. In your case, because this is highly numerical, I think a custom type is probably overkill.
As I'm really struggleing to get from R-code, to Python code, I would like to ask some help. The code I want to use has been provided to my from withing the mathematics forum of stackexchange.
https://math.stackexchange.com/questions/2205573/curve-fitting-on-dataset
I do understand what is going on. But I'm really having a hard time trying to solve the R-code, as I have never seen anything of it. I have written the function to return the sum of squares. But I'm stuck at how I could use a function similar to the optim function. And also I don't really like the guesswork at the initial values. I would like it better to run and re-run a type of optim function untill I get the wanted result, because my needs for a nearly perfect curve fit are really high.
def model (par,x):
n = len(x)
res = []
for i in range(1,n):
A0 = par[3] + (par[4]-par[1])*par[6] + (par[5]-par[2])*par[6]**2
if(x[i] == par[6]):
res[i] = A0 + par[1]*x[i] + par[2]*x[i]**2
else:
res[i] = par[3] + par[4]*x[i] + par[5]*x[i]**2
return res
This is my model function...
def sum_squares (par, x, y):
ss = sum((y-model(par,x))^2)
return ss
And this is the sum of squares
But I have no idea on how to convert this:
#I found these initial values with a few minutes of guess and check.
par0 <- c(7,-1,-395,70,-2.3,10)
sol <- optim(par= par0, fn=sqerror, x=x, y=y)$par
To Python code...
I wrote an open source Python package (BSD license) that has a genetic algorithm (Differential Evolution) front end to the scipy Levenberg-Marquardt solver, it functions similarly to what you describe in your question. The github URL is:
https://github.com/zunzun/pyeq3
It comes with a "user-defined function" example that's fairly easy to use:
https://github.com/zunzun/pyeq3/blob/master/Examples/Simple/FitUserDefinedFunction_2D.py
along with command-line, GUI, cluster, parallel, and web-based examples. You can install the package with "pip3 install pyeq3" to see if it might suit your needs.
Seems like I have been able to fix the problem.
def model (par,x):
n = len(x)
res = np.array([])
for i in range(0,n):
A0 = par[2] + (par[3]-par[0])*par[5] + (par[4]-par[1])*par[5]**2
if(x[i] <= par[5]):
res = np.append(res, A0 + par[0]*x[i] + par[1]*x[i]**2)
else:
res = np.append(res,par[2] + par[3]*x[i] + par[4]*x[i]**2)
return res
def sum_squares (par, x, y):
ss = sum((y-model(par,x))**2)
print('Sum of squares = {0}'.format(ss))
return ss
And then I used the functions as follow:
parameter = sy.array([0.0,-8.0,0.0018,0.0018,0,200])
res = least_squares(sum_squares, parameter, bounds=(-360,360), args=(x1,y1),verbose = 1)
The only problem is that it doesn't produce the results I'm looking for... And that is mainly because my x values are [0,360] and the Y values only vary by about 0.2, so it's a hard nut to crack for this function, and it produces this (poor) result:
Result
I think that the range of x values [0, 360] and y values (which you say is ~0.2) is probably not the problem. Getting good initial values for the parameters is probably much more important.
In Python with numpy / scipy, you would definitely want to not loop over values of x but do something more like
def model(par,x):
res = par[2] + par[3]*x + par[4]*x**2
A0 = par[2] + (par[3]-par[0])*par[5] + (par[4]-par[1])*par[5]**2
res[np.where(x <= par[5])] = A0 + par[0]*x + par[1]*x**2
return res
It's not clear to me that that form is really what you want: why should A0 (a value independent of x added to a portion of the model) be so complicated and interdependent on the other parameters?
More importantly, your sum_of_squares() function is actually not what least_squares() wants: you should return the residual array, you should not do the sum of squares yourself. So, that should be
def sum_of_squares(par, x, y):
return (y - model(par, x))
But most importantly, there is a conceptual problem that is probably going to plague this model: Your par[5] is meant to represent a breakpoint where the model changes form. This is going to be very hard for these optimization routines to find. These routines generally make a very small change to each parameter value to estimate to derivative of the residual array with respect to that variable in order to figure out how to change that variable. With a parameter that is essentially used as an integer, the small change in the initial value will have no effect at all, and the algorithm will not be able to determine the value for this parameter. With some of the scipy.optimize algorithms (notably, leastsq) you can specify a scale for the relative change to make. With leastsq that is called epsfcn. You may need to set this as high as 0.3 or 1.0 for fitting the breakpoint to work. Unfortunately, this cannot be set per variable, only per fit. You might need to experiment with this and other options to least_squares or leastsq.
The following code simulates extracting binary words from different locations within a set of images.
The Numba wrapped function, wordcalc in the code below, has 2 problems:
It is 3 times slower compared to a similar implementation in C++.
Most strangely, if you switch the order of the "ibase" and "ibit" for-loops, speed drops by a factor of 10 (!). This does not happen in the C++ implementation which remains unaffected.
I'm using Numba 0.18.2 from WinPython 2.7
What could be causing this?
imDim = 80
numInsts = 10**4
numInstsSub = 10**4/4
bitsNum = 13;
Xs = np.random.rand(numInsts, imDim**2)
iInstInds = np.array(range(numInsts)[::4])
baseInds = np.arange(imDim**2 - imDim*20 + 1)
ofst1 = np.random.randint(0, imDim*20, bitsNum)
ofst2 = np.random.randint(0, imDim*20, bitsNum)
#nb.jit(nopython=True)
def wordcalc(Xs, iInstInds, baseInds, ofst, bitsNum, newXz):
count = 0
for i in iInstInds:
Xi = Xs[i]
for ibit in range(bitsNum):
for ibase in range(baseInds.shape[0]):
u = Xi[baseInds[ibase] + ofst[0, ibit]] > Xi[baseInds[ibase] + ofst[1, ibit]]
newXz[count, ibase] = newXz[count, ibase] | np.uint16(u * (2**ibit))
count += 1
return newXz
ret = wordcalc(Xs, iInstInds, baseInds, np.array([ofst1, ofst2]), bitsNum, np.zeros((iInstInds.size, baseInds.size), dtype=np.uint16))
I get 4x speed-up by changing from np.uint16(u * (2**ibit)) to np.uint16(u << ibit); i.e. replace the power of 2 with a bitshift, which should be equivalent (for integers).
It seems reasonably likely that your C++ compiler might be making this substitution itself.
Swapping the order of the two loops makes a small difference for me for both your original version (5%) and my optimized version (15%), so I can't think I can make a useful comment on that.
If you really wanted to compare the Numba and C++ you can look at the compiled Numba function by doing os.environ['NUMBA_DUMP_ASSEMBLY']='1' before you import Numba. (That's clearly quite involved though).
For reference, I'm using Numba 0.19.1.