require './gphys_const_v1.2'
require "numru/ggraph"
include NumRu
include NMath

#########################################################################################

def interpolate_on_p( plev, gp_ps_part, gp_part )

  na_plev = NArray.to_na(plev)
  va_plev = VArray.new( na_plev, {"units"=>"Pa"}, "pressure")

  imax   = gp_part.coord('lon').val.size
  jmax   = gp_part.coord('lat').val.size
  kmax   = gp_part.coord('sig').val.size
  tmax   = gp_part.coord('time').val.size
  na_sig = gp_part.coord('sig').val

  # calculate pressure at each grid
  na_p =
    gp_ps_part.val.reshape!(imax,jmax,1,tmax) *
    na_sig.reshape!(1,1,kmax,1)
  va_p = VArray.new( na_p,
	             {"long_name"=>"air_pressure", "units"=>"Pa"},
		     "pressure" )
  # time axis of object is extracted
  ax_lon     = gp_part.axis('lon')
  ax_lat     = gp_part.axis('lat')
  ax_sig     = gp_part.axis('sig')
  ax_time_sl = gp_part.axis('time')
  # make GPhys object of pressure
  gp_p = GPhys.new( Grid.new(ax_lon,ax_lat,ax_sig,ax_time_sl), va_p )

  # set pressure as an assocated coordinate
  gp_part.set_assoc_coords([gp_p])

  # interpolate values on pressure levels
  gp_part_onplev = gp_part.interpolate("sig"=>va_plev)

  return gp_part_onplev

end

#########################################################################################

def old_longmean( axname, is, ie, gphys )

  axis = gphys.coord(axname).val

  gphysout = gphys.cut(axname=>axis[is]).copy

#  di = 1
#  di = ( ie - is + 1 ) / 3
#  di = te - ts + 1
  di = [268000000 / ( gphys.length / axis.length ), 1].max
  i1 = is
  i2 = is
  while true
    i1 = i2+1
    i2 = i1+di-1
    if i2 >= ie then
      break
    end
    gphysout = gphysout + gphys.cut(axname=>axis[i1]..axis[i2]).sum(axname)
  end
  i2 = ie
  gphysout = gphysout + gphys.cut(axname=>axis[i1]..axis[i2]).sum(axname)
  gphysout = gphysout / ( ie - is + 1 )

  return gphysout

end

#########################################################################################

def findindex( axisnarray, targetvalue, step )

is = 0
ie = axisnarray.size - 1
i = 0
while axisnarray[i] < targetvalue do
  i += step
end
i

end

#########################################################################################

def replace_axis( axisname, axislongname, axisunits, axisnarray, replaceaxisname, replacegphys )

if axisnarray.size != replacegphys.coord(replaceaxisname).val.size then
  p 'Array size is not the same in replace_axis'
  exit
end

varray = VArray.new( axisnarray,
                     { "long_name"=>axislongname,
                       "units"=>axisunits }, 
                     axisname )
replacegphys.axis(replaceaxisname).set_pos(varray)

end

#########################################################################################

# dcpam data : iyr0 =    1, imon0 = 1
# NCEP  data : iyr0 = 1948, imon0 = 1
# ECMWF data : iyr0 = 1957, imon0 = 9

def clim_mon_mean_with_mon_mean( iyr0, imon0, iyrs, iyre, imon, gphys )

  time = gphys.coord('time').val


#  for iyr in iyrs..iyre
#
#    t = ( iyr - iyr0 ) * 12 + imon - imon0
#
#    if iyr == iyrs
#      gphysout = gphys.cut('time'=>time[t])
#    else
#      gphysout = gphysout + gphys.cut('time'=>time[t])
#    end
#  end
#  gphysout = gphysout / ( iyre - iyrs + 1 )



  for iyr in iyrs..iyre
    t = ( iyr - iyr0 ) * 12 + imon - imon0

    gphysout1 = gphys.cut('time'=>time[t]).copy

    begin
      gphysout1mask = gphysout1.val.get_mask.to_f
    rescue
      flagmiss = false
    else
      flagmiss = true
    end

    if iyr == iyrs
      gphysout = gphysout1
      if flagmiss then
        gphysoutmask = gphysout1.val.get_mask.to_f
      end
    else
      gphysout = gphysout + gphysout1
      if flagmiss then
        gphysoutmask = gphysoutmask + gphysout1.val.get_mask.to_f
      end
    end

  end
  if flagmiss then
    gphysoutmask[ (gphysoutmask.eq 0).where ] = 1
    gphysout = gphysout / gphysoutmask
  else
    gphysout = gphysout / ( iyre - iyrs + 1 )
  end

  return gphysout

end

#########################################################################################

# dcpam data : iyr0 =    1, imon0 = 1
# NCEP  data : iyr0 = 1948, imon0 = 1
# ECMWF data : iyr0 = 1957, imon0 = 9

def clim_mean_with_mon_mean( iyr0, imon0, iyrs, iyre, gphys )

  ts = ( iyrs - iyr0 ) * 12 +  1 - imon0
  te = ( iyre - iyr0 ) * 12 + 12 - imon0

  gphysout = longmean( 'time', ts, te, gphys )

  return gphysout

end

#########################################################################################

def longmean( axisname, is, ie, gphys )

  if ie < 0 then
    ie = gphys.coord(axisname).val.size + ie
  end

  axis = gphys.coord(axisname).val

  gphysout = gphys.cut(axisname=>axis[is]).copy
  begin
    gphysoutmask = gphysout.val.get_mask.to_f
  rescue
    flagmiss = false
  else
    flagmiss = true
  end

#    dt = 1
#  di = ( ie - is + 1 ) / 3
#    dt = te - ts + 1
  di = [268000000 / ( gphys.length / axis.length ), 1].max
  i1 = is
  i2 = is
  while true
    i1 = i2+1
    i2 = i1+di-1
    if i2 >= ie then
      break
    end
    gphysout = gphysout + gphys.cut(axisname=>axis[i1]..axis[i2]).sum(axisname)
    if flagmiss then
      gphysoutmask = gphysoutmask + gphys.cut(axisname=>axis[i1]..axis[i2]).val.get_mask.to_f.sum(gphys.rank-1)
    end
  end
  i2 = ie
  gphysout = gphysout + gphys.cut(axisname=>axis[i1]..axis[i2]).sum(axisname)
  if flagmiss then
    gphysoutmask = gphysoutmask + gphys.cut(axisname=>axis[i1]..axis[i2]).val.get_mask.to_f.sum(gphys.rank-1)
  end
  if flagmiss then
    gphysoutmask[ (gphysoutmask.eq 0).where ] = 1
    gphysout = gphysout / gphysoutmask
  else
    gphysout = gphysout / ( ie - is + 1 )
  end

  return gphysout

end

#########################################################################################

def longmean_2varmulti( axisname, is, ie, gphys, gphys2 )

  if ie < 0 then
    ie = gphys.coord(axisname).val.size + ie
  end

  axis = gphys.coord(axisname).val

  gphysout = ( gphys.cut(axisname=>axis[is]) * gphys2.cut(axisname=>axis[is]) ).copy
  begin
    gphysoutmask = gphysout.val.get_mask.to_f
  rescue
    flagmiss = false
  else
    flagmiss = true
  end

#    dt = 1
#  di = ( ie - is + 1 ) / 3
#    dt = te - ts + 1
  di = [268000000 / ( gphys.length / axis.length ), 1].max
  i1 = is
  i2 = is
  while true
    i1 = i2+1
    i2 = i1+di-1
    if i2 >= ie then
      break
    end
    gphysout = gphysout +
               ( gphys.cut(axisname=>axis[i1]..axis[i2]) * gphys2.cut(axisname=>axis[i1]..axis[i2]) ).sum(axisname)
    if flagmiss then
      gphysoutmask = gphysoutmask +
                     ( gphys.cut(axisname=>axis[i1]..axis[i2]) * gphys2.cut(axisname=>axis[i1]..axis[i2]) ).val.get_mask.to_f.sum(gphys.rank-1)
    end
  end
  i2 = ie
  gphysout = gphysout +
             ( gphys.cut(axisname=>axis[i1]..axis[i2]) * gphys2.cut(axisname=>axis[i1]..axis[i2]) ).sum(axisname)
  if flagmiss then
    gphysoutmask = gphysoutmask +
                   ( gphys.cut(axisname=>axis[i1]..axis[i2]) * gphys2.cut(axisname=>axis[i1]..axis[i2]) ).val.get_mask.to_f.sum(gphys.rank-1)
  end
  if flagmiss then
    gphysoutmask[ (gphysoutmask.eq 0).where ] = 1
    gphysout = gphysout / gphysoutmask
  else
    gphysout = gphysout / ( ie - is + 1 )
  end

  return gphysout

end

#########################################################################################

def dcpam_clim_mon_mean( iyrs, iyre, imon, gphys )

  daysOfYear = 0
  DaysOfMonth.each do |i|
    daysOfYear += i
  end

  time = gphys.coord('time').val

  for iyr in iyrs..iyre
    t = daysOfYear*(iyr-1)
    for i in 1..imon-1
      t = t + DaysOfMonth[i-1]
    end

    ts = t
    te = t + DaysOfMonth[imon-1]

    ts = ts * DataNumOfDay
    te = te * DataNumOfDay - 1

    gphysout1 = longmean( 'time', ts, te, gphys )


#    gphysout1 = gphys.cut('time'=>time[ts]).copy
    begin
      gphysout1mask = gphysout1.val.get_mask.to_f
    rescue
      flagmiss = false
    else
      flagmiss = true
    end

    if iyr == iyrs
      gphysout = gphysout1
      if flagmiss then
        gphysoutmask = gphysout1.val.get_mask.to_f
      end
    else
      gphysout = gphysout + gphysout1
      if flagmiss then
        gphysoutmask = gphysoutmask + gphysout1.val.get_mask.to_f
      end
    end

  end
  if flagmiss then
    gphysoutmask[ (gphysoutmask.eq 0).where ] = 1
    gphysout = gphysout / gphysoutmask
  else
    gphysout = gphysout / ( iyre - iyrs + 1 )
  end

  return gphysout

end

#########################################################################################

def dcpam_clim_mon_mean_2varmulti( iyrs, iyre, imon, gphys, gphys2 )

  daysOfYear = 0
  DaysOfMonth.each do |i|
    daysOfYear += i
  end

  time = gphys.coord('time').val

  for iyr in iyrs..iyre
    t = daysOfYear*(iyr-1)
    for i in 1..imon-1
      t = t + DaysOfMonth[i-1]
    end

    ts = t
    te = t + DaysOfMonth[imon-1]

    ts = ts * DataNumOfDay
    te = te * DataNumOfDay - 1

#    gphysout1 = longmean( 'time', ts, te, gphys )
    gphysout1 = longmean_2varmulti( 'time', ts, te, gphys, gphys2 )


#    gphysout1 = gphys.cut('time'=>time[ts]).copy
    begin
      gphysout1mask = gphysout1.val.get_mask.to_f
    rescue
      flagmiss = false
    else
      flagmiss = true
    end

    if iyr == iyrs
      gphysout = gphysout1
      if flagmiss then
        gphysoutmask = gphysout1.val.get_mask.to_f
      end
    else
      gphysout = gphysout + gphysout1
      if flagmiss then
        gphysoutmask = gphysoutmask + gphysout1.val.get_mask.to_f
      end
    end

  end
  if flagmiss then
    gphysoutmask[ (gphysoutmask.eq 0).where ] = 1
    gphysout = gphysout / gphysoutmask
  else
    gphysout = gphysout / ( iyre - iyrs + 1 )
  end

  return gphysout

end

#########################################################################################

def dcpam_clim_mean( iyrs, iyre, gphys )

  daysOfYear = 0
  DaysOfMonth.each do |i|
    daysOfYear += i
  end

  ts = daysOfYear*(iyrs-1)
  te = daysOfYear* iyre
  ts = ts * DataNumOfDay
  te = te * DataNumOfDay - 1

  gphysout = longmean( 'time', ts, te, gphys )

  return gphysout

end

#########################################################################################

def dcpam_clim_mean_2varmulti( iyrs, iyre, gphys, gphys2 )

  daysOfYear = 0
  DaysOfMonth.each do |i|
    daysOfYear += i
  end

  ts = daysOfYear*(iyrs-1)
  te = daysOfYear* iyre
  ts = ts * DataNumOfDay
  te = te * DataNumOfDay - 1

#  gphysout = longmean( 'time', ts, te, gphys )
  gphysout = longmean_2varmulti( 'time', ts, te, gphys, gphys2 )

  return gphysout

end

#########################################################################################

def calc_msf( gphys, yname = 'lat', zname = 'level' )

gphysout = calc_msf_core( gphys, yname, zname )

return gphysout

end

#########################################################################################

def calc_msf_core( gphys, namelat, namelev )

if namelev == 'level' then
  rmiss = gphys.get_att('missing_value')[0]
else
  rmiss = -999.0
end

gphyscopied = gphys.copy

#z = gphyscopied.axis(namelev).pos.convert_units( Units['Pa'] )
if namelev == 'level' then
  z = gphyscopied.axis(namelev).pos
  z = z.convert_units( Units['Pa'] )
  z.long_name = 'pressure'
  gphyscopied.axis(namelev).set_pos(z)
end

km    = gphyscopied.shape[1]
lat   = gphyscopied.coord(namelat).val
level = gphyscopied.coord(namelev).val


gphysout = gphyscopied.copy


#mask                  = gphyscopied.ne(rmiss)
#gphyscopied[mask.not] = 0.0

#mask        = gphyscopied.val.valid?
if namelev == 'level' then
  mask        = gphyscopied.val.get_mask
  gphyscopied = gphyscopied.val.set_missing_value(0.0).all_valid
else
  gphyscopied = gphyscopied.val
end

if level[0] > level[1] then
  gphysout[true,km-1] = gphyscopied[true,km-1] * level[km-1] / Grav * 2 * PI * RPlanet * cos( lat[true] * PI / 180.0 )
  k = km-1-1
  while k >= 0
    gphysout[true,k] = gphysout[true,k+1] + ( gphyscopied[true,k] + gphyscopied[true,k+1] ) * 0.5 * ( level[k] - level[k+1] ) / Grav * 2 * PI * RPlanet * cos( lat[true] * PI / 180.0 )
  k -= 1
  end
elsif
  gphysout[true,0] = gphyscopied[true,0] * level[0] / Grav * 2 * PI * RPlanet * cos( lat[true] * PI / 180.0 )
  k = 0+1
  while k < km
    gphysout[true,k] = gphysout[true,k-1] + ( gphyscopied[true,k] + gphyscopied[true,k-1] ) * 0.5 * ( level[k] - level[k-1] ) / Grav * 2 * PI * RPlanet * cos( lat[true] * PI / 180.0 )
  k += 1
  end
end

if namelev == 'level' then
  gphysoutNaMiss = gphysout.val.set_mask(mask)
  gphysout[true,true] = gphysoutNaMiss[true,true]
end

if namelev == 'level' then
  gphysout = gphysout * 1.0e-8
  gphysout.long_name = 'mass stream function'
  gphysout.units     = '1e8 kg/s'
else
  gphysout = gphysout * 1.0e-4 * Grav
  gphysout.long_name = 'normalized mass stream function'
  gphysout.units     = '1e4 m2/s'
end

return gphysout

end

#########################################################################################

def calc_angmom( gphys )

gphysout = calc_angmom_core( gphys, 'lat', 'level' )

return gphysout

end

#########################################################################################

def calc_angmom_core( gphys, namelat, namelev )

rmiss = gphys.get_att('missing_value')[0]

gphyscopied = gphys.copy

z = gphyscopied.axis(namelev).pos.convert_units( Units['Pa'] )
z.long_name = 'pressure'
gphyscopied.axis(namelev).set_pos(z)

#mask                  = gphyscopied.ne(rmiss)
#gphyscopied[mask.not] = rmiss #0.0

mask = gphyscopied.val.get_mask

jm    = gphyscopied.shape[0]
km    = gphyscopied.shape[1]
lat   = gphyscopied.coord(namelat).val
level = gphyscopied.coord(namelev).val

gphysout = gphyscopied.copy


k = 0
while k < km
  j = 0
  while j < jm
    if mask[j,k] == 1 then
      gphysout[j,k] = ( RPlanet * cos( lat[j] * PI / 180.0 ) * Omega + gphysout.val[j,k] ) * RPlanet * cos( lat[j] * PI / 180.0 )
    end
    j = j + 1
  end
  k = k + 1
end

gphysout = gphysout * 1.0e-8

gphysout.long_name = 'angular momentum'
gphysout.units     = '1e8 m2 s-1'

return gphysout

end

#########################################################################################

def calc_qsat_tetens( gphys )

gphysout = gphys.copy

z = gphysout.axis('level').pos.convert_units( Units['Pa'] )
z.long_name = 'pressure'
gphysout.axis('level').set_pos(z)

km    = gphysout.shape[2]
level = gphysout.coord('level').val


gasRUniv   = 8.314
molWtDry   = 28.964e-3
molWtWet   = 18.01528e-3
es0        = 611.0
latentHeat = 2.5e6
gasRWet    = gasRUniv / molWtWet
epsV       = molWtWet / molWtDry

gphysout = epsV * es0 * ( latentHeat / gasRWet * ( 1.0/273.0 - 1.0/gphysout ) ).exp

k = 0
while k < km
  gphysout[true,true,k,true] = gphysout[true,true,k,true] / level[k]
  k += 1
end

gphysout.long_name = 'saturation specific humidity'
gphysout.units     = '1'

return gphysout

end

#########################################################################################

