''' Shallow-water equation in beta plane, with initial conditios 
 Author : M. Crucifix, University catholique de Louvain 2013
 Based on : 
 An Unconditionally Stable Scheme for the Shallow Water Equations 
 M. Israeli, N. Naik and M. Cane  
 Monthly Weather Review,  vol. 128, p. 810,  2000 
 adimensional variables scaling 
 The dimensional equations and variables can be recovered by multiplying 
 h by H,(u,v) by U, t by T,and(x,y) by L,where  
 U = c ; H = c*c/g ; T = (c beta)^(-1/2) ; L = (c / beta )^1/2  
 so that the non dimensional coriolis parameter, around the equation 
 which equals beta * y ....  
 Free to redistribute modify .... but cite author, propose 
 modifications, and without any warranty whatsoever   '''

import numpy as np
import matplotlib as mpl 
import matplotlib.pyplot as plt
import td 


def check_positive (i):
  if (i <=0 ):
    raise NamedError('negative space or time step supplied')
  else:
    pass

class Field:
  ''' horizontal field' organised as a np.matrix    
      firt index corresponds to position on 'x' axis and 
      second index to position on 'y' axis ''' 
  def __init__ (self, val, m, n, dx, dy, dt):
    self.m = np.int(m)
    self.n = np.int(n)
    for i in (dx,dy,dt): check_positive(i)
    self.dx = dx 
    self.dy = dy 
    self.dt = dt
    self.dtdx = dt/dx

    self.d1 = int(np.floor(self.dtdx))
    self.d2 = int(self.d1 + 1)
    self.r  = self.dtdx - self.d1

    if (val.shape == (m+1,n+1)):
      self.field = val
    else:
      raise NameError('Wrong field shape')
    return (None)

 
  def R(self,i):
    ''' returns latitudinal transect of field at position i*dx + dt  (_R operator)'''
    if (np.int( i + self.d2 ) < self.m):
      return(self.field[i+self.d1,:] * (1-self.r) + self.field[i+self.d2,:] * self.r)
    else:
      return(self.field[self.m,:] + (self.field[self.m,:]-
             self.field[self.m-1,:])*(self.dtdx - (self.m-i) ))

  def L(self,i):
    ''' returns latitudinal transect of field at position i*dx - dt  (_L operator)'''
    if (np.int( i - self.d2 ) > 0):
      return(self.field[i-self.d1,:] * (1-self.r) + self.field[i-self.d2,:] * self.r)
    else:
      return(self.field[0,:] + (self.field[0,:]-self.field[1,:])*( self.dtdx  - i ))

  def C(self,i):
    return (self.field[i,:])


def fdy(vector,n, dy):
  tdy = np.zeros(n+1)
#  tdy[0] =  (-vector[2] + 4*vector[1] - 3*vector[0])
#  tdy[n] =  -(-vector[n-2] + 4*vector[n-1] - 3*vector[n]) 
  tdy[0] = 0.
  tdy[n] = 0.
  tdy[1:n] = (vector[range(2,n+1)] - vector[range(0,n-1)])/2.
  tdy /= dy
  return(tdy)

class swfield:

  def field ( self, val ):
      return(Field(val, self.m, self.n, self.dx, self.dy, self.dt))

  def __init__ (self,m, n, dx, dy, dt, h, u, v, f, F, G, Q):
 
    self.dx = dx
    self.dy = dy
    self.ddy = dy*dy
    self.dt = dt
    self.tau = dt/2.
    self.ttau = self.tau*self.tau
    self.dyv = np.zeros((m+1,n+1))
    self.m = int(m)
    self.n = int(n)
    if (f.shape == (n+1,)):
      self.f = f
    else:
      raise NameError ('Coriolis factor ''f'' must be supplied as a n-vector')

    self.F = F
    self.G = G
    self.Q = Q

    self.h = self.field(h)
    self.u = self.field(u)
    self.v = self.field(v)
    self.II1 = self.field(np.zeros((m+1,n+1)))
    self.rhs = np.zeros((m+1,n+1))

    return(None)

  def update_dyv(self):
    for i in range(0,m+1): 
      self.dyv[i,:] = fdy(self.v.C(i), self.n, self.dy) 
    return(self)

  def update_II(self):
    self.II1 = self.field(self.u.field+self.h.field+self.tau*
          (self.f*self.v.field - self.dyv + self.F + self.Q))
    self.II2 = self.field(self.u.field-self.h.field+self.tau*
          (self.f*self.v.field + self.dyv + self.F - self.Q))
    return(self)

  def iterate_interior(self,i):
    I1 = self.II1.L(i)
    I2 = self.II2.R(i)
    cn = I1 + 0.5*(I2-I1) + self.tau*self.F[i,]
    dn = I1 - 0.5*(I2+I1) + self.tau*self.Q[i,]

    aa = -2./(self.ddy) - (1/(self.ttau) + self.f*self.f)
    bb = np.ones(self.n)/(self.ddy) 
    cc = np.ones(self.n)/(self.ddy) 
    dd = (self.f*self.u.C(i) + self.f*cn + fdy(self.h.C(i)+dn, self.n, self.dy) - 
          2*self.G[i,]) / self.tau - self.v.C(i) / self.ttau
    v_ = td.solve_s(aa,bb,cc,dd)
    u_ = self.tau * self.f * v_ + cn
    h_ = - self.tau * fdy(v_, self.n, self.dy) + dn

    return u_, v_, h_

  def iterate_west(self):
    i = 0
    I2 = self.II2.R(i)
    dn = - I2 

    # to do : the +self.f * self.f is incorrect, but
    # is necessary to ensure stability. Need to find
    # the way to keep the system stable without it
    aa = -2./(self.ddy) - (1/(self.ttau) * np.ones(self.n+1) + self.f * self.f)
    cc = np.ones(n)/(self.ddy)  + self.f[1:self.n+1]/(2.*self.dy)
    bb = np.ones(n)/(self.ddy)  - self.f[0:self.n]/(2.*self.dy)
#
    dd = (fdy(self.h.C(i)+dn, self.n, self.dy) - 
          2*self.G[i,]) / self.tau - self.v.C(i) / self.ttau
    v_ = td.solve_s(aa,bb,cc,dd)
    u_ = np.zeros(self.n+1)
    h_ = - self.tau * fdy(v_, self.n, self.dy) + dn

    return u_, v_, h_

  def iterate_east(self):
    i = self.m
    dn = self.II1.L(self.m)

    aa = -2./(self.ddy) - (1/(self.ttau) * np.ones(self.n+1) + self.f * self.f )
    cc = np.ones(n)/(self.ddy)  - self.f[1:self.n+1]/(2.*self.dy)
    bb = np.ones(n)/(self.ddy)  + self.f[0:self.n]/(2.*self.dy)
#
    dd = (fdy(self.h.C(i)+dn, self.n, self.dy) - 
          2*self.G[i,]) / self.tau - self.v.C(i) / self.ttau
    v_ = td.solve_s(aa,bb,cc,dd)
    u_ = np.zeros(self.n+1)
    h_ =  self.tau * fdy(v_, self.n, self.dy) + dn

    return u_, v_, h_
 

  def iterate (self):
    u_ = np.zeros((self.m+1,self.n+1))
    v_ = np.zeros((self.m+1,self.n+1))
    h_ = np.zeros((self.m+1,self.n+1))

    self.update_dyv()
    self.update_II()

    u_[0,:],v_[0,:],h_[0,:] = self.iterate_west()
    u_[self.m,:],v_[self.m,:],h_[self.m,:] = self.iterate_east()

    for i in range(1,self.m):
      u_[i,:],v_[i,:],h_[i,:] = self.iterate_interior(i)

    self.u = self.field(u_)
    self.v = self.field(v_)
    self.h = self.field(h_)
    return(self)

  def plot (self):
   plt.imshow(self.h.field.T,vmin=-0.4, vmax=0.4)

import matplotlib.animation as animation

if __name__ == "__main__": 

  m = 60
  n = 60
  dx = .6
  dy = .6
  dt = 0.3
  u = np.zeros((m+1,n+1))
  v = np.zeros((m+1,n+1))
  h = np.zeros((m+1,n+1))
  h[25:35,25:35] = 0.5

  beta = 1 # set to zero to simulate gravity waves
  f = ( -n/2 + np.arange(0,n+1) )  * dy * beta
  F = np.zeros((m+1, n+1))
  G = np.zeros((m+1, n+1))
  Q = np.zeros((m+1, n+1))

  H = swfield (m, n, dx, dy, dt, h, u, v, f, F, G, Q)

  def init():
    H.plot()

  def animate(i):
    H.iterate()
    print(i)
    # H.plot()

  fig = plt.figure()

  init()
  for i in range(1,350):
    fig.clear()
    animate(i)
    H.plot()
    plt.savefig('iter_%4.4i.png' % i)
 
  # anim.save('basic_animation', writer='ffmpeg')
  # above
  # does not work on the mac !!! ('rgba' not supported)