Why can't my covariance be calculated by scipy.optimize.curve_fit? - python

I am trying to fit a curve to this set of data using the optimize.curve_fit method in Spyder. I keep getting the optimize warning: cannot calculate covariance error. Could someone please explain why I keep getting this error?
time_in_days = np.array([0, 0.0831, 0.1465, 0.2587, 0.4828, 0.7448, 0.9817, 1.2563, 1.4926, 1.7299, 1.9915, 3.0011, 4.0109, 5.009, 5.9943, 7.0028])
viral_load = np.array([106000, 93240, 167000, 154000, 119000, 117000, 110000, 111000, 74388, 83291, 66435, 21125, 20450, 15798, 4785.2])
#defining function
def VL_func(time, A, B, alpha, beta):
"""
Parameters
----------
time : Time in days
A : constant
B : constant
alpha : constant
beta : constant
Returns VL
------
"""
VL = (A * np.exp(-alpha * time_in_days)) + (B * np.exp(-beta * time_in_days))
return np.round(VL)
popt, pcov = curve_fit(VL_func, time_in_days, viral_load)
print(popt)
print("\n")
print(pcov)
error message:
OptimizeWarning: Covariance of the parameters could not be estimated
category=OptimizeWarning

As you can see here, the shape of your data is important.
I've reproduced your problem here, and going a little deeper I found this. Your variable time_in_days and viral_load have different lengths.
ValueError: operands could not be broadcast together with shapes (16,) (15,)

Related

AttributeError: 'str' object has no attribute 'dot'

I am using the qndiag library to try to find a diagonalisation for 2 given matrices.
The github is here : qndiag libray
I am using this Python script to compute these 2 diagonalisation as closed as possible :
import os, sys
import numpy as np
from qndiag import qndiag
# dimension
m=7
# number of matrices
n=2
# Load spectro and WL+GCph+XC
FISH_GCsp = np.loadtxt('Fisher_GCsp_flat.txt')
FISH_XC = np.loadtxt('Fisher_XC_GCph_WL_flat.txt')
# Marginalizing over uncommon parameters between the two matrices
COV_GCsp_first = np.linalg.inv(FISH_GCsp)
COV_XC_first = np.linalg.inv(FISH_XC)
COV_GCsp = COV_GCsp_first[0:m,0:m]
COV_XC = COV_XC_first[0:m,0:m]
# Invert to get Fisher matrix
FISH_sp = np.linalg.inv(COV_GCsp);
FISH_xc = np.linalg.inv(COV_XC);
# Drawing a random set of commuting matrices
C=np.zeros((n,m,m));
C[0,:,:] = FISH_sp
C[1,:,:] = FISH_xc
[D, B] = qndiag(C, 'max_iter', 1000, 'tol', 1e-3);
# Print expected diagonal matrices
B*C[0,:,:]*B.T
B*C[1,:,:]*B.T
Given my 2 matrices 7x7 FISH_sp and FISH_xc, I get an error of kind :
approximate_joint_diagonalization_qndiag/qndiag/qndiag/qndiag.py", line 90, in qndiag
D = transform_set(B, C)
and following
approximate_joint_diagonalization_qndiag/qndiag/qndiag/qndiag.py", line 151, in transform_set
op[i] = M.dot(d.dot(M.T))
AttributeError: 'str' object has no attribute 'dot'
Here the concerned function transform_set :
def transform_set(M, D, diag_only=False):
n, p, _ = D.shape
if not diag_only:
op = np.zeros((n, p, p))
for i, d in enumerate(D):
op[i] = M.dot(d.dot(M.T))
and the main function that is called in my initialization :
def qndiag(C, B0=None, weights=None, max_iter=1000, tol=1e-6,
lambda_min=1e-4, max_ls_tries=10, diag_only=False,
return_B_list=False, verbose=False):
"""Joint diagonalization of matrices using the quasi-Newton method
Parameters
----------
C : array-like, shape (n_samples, n_features, n_features)
Set of matrices to be jointly diagonalized. C[0] is the first matrix,
etc...
B0 : None | array-like, shape (n_features, n_features)
Initial point for the algorithm. If None, a whitener is used.
weights : None | array-like, shape (n_samples,)
Weights for each matrix in the loss:
L = sum(weights * KL(C, C')) / sum(weights).
No weighting (weights = 1) by default.
max_iter : int, optional
Maximum number of iterations to perform.
tol : float, optional
A positive scalar giving the tolerance at which the
algorithm is considered to have converged. The algorithm stops when
|gradient| < tol.
lambda_min : float, optional
A positive regularization scalar. Each eigenvalue of the Hessian
approximation below lambda_min is set to lambda_min.
max_ls_tries : int, optional
Maximum number of line-search tries to perform.
diag_only : bool, optional
If true, the line search is done by computing only the diagonals of the
dataset. The dataset is then computed after the line search.
Taking diag_only = True might be faster than diag_only=False
when the matrices are large (n_features > 200)
return_B_list : bool, optional
Chooses whether or not to return the list of iterates.
verbose : bool, optional
Prints informations about the state of the algorithm if True.
Returns
-------
D : array-like, shape (n_samples, n_features, n_features)
Set of matrices jointly diagonalized
B : array, shape (n_features, n_features)
Estimated joint diagonalizer matrix.
infos : dict
Dictionnary of monitoring informations, containing the times,
gradient norms and objective values.
References
----------
P. Ablin, J.F. Cardoso and A. Gramfort. Beyond Pham's algorithm
for joint diagonalization. Proc. ESANN 2019.
https://www.elen.ucl.ac.be/Proceedings/esann/esannpdf/es2019-119.pdf
https://hal.archives-ouvertes.fr/hal-01936887v1
https://arxiv.org/abs/1811.11433
"""
t0 = time()
n_samples, n_features, _ = C.shape
if B0 is None:
C_mean = np.mean(C, axis=0)
d, p = np.linalg.eigh(C_mean)
B = p.T / np.sqrt(d[:, None])
else:
B = B0
if weights is not None: # normalize
weights_ = weights / np.mean(weights)
else:
weights_ = None
D = transform_set(B, C)
Why this error on dot operator ? it seems to be matricial product (like np.dot) but the way it is used lets think that's not the case.
The issue lies in [D, B] = qndiag(C, 'max_iter', 1000, 'tol', 1e-3), B0 (which is the second param) gets assigned as a string not as an array! Then eventually B would be a string and hence the error message str object has no attribute 'dot' !, if you are only passing C matrix as parameter, just do [D, B] = qndiag(C).

How to make scipy.optimize.curve_fit result in a better sine regression fit?

I have a problem where I am using scipy.optimize.curve_fit to do a regression fit to a sine/cosine function but the fit does not seem as optimized as I want it to be. How can I change my code to make the fitting better?
I have already tried changing how parameters are tried for the dataset and there is always seemingly a difference in phase-offset of my generated fit or the fitting function is not fitting to the proper minima/maxima.
Here is the code I am using to generate the regression fit. The output (fitfunc) can be plotted to show the result.
def sin_regress(data_x, data_y):
"""Function regression fits data to SIN function; does not need guess of freq.
Parameters
----------
data_x :
Data for X values, most likely a set of voltages.
data_y :
Data for Y values, most likely the resulting powers from voltages.
Returns
-------
__ :
Dictionary containing values for amplitude, angular frequency, phase, offset, frequency, period, fit function, max covariance, initial guess.
"""
data_x = np.array(data_x)
data_y = np.array(data_y)
freqz = np.fft.rfftfreq(len(data_x), (data_x[1] - data_x[0])) # uniform spacing
freq_y = abs(np.fft.rfft(data_y))
guess_freq = abs(freqz[np.argmax(freq_y[1:])+1]) # exclude offset peak
guess_amp = np.std(data_y) * 2.**0.5
guess_offset = np.mean(data_y)
guess = np.array([guess_amp, 2.*np.pi*guess_freq, 0., guess_offset])
def sinfunc(t, A, w, p, c):
"""Raw function to be used to fit data.
Parameters
----------
t :
Voltage array
A :
Amplitude
w :
Angular frequency
p :
Phase
c :
Constant value
Returns
-------
__ :
Formed fit function with provided values.
"""
return A * np.sin(w*t + p) + c
popt, pcov = scipy.optimize.curve_fit(sinfunc, data_x, data_y, p0=guess)
A, w, p, c = popt
f = w/(2.*np.pi)
fitfunc = lambda t: A * np.sin(w*t + p) + c
return {"amp": A, "omega": w, "phase": p, "offset": c, "freq": f, "period": 1./f, "fitfunc": fitfunc, "maxcov": np.max(pcov), "rawres": (guess,popt,pcov)}
With my trial dataset being:
x = np.linspace(3.5,9.5,(9.5-3.5)/0.00625 + 1)
pow1 = [1.8262110863, 1.80944546009, 1.7970185646900003, 1.77120336754, 1.7458101235699999, 1.73597098224, 1.7122529922799998, 1.70015674142, 1.68968617429, 1.6989396515, 1.69760676076, 1.6946375613599998, 1.6895321899, 1.68145658386, 1.68581793183, 1.6920468775900002, 1.6865452951599997, 1.68570953338, 1.6922784791700003, 1.70958957412, 1.71683408637, 1.70360183933, 1.6919669752199997, 1.6669487117300001, 1.6351298032300001, 1.6061729066600001, 1.57344333403, 1.54723708217, 1.5277773737599998, 1.5122628414300001, 1.4962354965200002, 1.4873367459, 1.47567715522, 1.4696584634, 1.46159565032, 1.45320592315, 1.4487225244200002, 1.44572887186, 1.44089260198, 1.4367157657399998, 1.4349226211, 1.43614316806, 1.4381950627400002, 1.43947658627, 1.4483572314200002, 1.4504305909200002, 1.44436990692, 1.43367609757, 1.42637295252, 1.41197427963, 1.4067529511399999, 1.39714414185, 1.38309980493, 1.3730701362500004, 1.3693239836499997, 1.3729558979599998, 1.38291189477, 1.3988274622900003, 1.42112832324, 1.44217266068, 1.4578792438300001, 1.46478639274, 1.46676801398, 1.4646383458800003, 1.45918801344, 1.44561402809, 1.4212145146499997, 1.4012453921299999, 1.38070199226, 1.36215759642, 1.3540496661500003, 1.35470913884, 1.3481165993199997, 1.34059081754, 1.332964567, 1.33426054366, 1.34052562222, 1.3343255632100002, 1.3310385903, 1.33044179339, 1.32827462527, 1.3356201140500001, 1.3400144893900001, 1.3157198001600001, 1.27716313727, 1.2517667292400003, 1.2406836620500001, 1.2354036030700002, 1.23110776291, 1.22492582889, 1.22074838719, 1.21816502762, 1.21015135518, 1.20038737012, 1.1920263929700001, 1.18723010357, 1.19656731125, 1.2237068834899998, 1.2373841696199999, 1.2251076648299999, 1.1963014909299998, 1.16152861736, 1.13940556893, 1.12839812676, 1.12368066547, 1.1190219542100002, 1.11384679759, 1.10555781262, 1.0977575386300003, 1.0901734365399998, 1.0824275375699999, 1.07552931443, 1.0696565210100002, 1.06481394254, 1.0578173014299999, 1.05204230102, 1.0482530038799998, 1.04237087457, 1.0361766944300002, 1.0297906393, 1.0240842912299999, 1.01250548183, 0.9964340353700001, 0.9859450307400002, 0.98614987451, 0.9826424718800002, 0.9739505767299999, 0.9578738177999998, 0.9416973908799999, 0.92975112051, 0.9204409049900001, 0.91821299468, 0.9100360995600001, 0.89589154778, 0.8799530701000002, 0.8640439088, 0.8500274234399999, 0.8428500205999999, 0.8358678326, 0.8333072464999999, 0.83420148485, 0.8362578717, 0.83608947323, 0.83035464861, 0.82315039029, 0.81220152235, 0.80169300598, 0.7918658959, 0.7808782388700001, 0.77684747687, 0.7743299962, 0.76797978094, 0.7591097217, 0.7520710688500001, 0.7452609707, 0.73562753255, 0.7256206568399999, 0.71663518742, 0.70951165178, 0.7035884873, 0.6973768853, 0.6900439160299999, 0.68062538021, 0.67096725454, 0.66585371901, 0.6663177033900001, 0.67214877804, 0.6787934074299999, 0.68365489213, 0.68581510712, 0.6820892084400001, 0.67805153237, 0.67540688376, 0.6724865515, 0.6674502035, 0.6593852224500001, 0.6524835227400001, 0.64758563177, 0.6424489126599999, 0.63385426361, 0.6242639699699999, 0.6143974848999999, 0.60705328516, 0.60087306988, 0.5928024247700001, 0.5864009594799999, 0.5786877362899999, 0.57457744302, 0.57012636848, 0.56554310644, 0.5618750202299999, 0.55731189492, 0.55057384756, 0.5419996086800001, 0.52987726408, 0.51025575876, 0.48599474143000004, 0.46231124366000004, 0.44151899608999995, 0.42632008877, 0.42655368254, 0.42784393651999997, 0.42863940533999995, 0.42506971759, 0.41952014686999994, 0.41337420894, 0.40570705996, 0.39706149294, 0.38721395321, 0.3806321949, 0.37313342483999995, 0.36982676447, 0.36704194004, 0.36189430296, 0.3560628963, 0.34954350131, 0.34540695806, 0.34178605934, 0.33629549256, 0.3293877577, 0.32357672213, 0.31864117490000005, 0.31165906503, 0.30439039263000006, 0.29875160317, 0.29294459105000004, 0.28847285244, 0.28509162173, 0.28265949265, 0.28003828154, 0.27814630873999996, 0.27599048828, 0.27524025386, 0.27406833971, 0.27281988259, 0.27155314420999993, 0.26840999947000005, 0.2634181241, 0.25883622926000005, 0.25503165868, 0.25056988104, 0.24466620872, 0.23932761459000002, 0.23422685251999997, 0.22880456697, 0.22310130485000004, 0.21785542557999998, 0.21366651902000006, 0.20966530780999998, 0.20521315906, 0.20012157666000002, 0.19469597081, 0.18957032591999995, 0.18423432945, 0.17946309866000001, 0.17845044232, 0.17746098912000002, 0.17475331315, 0.17039776599, 0.16363173032999997, 0.15716942518, 0.15214176858, 0.14870803788, 0.14515563527000003, 0.14218680693, 0.13893215828, 0.13546723615, 0.13178983356, 0.12747471604, 0.12350983297, 0.12011202021999998, 0.11627787931000003, 0.11218377746, 0.10821276155, 0.10384311280999999, 0.09960625706000001, 0.09615194041000003, 0.09216061199, 0.08847719376999999, 0.08481545522999999, 0.08163922452000001, 0.07851820869000001, 0.07535195845, 0.07259346216999998, 0.06996658694999999, 0.06748611806, 0.06513859836, 0.06343437948, 0.06174502390000001, 0.059727113600000006, 0.05755100017, 0.054968070300000005, 0.052386214650000006, 0.05002439809, 0.04768410494, 0.04532047195999999, 0.04319275697, 0.04105023728, 0.03894787384, 0.03695523698, 0.03513302983, 0.033548459399999994, 0.032170295249999994, 0.030958654539999998, 0.02983605681, 0.028375548879999997, 0.02671830267, 0.024898224419999997, 0.0230959196, 0.02139548979, 0.01983882955, 0.018419727860000002, 0.017108712149999997, 0.01590183706, 0.01467630964, 0.01340369235, 0.01204181727, 0.011048145310000002, 0.01072443434, 0.010401953859999999, 0.010151465580000001, 0.00990748117, 0.00972232492, 0.00956939523, 0.009442617850000001, 0.009344043619999999, 0.009241641279999999, 0.00915107487, 0.009064981109999998, 0.008985430320000001, 0.00890431702, 0.00883441469, 0.008775488880000001, 0.00873752015, 0.00871498109, 0.008710938120000001, 0.00872328188, 0.00874796935, 0.008778945909999999, 0.00882859436, 0.00889468812, 0.00898683656, 0.00910033268, 0.009214043629999998, 0.00934455143, 0.00949293034, 0.00965939522, 0.009844610069999999, 0.01005115305, 0.010290684330000001, 0.01054888746, 0.010822364050000002, 0.011132617979999999, 0.012252539939999998, 0.013524844710000001, 0.01492336044, 0.01639437616, 0.01790093876, 0.01949634904, 0.02112754055, 0.022849025059999997, 0.02457990408, 0.02637656436, 0.02816101762, 0.02999357634, 0.031735392870000004, 0.03370418208999999, 0.03591160409, 0.03868365509, 0.0413049248, 0.043746897629999996, 0.04622211263, 0.04871939798, 0.051123460649999994, 0.05370180068, 0.05625859775000001, 0.058868656510000006, 0.06136678167, 0.06394643029, 0.06623680155999997, 0.06885605955999999, 0.07171654804, 0.07483811078, 0.07798461489, 0.08075584557000001, 0.08390440047999999, 0.08690709601, 0.09012059232, 0.09292447923, 0.09569860054, 0.09869240932999998, 0.10204307363999998, 0.10579037859, 0.10944262493000001, 0.11339190256000002, 0.11739889503, 0.12165444219999999, 0.12640639566999998, 0.13103823193000003, 0.13545668928, 0.13980243177, 0.1445100493, 0.14892381914000002, 0.15358704212000002, 0.15754780411999997, 0.1620275896, 0.16721823448, 0.17344235602999997, 0.17972712208000002, 0.18671513038999998, 0.19370331449, 0.1997322407, 0.20632862788999998, 0.21168169468000003, 0.2186676522, 0.22613634413, 0.23308478213, 0.24056257561, 0.24694894328, 0.25289726401, 0.26043587782, 0.26523394455, 0.27115650357, 0.27472996084, 0.27757628917, 0.28195025433, 0.28717476642, 0.29255468867, 0.29700002103, 0.29903203287999996, 0.30043668141, 0.30362955273000003, 0.30861634997000004, 0.3146493582, 0.32141648759, 0.33050709371, 0.34155311010999995, 0.35347176329, 0.3641544984300001, 0.37273471389, 0.37810184317999995, 0.38245108175, 0.38773739072, 0.39195147307000006, 0.39284567233, 0.39723110233000003, 0.39968268453, 0.40089368072000003, 0.40181627844999995, 0.40374096608, 0.40828194296, 0.41598909193000005, 0.42570815513, 0.43468223779000004, 0.4419052070599999, 0.44814120359, 0.4541516141699999, 0.45904682936999996, 0.46598345094999993, 0.47421183044, 0.48259810056, 0.49064425346, 0.49772194929999997, 0.50355609034, 0.5097226337399999, 0.5242588261700001, 0.53191943219, 0.5427558587299999, 0.5558334377799999, 0.57145400528, 0.58596031492, 0.6017949058700001, 0.61620852018, 0.62886383358, 0.63983492811, 0.64928899126, 0.65807748798, 0.66440410952, 0.67291110232, 0.68452424766, 0.6952567679499999, 0.7045326279799999, 0.7168566913700001, 0.72438360596, 0.7334800323799999, 0.73850692728, 0.7444589784699999, 0.75250327593, 0.7652333354299999, 0.7794230629700001, 0.79152575915, 0.80011656054, 0.80971581904, 0.8176350188100001, 0.82681863275, 0.83466310596, 0.84169904395, 0.85246648611, 0.8612931078200001, 0.8712971515300001, 0.88083937874, 0.89039777788, 0.89838717297, 0.90641512274, 0.9111584238600001, 0.9159304749999999, 0.9210217253499999, 0.92296264345, 0.9233887177, 0.9218466277399999, 0.9176133266600001, 0.91940151039, 0.9208485417400001, 0.9220888543199999, 0.9236718817800001, 0.9276074484799999, 0.93015244864, 0.9343631130099999, 0.93763016402, 0.9384009648400001, 0.93879867973, 0.93652442175, 0.93662918739, 0.9331820972899999, 0.93503584744, 0.9360406912399999, 0.93994795716, 0.9444487777899999, 0.95150762595, 0.9574753021500001, 0.9659650293199998, 0.9757605964, 0.9878513785299999, 0.99883880117, 1.01323052095, 1.0311493112499999, 1.04763474212, 1.0677277318200002, 1.086237323, 1.0988490621599998, 1.10287175775, 1.11006095748, 1.1203823058799998, 1.1266948453599999, 1.1295011150999998, 1.13468379124, 1.13839008058, 1.1417559206699999, 1.1386140845, 1.1368738695300002, 1.13791410398, 1.1443759989699998, 1.1533826011700001, 1.16127430094, 1.1771807669, 1.19318348288, 1.2014892452, 1.20715822998, 1.21764737132, 1.23158125907, 1.2387470993899998, 1.2441262208700001, 1.2562376475, 1.2682344256899998, 1.28293907518, 1.2903573374300001, 1.3040509126199997, 1.3260814219800001, 1.3595052134299999, 1.3870089263099998, 1.4040962907899999, 1.4190098465199998, 1.43005375357, 1.4343605702800002, 1.4355429141099998, 1.43638377355, 1.44962018073, 1.45147113789, 1.45921588453, 1.4661880139399999, 1.47414703793, 1.47941295628, 1.47950143284, 1.4748920184699998, 1.4692222329000004, 1.4631299473100001, 1.45757789614, 1.4527345168899999, 1.4434376802999997, 1.4390123479299999, 1.4387321330999998, 1.4376372501999999, 1.44922049319, 1.46122473234, 1.47480432313, 1.48463330822, 1.50740325124, 1.52143227566, 1.5388702456399996, 1.5586354228100001, 1.5670929624799999, 1.57654938893, 1.60239005482, 1.6187282200499997, 1.6195258763400002, 1.6341473226799998, 1.6455264836499999, 1.6550699218299996, 1.6682315829299998, 1.68167279482, 1.6900114477300001, 1.6978344170500002, 1.7018968392199998, 1.70642375358, 1.71237959385, 1.7205134225500003, 1.7311321537799997, 1.7430771546100001, 1.7517999091500003, 1.76491293742, 1.7833902824799999, 1.8081253623500004, 1.83075608662, 1.8524498577000004, 1.86711454623, 1.8814965784800002, 1.8857294108200002, 1.90378495898, 1.9156142957500002, 1.9241271088399998, 1.92694429655, 1.92836076148, 1.9246632612399999, 1.9177767372999999, 1.9240789057399996, 1.93491201195, 1.95508541182, 1.9667632837499998, 1.97663894849, 1.9838888513599997, 1.9862320351100002, 1.9850681678399997, 1.9724571903800001, 1.9569690057000002, 1.9450577939199998, 1.93385585952, 1.91272038928, 1.90263962687, 1.89419806376, 1.8846363638699999, 1.8752989218, 1.8721239020399998, 1.87465480067, 1.87635644139, 1.8883053875500004, 1.90622687322, 1.9326186524100002, 1.96217418184, 1.99341387155, 2.0052843606899997, 2.0198940101400003, 2.03224112041, 2.04585828934, 2.0482686606100002, 2.0761935844499995, 2.10636661393, 2.1218703845699998, 2.1265723770799996, 2.13344606897, 2.13480411595, 2.12395452534, 2.11298829408, 2.10366419185, 2.10279155509, 2.10582569592, 2.12401487691, 2.14351597204, 2.1603280826, 2.1732762280399998, 2.1829961701499996, 2.1825562873100006, 2.1829598615399997, 2.18269224434, 2.18542837733, 2.18136038877, 2.17195739983, 2.16672507523, 2.1595190200499994, 2.15408655871, 2.16100126623, 2.1646243915, 2.16989273172, 2.1760575368399997, 2.18993197141, 2.20082640578, 2.18953400264, 2.1673666182699995, 2.15301331645, 2.1344672799800004, 2.1212936853000004, 2.1081594070399996, 2.08825354625, 2.0697085058700004, 2.045492469, 2.02153998684, 2.0038663723099996, 2.0038828566799998, 2.0085019585599997, 2.0192783851200002, 2.03833670679, 2.05771370034, 2.08050465897, 2.1006803439999997, 2.1263974552, 2.14748327701, 2.17287144288, 2.1941383974899997, 2.19820122981, 2.2003345112000003, 2.20800316408, 2.21184328157, 2.21310867227, 2.21112832057, 2.1998480658600004, 2.1906804089599996, 2.17670294702, 2.1515223983699996, 2.1337058932199997, 2.11742559909, 2.1017357932899996, 2.0798991511200002, 2.05328198125, 2.02510619803, 2.00362619651, 1.98193234731, 1.9618359005700001, 1.9612528146099997, 1.97096636996, 1.9761617414300001, 1.9782324642600002, 1.99263889104, 2.00500029816, 2.01506871685, 2.02912785846, 2.04221860157, 2.06368362263, 2.07491317421, 2.08832055797, 2.09538342956, 2.1084886843899997, 2.1158979036700005, 2.1260576895499996, 2.13639327622, 2.14181249535, 2.1392352295499997, 2.14448495648, 2.1421138235, 2.14009620617, 2.1384934521399996, 2.1319765571600002, 2.1216323962400003, 2.1065051490999998, 2.08999485498, 2.06996758792, 2.05396301646, 2.0366352808700006, 2.023489069, 1.9927697308899996, 1.9807445347400001, 1.97629449536, 1.9772154719699997, 1.9837454333899998, 1.9903514690000002, 1.9990068602399997, 2.0052703762999995, 2.0102515290099996, 2.01071088451, 2.00780344289, 2.00202451671, 1.99526703575, 1.9894158244, 1.9859053554, 1.9872483633099995, 1.99006639085, 2.00697930222, 2.0329301048299997, 2.05059264513, 2.0540770985099996, 2.04176762498, 2.0093012359700007, 1.9757453156100002, 1.94977980597, 1.94015615295, 1.93165724611, 1.9207719523600002, 1.90945249843, 1.89062300491, 1.87690150004, 1.8621346825699998, 1.84607821661, 1.828253313, 1.8169694254700002, 1.8075289169999997, 1.8040289362800004, 1.79267489253, 1.78023102445, 1.7778953016200003, 1.7787011610500003, 1.78226670819, 1.7830425676100004, 1.77486727406, 1.7675372149399997, 1.7575688744100002, 1.7498299871300003, 1.74518012353, 1.73248096246, 1.7160241253800002, 1.70317674164, 1.6978293584500002, 1.6946921121299998, 1.6961595927200002, 1.70211670251, 1.7104493398199998, 1.7203816647499999, 1.7274331496, 1.7311123100199999, 1.73665119714, 1.74750018228, 1.7625600270900001, 1.76829838689, 1.7683754962599998, 1.7604641870999997, 1.7378729159800002, 1.7182883638100002, 1.7072806677199999, 1.7037852573199999, 1.6963237919299996, 1.67904111493, 1.64849412058, 1.61509034869, 1.58860298353, 1.56708077499, 1.5563275906199998, 1.5508352464699997, 1.5448227655799998, 1.53880546048, 1.54041544105, 1.5403843473000003, 1.53577729621, 1.5273169831, 1.51722079097, 1.5010415320300001, 1.4873523904299997, 1.47098713536, 1.45343877476, 1.4333900233299999, 1.4214382256099998, 1.4199358231499999, 1.42357822576, 1.42446916333, 1.4169634987200002, 1.40651060735, 1.39602957147, 1.38608337936, 1.38502109414, 1.38722933647, 1.3877573052599999, 1.38915685615, 1.3879546490299999, 1.38030042971, 1.37484574183, 1.36882917891, 1.36771619056, 1.36598312403, 1.35475238104, 1.3352715984299999, 1.31243304213, 1.29205091175, 1.26981483599, 1.25096920963, 1.23261465755, 1.2107178005399999, 1.1896016271599998, 1.1758782668, 1.17342422369, 1.17358562993, 1.17110207509, 1.1674486178099999, 1.1603703751, 1.1565048865399998, 1.15140617524, 1.15148740571, 1.15832875386, 1.16650391071, 1.1712949266600001, 1.16865191865, 1.16596408644, 1.1661593208199998, 1.16419447693, 1.15754447647, 1.15312982771, 1.1506705697300001, 1.14375644814, 1.13705099847, 1.12589113437, 1.11212277402, 1.10001296849, 1.08946394429, 1.0747068729400002, 1.05980790705, 1.0438431988799999, 1.02497712333, 1.00659505173, 0.98919173016, 0.9715707328300001, 0.95416868081, 0.9416231916500001, 0.92753217501, 0.91364512326, 0.90414607963, 0.8947884227199999, 0.8843405703999998, 0.8769049253500001, 0.8719632452999999, 0.86833484662, 0.8680955887799999, 0.86604049098, 0.86558996362, 0.86372701427, 0.85893691627, 0.85435131048, 0.84886228665, 0.8409088095199999, 0.82732292967, 0.8182398235399999, 0.81298593645, 0.8065804672500001, 0.7963832009099999, 0.7813524576499999, 0.7642633939500001, 0.74891606863, 0.73387495429, 0.72021307831, 0.70711249145, 0.6972523931, 0.68836254874, 0.6789805168, 0.66917573095, 0.65520369872, 0.6405349086200001, 0.6262600443299999, 0.6128265668199999, 0.6004827768800001, 0.58821246352, 0.5763513298499999, 0.56580466895, 0.55820613325, 0.5498382224900001, 0.5432313079700001, 0.5383656045, 0.53169802591];
Here are some additional values for the pow dataset:
(Link to pastebin to not exceed post length limit)
https://pastebin.com/5GP8sj4N
The resulting fit that from the trial dataset (x, pow1) I get is shown here (orange) with the original (pow1) data (blue)
As mentioned, there is an issue with how the phase fits the minima and maxima. Unfortunately the application of getting this fit function correct has very little room for error.
Please help out if you have an idea of how to make this fit the data better!
Edit:
I tried what #Joe mentioned in the comments, with first filtering the data. I utilized a Savitzky-Golay filter and recieved the following result, Original data (blue), the filtered data (green), and the fit to the filtered data (orange). Again the same shift in minima and maxima is still present in the fit function to the filtered data.
Here are my results with more aggressive clipping bounds of 0.5 to 1.75 for each data set.
for pow1:
A = 9.6711505138648990E-01
c = 9.7613787086912507E-01
p = 4.0262076448344617E+00
w = 1.2654001570670070E+00
for pow2:
A = 9.4894637490866129E-01
c = 9.6733405789489280E-01
p = 4.0892433833755097E+00
w = 1.2578627414445132E+00
for pow3:
A = 9.8595630272060597E-01
c = 9.6749868212694512E-01
p = 4.0859456191316230E+00
w = 1.2598547148182329E+00
for pow4:
A = -9.4636707498392481E-01
c = 9.5047597808408602E-01
p = -4.2643913461857056E+02
w = 1.2761107231684055E+00
I think I have this figured out - your data is not a mathematically perfect sine wave + noise, so the fitting software can only come close to modeling a sine function to this data. If you must have more accuracy, try splitting the model into different segments and use a piecewise fit. Here is a close-up of the problem area:

Curve fitting in Python with constraint of zero slope at edges

I am looking to curve fit the following data, such that I get it to fit a trend with the condition of zero slope at the edges. The output of polyfit fits that data, but not with zero slopes at the edges.
Here is what I'm looking to output - pardon my Paint job. I need to it to fit like this so I can properly remove this sine/cosine bias of the data that isn't real towards the center.
Here is the data:
[0.23353535 0.25586247 0.26661164 0.26410896 0.24963951 0.22670266
0.19955422 0.17190263 0.1598439 0.17351905 0.18212444 0.18438673
0.17952432 0.18314894 0.19265689 0.19432385 0.19605163 0.20326011
0.20890851 0.20590997 0.21856518 0.23771665 0.24530019 0.23940831
0.22078396 0.23075128 0.2346082 0.22466281 0.24384843 0.26339594
0.26414153 0.24664183 0.24278978 0.31023648 0.3614195 0.37773436
0.3505998 0.28893167 0.23965877 0.24063917 0.27922502 0.32716477
0.36553767 0.42293146 0.50968856 0.5458872 0.52192533 0.45243764
0.36313155 0.3683921 0.40942553 0.4420537 0.46145585 0.4648034
0.4523771 0.4272876 0.39404616 0.3570107 0.35060245 0.3860975
0.3996996 0.44551122 0.46611032 0.45998383 0.4309985 0.38563925
0.37105605 0.4074444 0.48815584 0.5704579 0.6448988 0.7018853
0.73397845 0.73739105 0.7122451 0.6618154 0.591451 0.5076601
0.48578677 0.47347385 0.4791471 0.48306277 0.47025493 0.43479836
0.44380915 0.45868078 0.5341566 0.57549906 0.55790776 0.56244135
0.57668275 0.561856 0.67564166 0.7512851 0.76957643 0.7266262
0.734133 0.7231936 0.6776926 0.60511285 0.51599765 0.5579323
0.56723005 0.5440337 0.5775593 0.5950776 0.5722321 0.57858473
0.5652703 0.54723704 0.59561515 0.7071321 0.8169259 0.91443264
0.9883759 1.0275097 1.0235045 0.9737119 1.029139 1.1354861
1.1910824 1.1826864 1.1092159 0.9832138 0.9643041 0.92324203
0.9093703 0.88915515 1.0007693 1.0542978 1.0857164 1.0211861
0.88474303 0.8458009 0.76522666 0.7478076 0.90081936 1.0690157
1.1569089 1.1493248 1.0622779 1.0327609 0.9805119 0.9583969
0.8973544 0.9543319 0.9777171 0.94951093 0.97323567 1.0244237
1.0569099 1.0951824 1.0771195 1.3078191 1.7212077 2.09409
2.320331 2.3279085 2.125451 1.7908521 1.4180487 1.0744424
1.0218129 1.0916439 1.1255138 1.125803 1.1139745 1.2187989
1.300092 1.3025533 1.2312403 1.221301 1.2535597 1.2298189
1.1458241 1.1012102 1.0889369 1.1558667 1.3051153 1.4143198
1.6345526 1.8093723 1.9037704 1.8961821 1.7866236 1.5958548
1.3865516 1.5308585 1.6140417 1.627337 1.5733193 1.4981418
1.5048542 1.4935548 1.4798748 1.4131776 1.3792214 1.3728334
1.3683671 1.3593615 1.2995907 1.2965002 1.366058 1.4795257
1.5462885 1.61591 1.5968509 1.5222199 1.6210756 1.7074443
1.8351102 2.3187535 2.6568012 2.7676315 2.6480794 2.3636303
2.0673316 1.9607923 1.8074365 1.713272 1.5893831 1.4734347
1.507817 1.5213271 1.6091452 1.7162323 1.7608733 1.7497622
1.9187828 2.0197518 2.0487514 2.01107 1.9193696 1.7904462
1.8558109 2.1955926 2.4700975 2.6562278 2.675197 2.6645825
2.6295316 2.4182043 2.2114453 2.2506614 2.2086055 2.0497518
1.9557768 1.901191 2.067513 2.1077373 2.0159333 1.8138607
1.5413624 1.600069 1.7631899 1.9541935 1.9340311 1.805134
2.0671906 2.2247658 2.2641945 2.3594956 2.2504601 1.9749025
1.8905054 2.0679731 2.1193469 2.0307171 2.0717037 2.0340347
1.925536 1.7820351 1.9467943 2.315468 2.4834185 2.3751369
2.0240622 1.9363666 2.1732547 2.3113241 2.3264208 2.22015
2.0187428 1.7619076 1.796859 1.8757095 2.0501778 2.44711
2.6179967 2.508112 2.1694388 1.7242104 1.7671669 1.862043
1.8392721 1.7120028 1.6650634 1.6319505 1.482931 1.5240219
1.5815579 1.5691646 1.4766116 1.3731087 1.4666644 1.4061015
1.3652745 1.425564 1.4006845 1.5000012 1.581379 1.6329607
1.6444355 1.6098644 1.5300899 1.6876912 1.8968476 2.048039
2.1006014 2.0271482 1.8300935 1.6986666 1.9628603 2.0521066
1.9337255 1.6407858 1.2583638 1.2110122 1.2476432 1.2360718
1.2886397 1.2862154 1.2343681 1.1458222 1.209224 1.2475786
1.2353342 1.1797879 1.0963987 1.0928186 1.1553882 1.1569618
1.1932304 1.3002363 1.3386917 1.2973225 1.1816871 1.0557054
0.9350373 0.896656 0.8565816 0.90168726 0.9897751 1.02342
1.0232298 1.1199353 1.1466643 1.1081418 1.0377598 1.0348651
1.0223045 1.0607077 1.0089502 0.885213 1.023178 1.1131796
1.1331098 1.0779471 0.9626393 0.81472665 0.85455835 0.87542623
0.87286425 0.89130884 0.9545931 1.0355722 1.0201533 0.93568784
0.9180018 0.8202782 0.7450139 0.72550577 0.68578506 0.6431666
0.66193295 0.6386373 0.7060119 0.7650972 0.80093855 0.803342
0.76590335 0.7151591 0.6946282 0.7136788 0.7714012 0.8022328
0.79840165 0.8543819 0.8586749 0.8028453 0.7383879 0.73423904
0.65107304 0.61139977 0.5940311 0.6151931 0.59349155 0.54995483
0.5837645 0.5891752 0.56406695 0.5638191 0.5762535 0.58305734
0.5830114 0.57470953 0.5568098 0.52852243 0.49031836 0.45275375
0.47168964 0.46634504 0.4600581 0.45332378 0.41508177 0.3834329
0.4137769 0.41392407 0.3824464 0.36310086 0.434278 0.48041886
0.49433306 0.475708 0.43060693 0.36886734 0.34740242 0.34108457
0.36160505 0.40907663 0.43613982 0.4394311 0.42070773 0.38575593
0.3827834 0.4338096 0.46581286 0.45669746 0.40830874 0.3505502
0.32584783 0.3381971 0.33949164 0.36409503 0.3759155 0.3610108
0.37174097 0.39990777 0.38925973 0.34376588 0.32478797 0.32705626
0.3228174 0.30941254 0.28542265 0.2687348 0.25517422 0.26127565
0.27331188 0.3028561 0.31277937 0.29953563 0.2660389 0.27051866
0.2913383 0.30363902 0.30684754 0.3011791 0.28737035 0.26648855
0.26413882 0.25501928 0.23947525 0.21937743 0.19659272 0.18965112
0.21511254 0.23329383 0.24157354 0.2391297 0.22697571 0.20739041
0.1855308 0.18856761 0.19565174 0.20542233 0.21473111 0.22244582
0.22726117 0.22789808 0.22336568 0.21322969 0.20314343 0.2031754
0.19738965 0.1959791 0.20284075 0.20859875 0.21363212 0.21804498
0.22160804 0.22381367]
This came close, but not exactly it as the edges aren't zero slope: How do I fit a sine curve to my data with pylab and numpy?
Is there anything available that will let me do this without having to write up a custom algorithm to handle this? Thanks.
Here is a Lorentzian type of peak equation fitted to your data, and for the "x" values I used an index similar to what I see on the Example Output plot in your post. I have also zoomed in on the peak center to better display the sinusoidal shapes you mention. You may be able to subtract the predicted value from this peak equation to condition or pre-process the raw data as you discuss.
a = 1.7056067124682076E+02
b = 7.2900803359572393E+01
c = 2.5047064423525464E+02
d = 1.4184767800540945E+01
Offset = -2.4940994412221318E-01
y = a/ (b + pow((x-c)/d, 2.0)) + Offset
Starting from you own example based on a sine fit, I added the constraints such that the derivatives of the model have to be zero at the end points. I did this using symfit, a package I wrote to make this kind of thing easier. If you prefer to do this using scipy you can adapt the example to that syntax if you want, symfit is just a wrapper around their minimizers that adds symbolical manipulations using sympy.
# Make variables and parameters
x, y = variables('x, y')
a, b, c, d = parameters('a, b, c, d')
# Initial guesses
b.value = 1e-2
c.value = 100
# Create a model object
model = Model({y: a * sin(b * x + c) + d})
# Take the derivative and constrain the end-points to be equal to zero.
dydx = D(model[y], x).doit()
constraints = [Eq(dydx.subs(x, xdata[0]), 0),
Eq(dydx.subs(x, xdata[-1]), 0)]
# Do the fit!
fit = Fit(model, x=xdata, y=ydata, constraints=constraints)
fit_result = fit.execute()
print(fit_result)
plt.plot(xdata, ydata)
plt.plot(xdata, model(x=xdata, **fit_result.params).y)
plt.show()
This prints: (from current symfit PR#221, which has better reporting of the results.)
Parameter Value Standard Deviation
a 8.790393e-01 1.879788e-02
b 1.229586e-02 3.824249e-04
c 9.896017e+01 1.011472e-01
d 1.001717e+00 2.928506e-02
Status message Optimization terminated successfully.
Number of iterations 10
Objective <symfit.core.objectives.LeastSquares object at 0x0000016F670DF080>
Minimizer <symfit.core.minimizers.SLSQP object at 0x0000016F78057A58>
Goodness of fit qualifiers:
chi_squared 29.72125657199736
objective_value 14.86062828599868
r_squared 0.8695978050586373
Constraints:
--------------------
Question: a*b*cos(c) == 0?
Answer: 1.5904051811454707e-17
Question: a*b*cos(511*b + c) == 0?
Answer: -6.354261416082215e-17

Modified BPMF in PyMC3 using `LKJCorr` priors: PositiveDefiniteError using `NUTS`

I previously implemented the original Bayesian Probabilistic Matrix Factorization (BPMF) model in pymc3. See my previous question for reference, data source, and problem setup. Per the answer to that question from #twiecki, I've implemented a variation of the model using LKJCorr priors for the correlation matrices and uniform priors for the standard deviations. In the original model, the covariance matrices are drawn from Wishart distributions, but due to current limitations of pymc3, the Wishart distribution cannot be sampled from properly. This answer to a loosely related question provides a succinct explanation for the choice of LKJCorr priors. The new model is below.
import pymc3 as pm
import numpy as np
import theano.tensor as t
n, m = train.shape
dim = 10 # dimensionality
beta_0 = 1 # scaling factor for lambdas; unclear on its use
alpha = 2 # fixed precision for likelihood function
std = .05 # how much noise to use for model initialization
# We will use separate priors for sigma and correlation matrix.
# In order to convert the upper triangular correlation values to a
# complete correlation matrix, we need to construct an index matrix:
n_elem = dim * (dim - 1) / 2
tri_index = np.zeros([dim, dim], dtype=int)
tri_index[np.triu_indices(dim, k=1)] = np.arange(n_elem)
tri_index[np.triu_indices(dim, k=1)[::-1]] = np.arange(n_elem)
logging.info('building the BPMF model')
with pm.Model() as bpmf:
# Specify user feature matrix
sigma_u = pm.Uniform('sigma_u', shape=dim)
corr_triangle_u = pm.LKJCorr(
'corr_u', n=1, p=dim,
testval=np.random.randn(n_elem) * std)
corr_matrix_u = corr_triangle_u[tri_index]
corr_matrix_u = t.fill_diagonal(corr_matrix_u, 1)
cov_matrix_u = t.diag(sigma_u).dot(corr_matrix_u.dot(t.diag(sigma_u)))
lambda_u = t.nlinalg.matrix_inverse(cov_matrix_u)
mu_u = pm.Normal(
'mu_u', mu=0, tau=beta_0 * lambda_u, shape=dim,
testval=np.random.randn(dim) * std)
U = pm.MvNormal(
'U', mu=mu_u, tau=lambda_u,
shape=(n, dim), testval=np.random.randn(n, dim) * std)
# Specify item feature matrix
sigma_v = pm.Uniform('sigma_v', shape=dim)
corr_triangle_v = pm.LKJCorr(
'corr_v', n=1, p=dim,
testval=np.random.randn(n_elem) * std)
corr_matrix_v = corr_triangle_v[tri_index]
corr_matrix_v = t.fill_diagonal(corr_matrix_v, 1)
cov_matrix_v = t.diag(sigma_v).dot(corr_matrix_v.dot(t.diag(sigma_v)))
lambda_v = t.nlinalg.matrix_inverse(cov_matrix_v)
mu_v = pm.Normal(
'mu_v', mu=0, tau=beta_0 * lambda_v, shape=dim,
testval=np.random.randn(dim) * std)
V = pm.MvNormal(
'V', mu=mu_v, tau=lambda_v,
testval=np.random.randn(m, dim) * std)
# Specify rating likelihood function
R = pm.Normal(
'R', mu=t.dot(U, V.T), tau=alpha * np.ones((n, m)),
observed=train)
# `start` is the start dictionary obtained from running find_MAP for PMF.
# See the previous post for PMF code.
for key in bpmf.test_point:
if key not in start:
start[key] = bpmf.test_point[key]
with bpmf:
step = pm.NUTS(scaling=start)
The goal with this reimplementation was to produce a model that could be estimated using the NUTS sampler. Unfortunately, I'm still getting the same error at the last line:
PositiveDefiniteError: Scaling is not positive definite. Simple check failed. Diagonal contains negatives. Check indexes [ 0 1 2 3 ... 1030 1031 1032 1033 1034 ]
I've made all the code for PMF, BPMF, and this modified BPMF available in this gist to make it simple to replicate the error. All you need to do is download the data (also referenced in the gist).
It looks like you are passing the complete precision matrix into the normal distribution:
mu_u = pm.Normal(
'mu_u', mu=0, tau=beta_0 * lambda_u, shape=dim,
testval=np.random.randn(dim) * std)
I assume you only want to pass the diagonal values:
mu_u = pm.Normal(
'mu_u', mu=0, tau=beta_0 * t.diag(lambda_u), shape=dim,
testval=np.random.randn(dim) * std)
Does this change to mu_u and mu_v fix it for you?

Fitting negative binomial in python

In scipy there is no support for fitting a negative binomial distribution using data
(maybe due to the fact that the negative binomial in scipy is only discrete).
For a normal distribution I would just do:
from scipy.stats import norm
param = norm.fit(samp)
Is there something similar 'ready to use' function in any other library?
Statsmodels has discrete.discrete_model.NegativeBinomial.fit(), see here:
https://www.statsmodels.org/dev/generated/statsmodels.discrete.discrete_model.NegativeBinomial.fit.html#statsmodels.discrete.discrete_model.NegativeBinomial.fit
Not only because it is discrete, also because maximum likelihood fit to negative binomial can be quite involving, especially with an additional location parameter. That would be the reason why .fit() method is not provided for it (and other discrete distributions in Scipy), here is an example:
In [163]:
import scipy.stats as ss
import scipy.optimize as so
In [164]:
#define a likelihood function
def likelihood_f(P, x, neg=1):
n=np.round(P[0]) #by definition, it should be an integer
p=P[1]
loc=np.round(P[2])
return neg*(np.log(ss.nbinom.pmf(x, n, p, loc))).sum()
In [165]:
#generate a random variable
X=ss.nbinom.rvs(n=100, p=0.4, loc=0, size=1000)
In [166]:
#The likelihood
likelihood_f([100,0.4,0], X)
Out[166]:
-4400.3696690513316
In [167]:
#A simple fit, the fit is not good and the parameter estimate is way off
result=so.fmin(likelihood_f, [50, 1, 1], args=(X,-1), full_output=True, disp=False)
P1=result[0]
(result[1], result[0])
Out[167]:
(4418.599495886474, array([ 59.61196161, 0.28650831, 1.15141838]))
In [168]:
#Try a different set of start paramters, the fit is still not good and the parameter estimate is still way off
result=so.fmin(likelihood_f, [50, 0.5, 0], args=(X,-1), full_output=True, disp=False)
P1=result[0]
(result[1], result[0])
Out[168]:
(4417.1495981801972,
array([ 6.24809397e+01, 2.91877405e-01, 6.63343536e-04]))
In [169]:
#In this case we need a loop to get it right
result=[]
for i in range(40, 120): #in fact (80, 120) should probably be enough
_=so.fmin(likelihood_f, [i, 0.5, 0], args=(X,-1), full_output=True, disp=False)
result.append((_[1], _[0]))
In [170]:
#get the MLE
P2=sorted(result, key=lambda x: x[0])[0][1]
sorted(result, key=lambda x: x[0])[0]
Out[170]:
(4399.780263084549,
array([ 9.37289361e+01, 3.84587087e-01, 3.36856705e-04]))
In [171]:
#Which one is visually better?
plt.hist(X, bins=20, normed=True)
plt.plot(range(260), ss.nbinom.pmf(range(260), np.round(P1[0]), P1[1], np.round(P1[2])), 'g-')
plt.plot(range(260), ss.nbinom.pmf(range(260), np.round(P2[0]), P2[1], np.round(P2[2])), 'r-')
Out[171]:
[<matplotlib.lines.Line2D at 0x109776c10>]
I know this thread is quite old, but current readers may want to look at this repo which is made for this purpose: https://github.com/gokceneraslan/fit_nbinom
There's also an implementation here, though part of a larger package: https://github.com/ernstlab/ChromTime/blob/master/optimize.py
I stumbled across this thread, and found an answer for anyone else wondering.
If you simply need the n, p parameterisation used by scipy.stats.nbinom you can convert the mean and variance estimates:
mu = np.mean(sample)
sigma_sqr = np.var(sample)
n = mu**2 / (sigma_sqr - mu)
p = mu / sigma_sqr
If you the dispersionparameter you can use a negative binomial regression model from statsmodels with just an interaction term. This will find the dispersionparameter alpha using MLE.
# Data processing
import pandas as pd
import numpy as np
# Analysis models
import statsmodels.formula.api as smf
from scipy.stats import nbinom
def convert_params(mu, alpha):
"""
Convert mean/dispersion parameterization of a negative binomial to the ones scipy supports
Parameters
----------
mu : float
Mean of NB distribution.
alpha : float
Overdispersion parameter used for variance calculation.
See https://en.wikipedia.org/wiki/Negative_binomial_distribution#Alternative_formulations
"""
var = mu + alpha * mu ** 2
p = mu / var
r = mu ** 2 / (var - mu)
return r, p
# Generate sample data
n = 2
p = 0.9
sample = nbinom.rvs(n=n, p=p, size=10000)
# Estimate parameters
## Mean estimates expectation parameter for negative binomial distribution
mu = np.mean(sample)
## Dispersion parameter from nb model with only interaction term
nbfit = smf.negativebinomial("nbdata ~ 1", data=pd.DataFrame({"nbdata": sample})).fit()
alpha = nbfit.params[1] # Dispersion parameter
# Convert parameters to n, p parameterization
n_est, p_est = convert_params(mu, alpha)
# Check that estimates are close to the true values:
print("""
{:<3} {:<3}
True parameters: {:<3} {:<3}
Estimates : {:<3} {:<3}""".format('n', 'p', n, p,
np.round(n_est, 2), np.round(p_est, 2)))

Categories

Resources