require "numru/gphys/gphys"
require "numru/netcdf_miss"

# ToDo :: data missing handling ; use scaled_put/scaled_get

module NumRu
  class GPhys

    module NetCDF_Convention_Users_Guide
      # NetCDF Convention of the NetCDF User's Guide
      # To be used by (not to be included in) NetCDF_IO.

      module_function

      def interpret_miss_handling_and_scaling(data)
	#if data.get_att("scale_factor") && data.get_att("add_offset")
	def data.get(*args)
	  get_with_miss_and_scaling(*args)
	end
	def data.current_put(*args)
	  put_with_miss_and_scaling(*args)
	end
	nil
      end

      def coord_var_names(ncvar)
	# name of the coordinate variables (to be the "pos" object in Axis)
	ncvar.dim_names
      end

      def cell_bounds?(coord_var)
	# whether the coordinate variale represent grid cell bounds.
	# coordvar (VArray)
	# return value:
	result = false  # Always false, because User's guide does not define it
	varray_cell_center = nil
	[result, varray_cell_center]
      end

      def cell_center?(coord_var)
	# whether the coordinate variale represent grid cell centers.
	# coordvar (VArray)
	# return value:
	#    false if not
        #    true if true and the corresponding cell bounds are not identified.
        #    a VArray if true and the bounds are found (returns it)
	result = false  # Always false, because User's guide does not define it
	varray_cell_bounds = nil
	[result, varray_cell_bounds]
      end
    end

    module NetCDF_IO

      module_function

      @@convention = NetCDF_Convention_Users_Guide

      def NetCDF_IO.set_convention(convention)
	@@convention = convention
      end
      def NetCDF_IO.convention
	@@convention
      end

      def open(file, varname)
	if file.is_a?(String)
	  file = NetCDF.open(file)
	elsif ! file.is_a?(NetCDF)
	  raise ArgumentError, "1st arg must be a NetCDF or a file name"
	end
	ncvar = file.var(varname)
	data = VArrayNetCDF.new(ncvar)
	@@convention::interpret_miss_handling_and_scaling(data)
	axposnames = @@convention::coord_var_names(ncvar)    # NC User's Guide
	rank = ncvar.rank
	bare_index = [ false ] * rank # will be true if coord var is not found

        axes = Array.new
	var_names = file.var_names
	for i in 0...rank
	  if var_names.include?(axposnames[i])
	    axpos = VArrayNetCDF.new( file.var(axposnames[i]) )
	  else
	    bare_index[i]=true
	    na = NArray.float(file.dim(dimnames[i]).length).indgen!
	    axpos = VArray.new( na )
	  end
	  cell_center, varray_cell_bounds = @@convention::cell_center?( axpos )
	  cell_bounds, varray_cell_center = @@convention::cell_bounds?( axpos )
	  cell = cell_center || cell_bounds
	  axis = Axis.new(cell,bare_index[i])
	  if !cell
	    axis.set_pos( axpos )
	  else
	    if cell_center
	      if varray_cell_bounds
		axis.set_cell(axpos, varray_cell_bounds).set_pos_to_center
	      else
		p "cell bounds are guessed"
		axis.set_cell_guess_bounds(axpos).set_pos_to_center
	      end
	    else  # then it is cell_bounds
	      if varray_cell_center
		axis.set_cell(varray_cell_center, axpos).set_pos_to_bounds
	      else
		p "cell center is guessed"
		axis.set_cell_guess_center(axpos).set_pos_to_bounds
	      end
	    end
	  end
	  
	  #p "yet-to-be-defined: method to define aux coord vars"
	  
	  axes[i] = axis
	end

	grid = Grid.new( *axes )

	GPhys.new(grid,data)
      end

      def write(file, gphys, name=nil)
	(0...(gphys.rank)).each{|i|
	  ax = gphys.axis(i)
	  dimname = ax.pos.name
	  length = ax.pos.length
	  isfx = 0
	  altdimnames = Hash.new
	  ax.flatten.each{ |va|
	    if va.length == length
	      dimnames = [dimname]
	    else
	      if (nm=altdimnames[va.length])
		dimnames = [nm]
	      else
		dimnames = [ (altdimnames[va.length] = dimname+isfx.to_s) ]
		isfx += 1
	      end
	    end
	    VArrayNetCDF.write(file, va, nil, dimnames )
	  }
	}
	VArrayNetCDF.write(file, gphys.data, name, gphys.axnames)
      end

    end
  end
end

######################################################
if $0 == __FILE__
   include NumRu

   begin
     file = NetCDF.open("../../testdata/T.jan.nc")
   rescue
     file = NetCDF.open("../../../testdata/T.jan.nc")
   end
   temp = GPhys::NetCDF_IO.open(file,"T")
   p temp.name, temp.shape_current
   temp2 = temp[true,true,2]
   p temp2.name, temp2.shape_current

   temp_xmean = temp.average(0)
   p temp.val

   temp_edy = ( temp - temp_xmean )
   p '###',temp_edy.name,temp_edy.val[0,true,true]
   p '@@@',temp
   p '///',temp.copy
   p '+++',temp2

   puts "\n** test write (tmp.nc) **"
   file2 = NetCDF.create('tmp.nc')
   p v = temp_edy.axis(0).pos[0..-2].copy.rename('lonlon')
   temp_edy.axis(0).set_aux('test',v)
   temp_edy.axis(0).set_aux('test2',(v/2).rename('lonlon2'))
   temp_edy.axis(0).set_aux('test2',(v/2).rename('lonlon3')[0..-2])
   GPhys::NetCDF_IO.write(file2,temp_edy)
   file2.close
   file3 = NetCDF.create('tmp2.nc')
   GPhys::NetCDF_IO.write(file2,temp_xmean)
   file3.close
end
