require "numru/ggraph"
include NumRu
include NMath

###################################################################

def search_index_for_interpolate_in_press( gphysPs, gphysArr3d, tindex, tgtpress )

  km2 = tgtpress.shape[0]

  lon  = gphysArr3d.coord(0).val
  lat  = gphysArr3d.coord(1).val
  sig  = gphysArr3d.coord(2).val
#  time = gphysArr3d.coord(3).val

  im = gphysArr3d.shape[0] # number of elements for longitude dimension
  jm = gphysArr3d.shape[1] # number of elements for latitude  dimension
  km = gphysArr3d.shape[2] # number of elements for sigma     dimension
  tm = gphysArr3d.shape[3] # number of elements for time      dimension

  #p sig

#  tgtarr3d    = NArray.sfloat(im,jm,km2).fill(0.0)
  kindexarr3d = NArrayMiss.int(im,jm,km2)

  t = tindex

#  ps    = gphysPs.cut('time'=>time[t]).val
#  arr3d = gphysArr3d.cut('time'=>time[t]).val
  ps    = gphysPs.cut('time'=>gphysArr3d.coord(3).val[t]).val
  arr3d = gphysArr3d.cut('time'=>gphysArr3d.coord(3).val[t]).val

  press = ps * sig.reshape!(1,1,km)

  imjm = im*jm
  for j in 0...jm
    for i in 0...im

      for k2 in 0...km2

        mask = press[i,j,true].lt(tgtpress[k2])
        k = mask.any? ? mask.where[0] : km-1

        if k > 0
          kindexarr3d[i,j,k2] = i + im * j + imjm * k
        end

#        if k == 0
#          objarr3d[i,j,k2] = rmiss
#        else
#          if k > km-1
#            objarr3d[i,j,k2] = rmiss
#          else
#            objarr3d[i,j,k2] = ( arr3d[i,j,k] - arr3d[i,j,k-1] ) / log( press[i,j,k] / press[i,j,k-1] ) * log( objpress[k2] / press[i,j,k] ) + arr3d[i,j,k-1]
#          end
#        end

      end
    end
  end

  return kindexarr3d

end


###################################################################

def interpolate_in_press( gphysPs, gphysArr3d, tindex, tgtpress, kindexarr3d, rmiss )

  km2 = tgtpress.shape[0]

  lon  = gphysArr3d.coord(0).val
  lat  = gphysArr3d.coord(1).val
  sig  = gphysArr3d.coord(2).val
  time = gphysArr3d.coord(3).val

  im = gphysArr3d.shape[0] # number of elements for longitude dimension
  jm = gphysArr3d.shape[1] # number of elements for latitude  dimension
  km = gphysArr3d.shape[2] # number of elements for sigma     dimension
  tm = gphysArr3d.shape[3] # number of elements for time      dimension

  tgtarr3d = NArray.sfloat(im,jm,km2).fill(rmiss)

  t = tindex

  ps    = gphysPs.cut('time'=>time[t]).val
  arr3d = gphysArr3d.cut('time'=>time[t]).val

  press = ps * sig.reshape!(1,1,km)

  # calculate in double precision to be identical to the result of the original source
  arr3d = arr3d.to_type(NArray::FLOAT)
  press = press.to_type(NArray::FLOAT)

  tgtpress3d = NArray.float(im,jm).fill(1)*tgtpress.reshape(1,1,km2)

  mask = kindexarr3d.get_mask!
  idx0 = kindexarr3d.get_array![mask]
  idx1 = idx0 - im*jm
  tgtarr3d[mask] = ( arr3d[idx0] - arr3d[idx1] ) / log( press[idx0] / press[idx1] ) * log( tgtpress3d[mask] / press[idx0] ) + arr3d[idx1]

  return tgtarr3d

end

###################################################################

def zonalmean( gphysArr3d )

  im = gphysArr3d.shape[0] # number of elements for longitude dimension
  jm = gphysArr3d.shape[1] # number of elements for latitude  dimension
  km = gphysArr3d.shape[2] # number of elements for sigma     dimension

  arr3d = gphysArr3d.val

  rmiss = gphysArr3d.get_att('missing_value')[0]

  zmarr = NArray.sfloat(jm,km).fill(rmiss)

  arr3d.reshape!(im, jm*km)

  mask = arr3d.ne(rmiss)
  if mask.any?
    count = mask.to_type(NArray::SINT).sum(0)
    idx = count.gt(0).where
    mask = mask[true,idx]
    count = count[idx]
    arr3d = arr3d[true,idx]
    arr3d[mask.not] = 0.0
  end

  zmarr[idx] = arr3d.sum(0)/count

  lat_a = VArray.new( gphysArr3d.coord('lat').val, 
                      { "long_name"=>gphysArr3d.coord('lat').get_att('long_name'),
                        "units"=>gphysArr3d.coord('lat').get_att('units') }, 
                      "lat" )
  lat = Axis.new.set_pos(lat_a)
 
  press_a = VArray.new( gphysArr3d.coord('press').val, 
                        { "long_name"=>gphysArr3d.coord('press').get_att('long_name'),
                          "units"=>gphysArr3d.coord('press').get_att('units')},
                        "press" )
  press = Axis.new.set_pos(press_a)
 
  data = VArray.new( zmarr, 
                   {"long_name"=>gphysArr3d.get_att('long_name'), 
                    "units"=>gphysArr3d.get_att('units'), 
                    "missing_value"=>gphysArr3d.get_att('missing_value')},
                   gphysArr3d.name )
  gphys = GPhys.new( Grid.new(lat,press), data )

  return gphys

end

###################################################################

def zonalsum( gphysArr3d )

  im = gphysArr3d.shape[0] # number of elements for longitude dimension
  jm = gphysArr3d.shape[1] # number of elements for latitude  dimension
  km = gphysArr3d.shape[2] # number of elements for sigma     dimension

  arr3d = gphysArr3d.val

  rmiss = gphysArr3d.get_att('missing_value')[0]

  arr3d.reshape!(im,jm*km)
  mask = arr3d.ne(rmiss)
  arr3d[mask.not] = 0.0

  zmarr = arr3d.sum(0).reshape!(jm,km)

  lat_a = VArray.new( gphysArr3d.coord('lat').val, 
                      { "long_name"=>gphysArr3d.coord('lat').get_att('long_name'),
                        "units"=>gphysArr3d.coord('lat').get_att('units') }, 
                      "lat" )
  lat = Axis.new.set_pos(lat_a)
 
  press_a = VArray.new( gphysArr3d.coord('press').val, 
                        { "long_name"=>gphysArr3d.coord('press').get_att('long_name'),
                          "units"=>gphysArr3d.coord('press').get_att('units')},
                        "press" )
  press = Axis.new.set_pos(press_a)
 
  data = VArray.new( zmarr, 
                   {"long_name"=>gphysArr3d.get_att('long_name'), 
                    "units"=>gphysArr3d.get_att('units'), 
                    "missing_value"=>gphysArr3d.get_att('missing_value')},
                   gphysArr3d.name )
  gphys = GPhys.new( Grid.new(lat,press), data )

  return gphys

end

###################################################################

def addition( arr3d1, arr3d2, num3d, rmiss )

  im = arr3d1.shape[0]
  jm = arr3d1.shape[1]
  km = arr3d1.shape[2]

  resarr = NArray.sfloat(im,jm,km)

  mask = arr3d1.ne(rmiss) & arr3d2.ne(rmiss)

  num3d[true,true,true] = num3d + mask

  resarr[mask] = arr3d1[mask] + arr3d2[mask]
  mask = mask.not
  resarr[mask] = arr3d1[mask]


  return resarr

end

###################################################################

def division( arr3d, num3d, rmiss )

  im = arr3d.shape[0]
  jm = arr3d.shape[1]
  km = arr3d.shape[2]

  resarr = NArray.sfloat(im,jm,km)

  mask = num3d.gt(0)

  resarr[mask] = arr3d[mask] / num3d[mask]
  resarr[mask.not] = rmiss

  return resarr

end

###################################################################

def multiplyconst( arr3d, value, rmiss )

  im = arr3d.shape[0]
  jm = arr3d.shape[1]
  km = arr3d.shape[2]

  resarr = NArray.sfloat(im,jm,km)

  mask = arr3d.ne(rmiss)

  # resarr[mask] = arr3d[mask] * value
  # calculate in double precision to be identical to the result of the original source
  resarr[mask] = arr3d[mask].to_type(NArray::FLOAT) * value
  resarr[mask.not] = rmiss

  return resarr

end

###################################################################

def calcmsf_notzm( gphysArr3d, v3d, p1d, rmiss )

  im = v3d.shape[0]
  jm = v3d.shape[1]
  km = v3d.shape[2]

  lat  = gphysArr3d.coord(1).val

  msf3d = NArray.sfloat(im,jm,km).fill(rmiss)

  grav = 9.8


  imjm = im*jm

  # calculate in double precision to be identical to the result of the original source
  v3d = v3d.to_type(NArray::FLOAT)

  k = km-1
  mask = v3d[true,true,k].ne(rmiss)
  idx = mask.where + imjm * k
  msf3d[idx] = v3d[idx] * p1d[k] / grav

#  for k in km-1-1..0
  k = km-1-1
  while k >= 0
    mask0 = v3d[true,true,k].ne(rmiss)
    mask1 = v3d[true,true,k+1].ne(rmiss) & msf3d[true,true,k+1].ne(rmiss)

    mask = mask0 & mask1
    if mask.any?
      idx0 = mask.where + imjm * k
      idx1 = idx0 + imjm
      msf3d[idx0] = msf3d[idx1] + ( v3d[idx0] + v3d[idx1] ) * 0.5 * ( p1d[k] - p1d[k+1] ) / grav
    end

    mask = mask0 & mask1.not
    if mask.any?
      idx = mask.where + imjm * k
      msf3d[idx] = v3d[idx] * p1d[k] / grav
    end

    k = k - 1
  end


  pradi = 6378e3

  mask = msf3d.ne(rmiss).where
  # calculate in double precision to be identical to the result of the original source
  # factor = (cos( lat * PI / 180.0 ) * pradi * PI * 2 / im).reshape!(1,jm,1) * NArray.sfloat(im,1,km).fill(1)
  # msf3d[mask] = msf3d[mask] * factor[mask]
  factor = (cos( lat.to_type(NArray::FLOAT) * PI / 180.0 ) * pradi * PI * 2 / im).reshape!(1,jm,1) * NArray.float(im,1,km).fill(1)
  msf3d[mask] = msf3d[mask].to_type(NArray::FLOAT) * factor[mask]

  return msf3d

end

###################################################################

def packgphys( gphysArr3d, tgtpress, tgtarr3d, rmiss )

  lon  = gphysArr3d.coord(0).val
  lat  = gphysArr3d.coord(1).val
  sig  = gphysArr3d.coord(2).val
  time = gphysArr3d.coord(3).val

  lon_a = VArray.new( lon, 
                      {  "long_name"=>gphysArr3d.coord('lon').get_att('long_name'),
                         "units"=>gphysArr3d.coord('lon').get_att('units') },
                        "lon" )
  lon = Axis.new.set_pos(lon_a)
 
  lat_a = VArray.new( gphysArr3d.coord('lat').val, 
                      { "long_name"=>gphysArr3d.coord('lat').get_att('long_name'),
                        "units"=>gphysArr3d.coord('lat').get_att('units') }, 
                      "lat" )
  lat = Axis.new.set_pos(lat_a)

  press_a = VArray.new( tgtpress, 
                    {"long_name"=>"pressure","units"=>"Pa"},
                    "press" )
  press = Axis.new.set_pos(press_a)

  data = VArray.new( tgtarr3d, 
                    {"long_name"=>gphysArr3d.get_att('long_name'), 
                     "units"=>gphysArr3d.get_att('units'), 
                    "missing_value"=>[rmiss]}, 
                    gphysArr3d.name )
  gphys = GPhys.new( Grid.new(lon,lat,press), data )

  return gphys

end

###################################################################
###################################################################


#ts = 4
#te = 5

dir = ARGV[0]
ts = ARGV[1].to_i
te = ARGV[2].to_i

title = ARGV[3]

vname0 = 'Ps'
vname1 = 'Temp'
vname2 = 'U'
vname3 = 'V'
vname4 = 'QVap'

gphysPs     = GPhys::IO.open( dir+'/'+vname0+'.nc', vname0 )
gphysArr3d1 = GPhys::IO.open( dir+'/'+vname1+'.nc', vname1 )
gphysArr3d2 = GPhys::IO.open( dir+'/'+vname2+'.nc', vname2 )
gphysArr3d3 = GPhys::IO.open( dir+'/'+vname3+'.nc', vname3 )
gphysArr3d4 = GPhys::IO.open( dir+'/'+vname4+'.nc', vname4 )

arrpress = [1000.0, 925.0, 850.0, 700.0, 600.0, 500.0, 400.0, 300.0, 250.0, 200.0, 150.0, 100.0, 70.0, 50.0, 30.0, 20.0, 10.0]


rmiss = -999.0

napress = NArray.to_na(arrpress)
napress = napress * 1.0e2

im = gphysArr3d1.shape[0] # number of elements for longitude dimension
jm = gphysArr3d1.shape[1] # number of elements for latitude  dimension
km = gphysArr3d1.shape[2] # number of elements for sigma     dimension
tm = gphysArr3d1.shape[3] # number of elements for time      dimension

km2 = napress.shape[0]

nakindex3d = NArray.int(im,jm,km2)

naarr3d1 = NArray.sfloat(im,jm,km2)
nanum3d1 = NArray.int(im,jm,km2)
naarr3d2 = NArray.sfloat(im,jm,km2)
nanum3d2 = NArray.int(im,jm,km2)
naarr3d3 = NArray.sfloat(im,jm,km2)
nanum3d3 = NArray.int(im,jm,km2)
naarr3d4 = NArray.sfloat(im,jm,km2)
nanum3d4 = NArray.int(im,jm,km2)

for t in ts..te
  p t

  nakindex3d = search_index_for_interpolate_in_press( gphysPs, gphysArr3d1, t, napress )

  naarr3dss = interpolate_in_press( gphysPs, gphysArr3d1, t, napress, nakindex3d, rmiss )
  naarr3d1  = addition( naarr3d1, naarr3dss, nanum3d1, rmiss )

  naarr3dss = interpolate_in_press( gphysPs, gphysArr3d2, t, napress, nakindex3d, rmiss )
  naarr3d2  = addition( naarr3d2, naarr3dss, nanum3d2, rmiss )

  naarr3dss = interpolate_in_press( gphysPs, gphysArr3d3, t, napress, nakindex3d, rmiss )
  naarr3dss = calcmsf_notzm( gphysArr3d3, naarr3dss, napress, rmiss )
  naarr3d3  = addition( naarr3d3, naarr3dss, nanum3d3, rmiss )

  naarr3dss = interpolate_in_press( gphysPs, gphysArr3d4, t, napress, nakindex3d, rmiss )
  naarr3d4  = addition( naarr3d4, naarr3dss, nanum3d4, rmiss )
end
#
naarr3d1 = division( naarr3d1, nanum3d1, rmiss )
naarr3d2 = division( naarr3d2, nanum3d2, rmiss )
naarr3d3 = division( naarr3d3, nanum3d3, rmiss )
naarr3d4 = division( naarr3d4, nanum3d4, rmiss )

naarr3d3 = multiplyconst( naarr3d3, 1e-10, rmiss )

gphys    = packgphys( gphysArr3d1, napress, naarr3d1, rmiss )
gphyszm1 = zonalmean( gphys )
gphys    = packgphys( gphysArr3d2, napress, naarr3d2, rmiss )
gphyszm2 = zonalmean( gphys )
gphys    = packgphys( gphysArr3d3, napress, naarr3d3, rmiss )
#gphyszm3 = zonalmean( gphys )
gphyszm3 = zonalsum( gphys )
gphys    = packgphys( gphysArr3d4, napress, naarr3d4, rmiss )
gphyszm4 = zonalmean( gphys )

gphyszm3.name      = 'mass stream function'
gphyszm3.long_name = 'mass stream function'
gphyszm3.units     = '1e-10 kg s-1'


DCL.gropn(2)
DCL.sldiv('y',2,2)           # 2x2$B$K2hLLJ,3d(B, 'y'=yoko: $B:8>e"*1&>e"*:82<(B...
DCL.sgpset('lcntl', false)   # $B@)8fJ8;z$r2r<a$7$J$$(B
DCL.sgpset('lfull',true)     # $BA42hLLI=<((B
DCL.uzfact(0.7)             # $B:BI8<4$NJ8;zNs%5%$%:$r(B 0.75 $BG\(B
DCL.sgpset('lfprop',true)    # $B%W%m%]!<%7%g%J%k%U%)%s%H$r;H$&(B

DCL.glpset('lmiss',true)
DCL.glpset('rmiss',rmiss)

#< GGraph $B$K$h$k(B $BIA2h(B >
#GGraph.set_fig 'itr'=>1, 'viewport'=>[0.15,0.85,0.15,0.6], 'yrev'=>'units:Pa'
GGraph.set_fig 'itr'=>2, 'viewport'=>[0.15,0.85,0.15,0.6], 'yrev'=>'units:Pa'

# first panel
#GGraph.tone( gphyszm1, true )
GGraph.tone( gphyszm1, true,
#             'lev'=>[200,210,220,230,240,250,260,270,280,290,300],
#             # $B%l%Y%k!u%Q%?!<%s$rM[$K;XDj(B
#             'pat'=>[10999,20999,30999,40999,50999,60999,65999,70999,75999,80999,90999,95999] )
#             # $B%Q%?%s$NJ}$,(B1$B$DB?"*!^!g$^$G(B
             'lev'=>[170,180,190,200,210,220,230,240,250,260,270,280,290,300],
             # $B%l%Y%k!u%Q%?!<%s$rM[$K;XDj(B
             'pat'=>[10999,15999,20999,25999,30999,35999,40999,50999,60999,65999,70999,75999,80999,90999,95999] )
             # $B%Q%?%s$NJ}$,(B1$B$DB?"*!^!g$^$G(B
GGraph.color_bar


DCL::uxmttl('T', ' ', 1.0)
DCL::uxmttl('T', ' ', 1.0)
DCL::uxmttl('T', title, -1.0)


# second panel
#GGraph.tone( gphyszm2, true )
GGraph.tone( gphyszm2, true,
#             'lev'=>[-20,-10,0,10,20,30,40,50,60],
#             # $B%l%Y%k!u%Q%?!<%s$rM[$K;XDj(B
#             'pat'=>[10999,20999,30999,40999,50999,60999,65999,70999,80999,90999] )
#             # $B%Q%?%s$NJ}$,(B1$B$DB?"*!^!g$^$G(B
             'lev'=>[-20,-15,-10,-5,0,5,10,15,20,25,30,35,40,45,50],
             # $B%l%Y%k!u%Q%?!<%s$rM[$K;XDj(B
             'pat'=>[10999,15999,20999,25999,30999,35999,40999,45999,50999,55999,60999,65999,70999,75999,80999,90999] )
             # $B%Q%?%s$NJ}$,(B1$B$DB?"*!^!g$^$G(B
GGraph.color_bar

# third panel
#GGraph.tone( gphyszm3, true )
GGraph.tone( gphyszm3, true,
#             'lev'=>[-12e0,-10e0,-8e0,-6e0,-4e0,-2e0,0,2e0,4e0,6e0,8e0,10e0,12e0],
#             # $B%l%Y%k!u%Q%?!<%s$rM[$K;XDj(B
#             'pat'=>[10999,15999,20999,25999,30999,40999,50999,60999,70999,75999,80999,85999,90999,95999] )
#             # $B%Q%?%s$NJ}$,(B1$B$DB?"*!^!g$^$G(B
             'lev'=>[-14e0,-12e0,-10e0,-8e0,-6e0,-4e0,-2e0,0,2e0,4e0,6e0,8e0,10e0,12e0,14e0],
             # $B%l%Y%k!u%Q%?!<%s$rM[$K;XDj(B
             'pat'=>[10999,15999,20999,25999,30999,35999,40999,50999,60999,65999,70999,75999,80999,85999,90999,95999] )
             # $B%Q%?%s$NJ}$,(B1$B$DB?"*!^!g$^$G(B
GGraph.color_bar

# forth panel
#GGraph.tone( gphyszm4, true )
GGraph.tone( gphyszm4, true,
             'lev'=>[1e-4,2e-3,4e-3,6e-3,8e-3,10e-3,12e-3,14e-3,16e-3,18e-3,20e-3],
             # $B%l%Y%k!u%Q%?!<%s$rM[$K;XDj(B
             'pat'=>[1,10999,20999,30999,40999,50999,60999,65999,70999,80999,90999,95999] )
GGraph.color_bar

DCL.grcls
