# Numerical continuation with AUTO
Its a pain, but helpful. First load a model (a neuron) and simulate it in time with `brian2`. Then convert it to AUTO's eqution files and find bifurcations or calculate its [phase response curve](https://en.wikipedia.org/wiki/Phase_response_curve).

In [None]:
import json, auto, autoutils, brian2, sympy, scipy
units=dict( list(vars(brian2.units).items()) + list(vars(brian2.units.allunits).items()) )
modeldict = json.load(open("wangbuzsaki.json"))
print "These are the model parameter"
for i,j in modeldict['par'].items():
    print "{}: \t\t {}".format(i,j)

Load the model into a brian2 equation. We also select which parameters, we want to continue the model in. Let's use
membrane capacitance $C_m$ and the input current $I$

In [None]:
def load_mod(moddict,bp):
    fundic = dict([(j,k.split(":")[0]) for j,k in moddict["fun"].items()])
    pardic = dict([(j,k) for j,k in moddict["par"].items()])
    bifpar = [(k,pardic.pop(k)) for k in bp]
    sdelist = [[j,] + k.split(':') for j,k in moddict['aux_odes'].items()]
    sdelist = [(i,":".join((str(sympy.S(j).subs(fundic).subs(fundic).subs(pardic)),k))) for i,j,k in sdelist]
    sdelist+= [('v', str(sympy.solve(moddict['current_balance_eq'],'dv/dt')[0].subs(moddict['currents']).subs(fundic).subs(fundic).subs(pardic))+":volt")]
    sde = brian2.Equations("d{}/dt = {}".format(*sdelist[0]))
    for i,j in sdelist[1:]:
        sde += brian2.Equations("d{}/dt = {}".format(i,j))
    return sde, dict(bifpar)

## Time integration with `brian2`

In [None]:
dt = eval("0.01 * ms", units)
ode, bifpar = load_mod(modeldict,["I","Cm"])

# ode += brian2.Equations("I : amp / meter**2")
# ode += brian2.Equations("Cm :  farad/meter**2")

## ADD PAR ##
for j,k in bifpar.items():
    ode += brian2.Equations("{} : {}".format(j,repr(eval(k,units).dim)))


brian2.defaultclock.dt = dt
G = brian2.NeuronGroup(1, model=ode, method="rk4",
                       threshold='not_refractory and (v>5*mV)',
                       refractory='v>-40*mV')

## STATE INIT ##
G.v= eval("-65 * mV", units)
G.I = bifpar['I']
#G.DNap = p["DNap_list"][0]
G.Cm = bifpar['Cm']

states = brian2.StateMonitor(G, ode.eq_names, record=True)
spikes = brian2.SpikeMonitor(G)
net = brian2.Network(G,states,spikes)
duration = eval("1000 * ms",units)
net.run(duration)


In [None]:
%matplotlib inline
from matplotlib import pyplot
states.v[0]
pyplot.plot(states.t/brian2.ms,states.v[0]/brian2.mV)
pyplot.ylabel('v [mV]')
pyplot.xlabel('time [ms]')
pyplot.title('membrane voltage')
pyplot.tight_layout()

Oh it does not spike. But can it spike? Try to make it spike by changing $I$.

## Numerical continuation of the fixpoint (resting potential)
write the `AUTO`-equation files and run a fixpoint continuation. Check out the AUTO constants [here](http://www.macs.hw.ac.uk/~gabriel/auto07/node263.html)

In [None]:
unitlist=["mV","ms","cm2","uF","psiemens","um2","msiemens","cm"]
autobifpar = dict([(i,float(eval(j,units))) for i,j in bifpar.items()])
baseunits=[(k,float(eval(k,units).base)) for k in unitlist]
varrhs = [(i,sympy.S(j).subs(baseunits))
                for i,j in ode.eq_expressions]
varrhs.sort(cmp=lambda x,y:cmp(x[0],y[0]),reverse=True)
var,rhs = zip(*varrhs)
spikecriterion = [str(sympy.S(k).subs([(i,"{}_left".format(i)) for i in var])) 
                  for j,k in zip(var,rhs) if j=="v"]


unames,pnames= autoutils.writeFP('wb',
    bifpar=autobifpar, rhs=rhs, var=var,
    bc=['{0}_left-{0}_right'.format(v) for v in var] + spikecriterion, 
    ic=[])

################
# CONT FP & LC #
r1= auto.run([float(getattr(states,j)[0][-1]) for j in var], e='wb', 
    c='wb', parnames= pnames, unames=unames,
    ICP=['I'], ISP=1, ILP=1, SP=['LP','HB','BP'],
    PAR=autobifpar, ITNW=17, NWTN=13, NMX=500000, NPR=500000,
    DS=1e-6, DSMAX=1e-5, STOP=['HB1'],
    UZSTOP= {'I': 350.0})

s1HB= r1.getLabel('HB')[0]
s1LP= r1.getLabel('LP')[0]

r2= auto.run(s1HB, e='wb', c='wb',
    parnames= pnames, unames= unames,
    ICP=['I','period'], ILP=1, ISW=1,IPS=2,
    ITNW=7, NWTN=3, NMX=1000, NPR=1000,
    DS=-1e-2, DSMAX=1e1, NTST= 300,
    SP=['BT','LP','HB','BP','CP'],
    UZSTOP={'period':0.15})


In [None]:
p1=auto.plot(r1+r2,stability=True,bifurcation_y='MAX v')
p1.savefig("r1_bdiag.svg")
# look at p1.config()
from IPython.display import SVG, display
display(SVG('r1_bdiag.svg'))

## Continuation of boundary value problems
remember that the PRC is the solution of

$\dot Z=-J^\mathsf{T}(t)Z$

In [None]:
## CREATE ADJOINT LINEAR SYSTEM ##

advar = sympy.S(["ad{}".format(k) for k in var])
J = [[sympy.S(i).diff(j) for j in var] for i in rhs]
J = [[j.subs(baseunits) for j in k] for k in J]
adlinsys = [str(k) for k in 
            (sympy.S("lam")*sympy.eye(len(advar))-sympy.Matrix(J).T)*sympy.Matrix(advar)]
prcnorm=str((sympy.Matrix(sympy.S(advar)).T*sympy.Matrix(sympy.S(rhs)))[0,0] - sympy.S("dotZF/period"))


Actually, a slightly more general adjoint linear equation was used:

$\dot Z = (\lambda I-J^\mathsf{T}(t))Z$.

The additional paramter $\lambda$ has proofen useful for the continuation.

In [None]:
# Jacobian at the saddle-node is informative
statepardic = dict(zip(s1LP.coordnames,s1LP.coordarray.flatten()))
statepardic.update(s1LP.PAR)
J0 = scipy.array([[sympy.S(k).subs(statepardic) for k in j] for j in J],float)

Since we start near the saddle-node on limit cycle bifurcation, we may start to initialise the PRC with the canonical ones.

In [None]:
sol = r2.getLabel('UZ')[0] # STARTING POINT
 
ix = sol['v'].argmax()
orbdat = sol.coordarray[:,:-1]
orbdat = scipy.roll(orbdat,-ix,1)
tdat = sol.indepvararray[:-1]
tdat = scipy.mod(scipy.roll(tdat,-ix) - tdat[ix],1)
dat= scipy.zeros((2*min(orbdat.shape)+1,max(orbdat.shape)))
dat[0,:]= tdat
dat[1:min(orbdat.shape)+1,:]= orbdat
dat[min(orbdat.shape)+1:,:]= scipy.zeros((min(orbdat.shape),max(orbdat.shape)))
dat[min(orbdat.shape)+1:,:]= 10.*(1-scipy.cos(2*scipy.pi*dat[0,:])[None,:].repeat(min(orbdat.shape),0))
# goldstone= scipy.array([scipy.gradient(k,scipy.gradient(dat[0,:])) 
#             for k in dat[1:min(orbdat.shape)+1,:]])
# dotZF= scipy.trapz((goldstone*dat[min(orbdat.shape)+1:,:]).sum(0),dat[0,:])

dotZF = 1

autobifpar['period'] = sol['period']
autobifpar['lam'] = 0
autobifpar['dotZF'] = dotZF
autobifpar.update([(i,j) for i,j in sol.PAR.items() 
                   if i in autobifpar.keys()])

unames, pnames= autoutils.writeBVP('wbBVP',
    bifpar=autobifpar,
    rhs=list(rhs)+adlinsys,
    var=list(var)+advar,
    bc=['{0}_left-{0}_right'.format(v) for v in list(var)+advar]\
    +spikecriterion,
    ic=[prcnorm])

r3= auto.run(dat, e='wbBVP', c='wbBVP',
    parnames= pnames, unames=unames, PAR=autobifpar,
    ICP=['dotZF','lam','period','I'],
    NTST= 200, ITNW=17, NWTN=13, NMX=1000, NPR=200,
    DS=1e-3, DSMAX=1e1, UZSTOP={'dotZF':1})


In [None]:
r4= auto.run(r3.getLabel("UZ")[0], e='wbBVP', c='wbBVP',
    parnames= pnames, unames=unames, ICP=['period','I','lam'],
    NTST=200, ITNW=7, NWTN=3, NMX=2000, NPR=2000, DS=5e-2, DSMAX=1e1,
    UZSTOP={"period":1.0})

In [None]:
s4 = r4.getLabel('UZ')[0]
pyplot.plot(s4['t'],s4['adv'])

In [None]:
s4['t'].shape

In [None]:


r5= auto.run(r4.getLabel("UZ")[0], e='wbBVP', c='wbBVP',
    parnames= pnames, unames=unames, ICP=['Cm','period','lam'],
    NTST=200, ITNW=7, NWTN=3, NMX=200, NPR=20, DS=5e-0, DSMAX=1e4,
    UZSTOP={"Cm":0.025})

# r6= auto.run(r5.getLabel("UZ")[0], e='wbBVP', c='wbBVP',
#     parnames= pnames, unames=unames, ICP=['c','I','lam'],
#     NTST=200, ITNW=7, NWTN=3, NMX=300, NPR=30, DS=5e-0, DSMAX=1e4,
#     UZSTOP={"c":0.025})
