require "numru/gphys"

module NumRu
  module GAnalysis

    # Library for function fitting
    # 
    module Fitting

      # Predifined functions for convenience

      # predefined Proc for fitting: Polynomial x (function of the 1st dim)
      X = proc {|*args|
        raise(ArgumentError,"# of arge must be >= 1") if args.length==0
        x = args[0]
        self.ensure_1D_NArray(x, 0)
        rank = args.length
        f = x.dup
        (rank-1).times{f.newdim!(-1)}   # f.rank becomes the number of arguments
        f
      }

      # predefined Proc for fitting: Polynomial x**2 (function of the 1st dim)
      XX = proc {|*args|
        raise(ArgumentError,"# of arge must be >= 1") if args.length==0
        x = args[0]
        self.ensure_1D_NArray(x, 0)
        rank = args.length
        f = x*x
        (rank-1).times{f.newdim!(-1)}   # f.rank becomes the number of arguments
        f
      }

      # predefined Proc for fitting: Polynomial y (function of the 2nd dim)
      Y = proc {|*args|
        raise(ArgumentError,"# of arge must be >= 2") if args.length < 2
        y = args[1]
        self.ensure_1D_NArray(y, 1)
        rank = args.length
        f = y.dup
        f.newdim!(0)
        (rank-2).times{f.newdim!(-1)}   # f.rank becomes the number of arguments
        f
      }

      # predefined Proc for fitting: Polynomial y**2 (function of the 2nd dim)
      YY = proc {|*args|
        raise(ArgumentError,"# of arge must be >= 2") if args.length < 2
        y = args[1]
        self.ensure_1D_NArray(y, 1)
        rank = args.length
        f = y*y
        f.newdim!(0)
        (rank-2).times{f.newdim!(-1)}   # f.rank becomes the number of arguments
        f
      }

      # predefined Proc for fitting: Polynomial x*y (function of the 1st&2nd dims)
      XY = proc {|*args|
        raise(ArgumentError,"# of arge must be >= 2") if args.length < 2
        x = args[0]
        y = args[1]
        self.ensure_1D_NArray(x, 0)
        self.ensure_1D_NArray(y, 1)
        rank = args.length
        f = x.newdim(-1) * y.newdim(0)
        (rank-2).times{f.newdim!(-1)}   # f.rank becomes the number of arguments
        f
      }

      @@unity = proc {|*args|
        rank = args.length
        f = NArray.sfloat(1).fill!(1.0)   # will be coersed to float when needed
        (rank-1).times{f.newdim!(-1)}
        f
      }


      module_function
      
      # Least square fit of a linear combination of any functions (basic NArray version).
      # 
      # === ARGUMENTS
      # * +data+ [NArray or NArrayMiss] multi-D data to fit
      # * +grid_locs+ [Array of 1D NArrays] Grid points of independent variables
      #   (so grid_locs.length == the # of independent variables).
      # * +functions+ [Array of Procs] Proc objects to represent the functions,
      #   which accept the elements of +grid_locs+ as the arguments (so the
      #   number of arguments fed is equal to the length of +grid_locs+).
      # * +ensemble_dims+ (optional) [nil (defualt) or Array of Integers]
      #   When <tt>grid_locs.length < data.rank</tt>,
      #   this argument can be used to specify the dimensions that are
      #   not included in grid_locs and are used for ensemble averaging
      # * +indep_dims+ (optional) [nil (defualt) or Array of Integers]
      #   When <tt>grid_locs.length < data.rank</tt>,
      #   this argument can be used to specify the dimensions that are
      #   not included in +grid_locs+ and are treated as independent, so
      #   the fitting is made for each of their component.
      # 
      # Note that the sum of the lengths of +grid_locs+, +ensemble_dims+ and
      # +indep_dims+ must be equal to the rank (# of dims) of +data+.
      # 
      # === RETURN VALUES
      #   [ c, bf, diff ]
      # where
      # *  +c+ is a NArray containing the coefficients of the functions
      #    and the constant offset; its length is one greater than the
      #    number of +functions+ because of the offset. 
      #    It is 1D unless the +indep_dims+ argument is used
      #    (see the examples below).
      # *  +bf+ is a NArray having the best fit grid point values. 
      #    Its rank is equal to data.rank, but the lengths along 
      #    +ensemble_dims+ are simply 1.
      # *  rms of the difference between the data and best fit
      # 
      # === EXAMPLES
      # * Simple 1D case
      # 
      #   Line fitting:
      # 
      #     nx = 5
      #     x = NArray.float(nx).indgen! - nx/2
      #     data = x + x*x*0.1
      #     c, bf = GAnalysis::Fitting.least_square_fit(data, [x], 
      #                                               [GAnalysis::Fitting::X])
      #     p "data:", data, "c:", c, "bf:", bf
      #   
      #   Here, +GAnalysis::Fitting::X+ is a predefined Proc to represent
      #   the first order polynomial x. The data values given as above follow
      #   f(x) = x + x**2/10. Then the result printed by the last line is
      #     "data:"
      #     NArray.float(5): 
      #     [ -1.6, -0.9, 0.0, 1.1, 2.4 ]
      #     "c:"
      #     NArray.float(2): 
      #     [ 1.0, 0.2 ]
      #     "bf:"
      #     NArray.float(5): 
      #     [ -1.8, -0.8, 0.2, 1.2, 2.2 ]
      #   The +c+ values indicate that the fitting result is f(x) = 1.0*x + 0.2,
      #   and the +bf+ values are its grid point values. 
      #
      #   Parabolic fitting:
      # 
      #   You can also fit the data by 2nd order polynomial as
      #     c, bf = GAnalysis::Fitting.least_square_fit(data, [x], 
      #                       [GAnalysis::Fitting::XX,GAnalysis::Fitting::X])
      #   Then the result will be
      #       p c   #--> [0.1, 1.0, 0.0]
      #   which indicates the original 2nd order polynomial 0.1 x**2 + x,
      #   so it follows <tt>data == bf</tt> (except for round-off error if any).
      #
      # * 1D fitting of multi-D data (ensemble case)
      #   
      #   Suppose you have a 2D NArray (or NArrayMiss) data, in which
      #   the 1st dim represents x and the 2nd dim represents something
      #   else (such as time sequence, or just a simple ensemble).
      #   If you want to use the entire data to get a single fit,
      #   use the +ensemble_dims+ argument to specify the non-x dimension(s).
      #   You can fit the data, for example, by
      #   p*sin(x) + q*cos(x) + r as follows:
      #      
      #      sin = proc{|x| NMath.sin(x)}
      #      cos = proc{|x| NMath.cos(x)}
      #      c, bf = GAnalysis::Fitting.least_square_fit(data, [x], 
      #           [sin, cos], [1])
      #   Here, the last parameter [1] is given as the arguemnt
      #   +ensemble_dims+ to express that the dimension 1
      #   (2nd dimension) of +data+ is the ensemble dimension, so the x
      #   coordinate is the remaining dimension 0 (1st dimension). The 
      #   coefficients of the functions are returned by 
      #   the 1st return value as a NArray, so
      #       p = c[0]
      #       q = c[1]
      #       r = c[2]
      # 
      # * 1D fitting of multi-D data (individual fitting)
      # 
      #   Suppose you have the same data as above, but
      #   you want to fit it for each of the 2nd dim elements. You can
      #   do it as follows:
      #      
      #      c, bf = GAnalysis::Fitting.least_square_fit(data, [x], 
      #           [sin, cos], nil, [1])
      #   
      #   Here, +nil+ is given as the 4th argument (+ensemble_dims+) 
      #   and [1] is given as the fifth (+indep_dims+).
      #   In this case, the return value +c+ is 2-dimensional; the
      #   first being the coefficients as above and the second representing
      #   the non-x (i.e., the second) dim of +data+.
      # 
      # * 2D fitting
      # 
      #   It can be done like
      # 
      #     cosx = proc {|x,y| NMath.cos(x).newdim!(-1)}
      #     sinx = proc {|x,y| NMath.sin(x).newdim!(-1)}
      #     cosy = proc {|x,y| NMath.cos(y).newdim!(0)}
      #     siny = proc {|x,y| NMath.sin(y).newdim!(0)}
      #     c, bf = GAnalysis::Fitting.least_square_fit(data4D, [x,y], 
      #                 [cosx, sinx, cosy, siny], [2,3])
      #   where +data4D+ is a 4D NArray, whose first and second dimensions
      #   (dimensions 0 and 1) represent x and y axis, respectively, and the
      #   1D NArrays +x+ and +y+ are the grid points.
      #   Note that the functions (+cosx+ etc) accept 2 arguments (x and y),
      #   and they use NArray's +newdim+ method to return 2D NArray
      #   (newdim!(-1) inserts a 1-element dim to the end, and
      #   newdim(0) inserts a 1-element dim to the beginning).
      #
      # TYPICAL ERRORS
      # * Error is raised (from the LU decomposition), if the problem 
      #   cannot be solved. That happens if you specify a same function twice
      #   (redundantly) in the +functions+ argument, as a matter of course.
      # * Error is raised if the number of data is insuffcient for the
      #   number of functions (also unsolvable).
      # 
      def least_square_fit(data, grid_locs, functions, ensemble_dims=nil,
                           indep_dims=nil, with_offset=true)

        #< argument check >

        grid_locs.each_with_index{|x,i| self.ensure_1D_NArray(x, i)}
        functions.each{|f| raise("Found non-Proc arg") if !f.is_a?(Proc)}

        if with_offset
          functions = functions + [@@unity]   # constanf offset
        end

        ng = grid_locs.length
        rank = data.rank
        ensemble_dims = [ ensemble_dims ] if ensemble_dims.is_a?(Integer)
        indep_dims = [ indep_dims ] if indep_dims.is_a?(Integer)

        ensemble_dims = Array.new if ensemble_dims.nil?   # --> always an Array
        n_indep = ( indep_dims ? indep_dims.length : 0 )

        if ng < rank 
          ensemble_dims = ensemble_dims.collect{|d| 
            if d<-rank || d>=rank
              raise "Invalid ensemble_dims value (#{d}) for rank #{rank} NArray"
            end
            d += rank if d<0
            d
          }
          ensemble_dims.sort!
          if indep_dims
            indep_dims = indep_dims.collect{|d|
              if d<-rank || d>=rank
                raise "Invalid indep_dims value (#{d}) for rank #{rank} NArray"
              end
              d += rank if d<0
              d
            }
            indep_dims.sort!
          end
        elsif ng > rank
          raise "# of grid_locs (#{ng}) > data.rank (#{rank})"
        end

        if data.rank != ng + ensemble_dims.length + n_indep
            raise ArgumentError,
               "lengths of grid_locs, ensemble_dims and indep_dims != data.rank"
        end

        otherdims = ensemble_dims
        if indep_dims
          otherdims += indep_dims
          otherdims.sort!.uniq!
          if otherdims.length != ensemble_dims.length + n_indep
            raise ArgumentError, "Overlap in ensemble_dims and indep_dims"
          end
        end

        #< pre-process data >

        d0 = data.mean
        data = data - d0     # constant offset for numerical stability

        if data.is_a?(NArrayMiss)
          mask = data.get_mask
        elsif data.is_a?(NArray)
          mask = nil  # NArray.byte(*data.shape).fill!(1)
        else
          raise "Data type (#{data.class}) is not NArray or NArrayMiss"
        end

        #< derive the matrix >

        fv = functions.collect{|f| 
          f = f[*grid_locs]
          otherdims.each{|d| f.newdim!(d)}
          f
        }

        ms = fv.length    # matrix size

        if ( (len=data.length) < ms )
          raise "Insufficient data length (#{len}) for the # of funcs+1 (#{ms})"
        end

        mat = NMatrix.float(ms,ms)   #  wil be symmetric

        for i in 0...ms
          for j in 0..i
            if mask
              fvij = NArrayMiss.to_nam( fv[i] * fv[j] * mask, mask )
              mat[i,j] = (fvij).mean
            else
              mat[i,j] = (fv[i] * fv[j]).mean
            end
          end
        end

        for i in 0...ms
          for j in i+1...ms
            mat[i,j] = mat[j,i]      # symmetric
          end
        end
        #p "*** mat ***",mat
        lu = mat.lu

        #< derive the vector, solve, and best fit >

        unless indep_dims    # fitting only once
          # derive the vector
          b = NVector.float(ms)
          for i in 0...ms
            b[i] = (data * fv[i]).mean
          end

          # solve
          c = lu.solve(b)
          c[-1] += d0      # add the mean subtracted

          # convert c from NVector to NArray (just for cleanliness)
          na = NArray.float(ms)
          na[true] = c[true]
          c = na

          # best fit
          if with_offset
            bf = c[-1]    # the constant offset
            for i in 0...ms-1
              bf += c[i]*fv[i]
            end
          else
            bf = 0.0
            for i in 0...ms
              bf += c[i]*fv[i]
            end
          end

        else    # fitting multiple times

          # derive vectors
          idshp = indep_dims.collect{|d| data.shape[d]}
          bs = NArray.float(ms,*idshp)
          meandims = (0...rank).collect{|d| d} - indep_dims
          for i in 0...ms
            bsi = (data * fv[i]).mean(*meandims)
            if bsi.is_a?(NArrayMiss)
              if bsi.count_invalid > 0
                raise("Found invalid data everywhere along indep_dims. Trim data in advance and try again.")
              end
              bsi = bsi.to_na
            end
            bs[i,false] = bsi
          end
          idlen = 1
          idshp.each{|l| idlen *= l}

          # solve
          bs = bs.reshape(ms, idlen)
          c = NArray.float(ms,idlen)
          b = NVector.float(ms)
          for id in 0...idlen
            b[true] = bs[true,id]
            c[true,id] = lu.solve(b)
          end
          c[-1,true] += d0
          c = c.reshape(ms, *idshp)

          # best fit
          idshp_full = Array.new
          for d in 0...rank
            if indep_dims.include?(d)
              idshp_full[d] = data.shape[d]
            else
              idshp_full[d] = 1
            end
          end
          cs = c.reshape(ms, *idshp_full)
          if with_offset
            bf = cs[-1,false]
            for i in 0...ms-1
              bf += cs[i,false]*fv[i]
            end
          else
            bf = cs[-1,false] * 0
            for i in 0...ms
              bf += cs[i,false]*fv[i]
            end
          end

        end

        diff = Math.sqrt( ( (data + d0 - bf)**2 ).mean )

        #< return >

        [ c, bf, diff ]
      end

      ################################################
      # For internal usage

      private

      def self.ensure_1D_NArray(na, ith)
        raise("proc argument #{ith}: not a NArray") if !na.is_a?(NArray)
        raise("proc argument #{ith}: not 1 dimensional") if na.rank != 1
        nil
      end

    end

  end

  # GPhys extension with GAnalysis::Fitting
  class GPhys
    # Least square fit of a linear combination of any functions (GPhys version).
    # 
    # This method calls GAnalysis::Fitting.least_square_fit in
    # the GAnalysis::Fitting module.
    # See its document for the details, usage, and predifined functions.
    # 
    # === ARGUMENTS
    # 
    # The arguments are the same as the third to fifth arguemnts of 
    # GAnalysis::Fitting.least_square_fit except that +ensemble_dims+
    # and +indep_dims+ accept dimension specification by names (in Strings).
    #
    # * +functions+ [Array of Procs] Proc objects to represent the functions,
    #   which accept the elements of +grid_locs+ as the arguments (so the
    #   number of arguments fed is equal to the length of +grid_locs+).
    #   (Some predifined functions are available in GAnalysis::Fitting).
    # * +ensemble_dims+ (optional) [nil (defualt) or Array of Integers or Strings]
    #   When <tt>grid_locs.length < data.rank</tt>,
    #   this argument can be used to specify the dimensions that are
    #   not included in grid_locs and are used for ensemble averaging
    # * +indep_dims+ (optional) [nil (defualt) or Array of Integers or Strings]
    #   When <tt>grid_locs.length < data.rank</tt>,
    #   this argument can be used to specify the dimensions that are
    #   not included in +grid_locs+ and are treated as independent, so
    #   the fitting is made for each of their component.
    # 
    # === RETURN VALUES
    #   [ c, bf, diff ]
    # where
    # *  +c+ is a NArray containing the coefficients of the functions
    #    and the constant offset; its length is one greater than the
    #    number of +functions+ because of the offset. 
    #    It is 1D unless the +indep_dims+ argument is used
    #    (see the examples below).
    # *  +bf+ is a GPhys having the best fit grid point values. 
    #    Its rank is equal to data.rank unless ensemble_dims
    #    are given; ensemble_dims are deleted unlike the return
    #    value of GAnalysis::Fitting.least_square_fit.
    # *  rms of the difference between the data and best fit
    # 
    # === USAGE
    # See GAnalysis::Fitting.least_square_fit.



    def least_square_fit(functions, ensemble_dims=nil, indep_dims=nil)

      #< preparation >

      no_fitting_dims = Array.new
      if ensemble_dims
        ensemble_dims = ensemble_dims.collect{|d| @grid.dim_index(d)}
        no_fitting_dims += ensemble_dims
      end
      if indep_dims
        indep_dims = indep_dims.collect{|d| @grid.dim_index(d)}
        no_fitting_dims += indep_dims
      end
      fitting_dims = (0...rank).collect{|i| i} - no_fitting_dims
      grid_locs = fitting_dims.collect{|d| coord(d).val}
      data = self.val

      #< fitting >
      c, bf, diff = GAnalysis::Fitting.least_square_fit(data, grid_locs, 
                                          functions, ensemble_dims, indep_dims)

      #< make a GPhys of the best fit >

      if !ensemble_dims
        grid = self.grid
      else
        axes = Array.new
        (0...rank).each{|d| 
          axes.push(self.axis(d)) unless ensemble_dims.include?(d)
        }
        grid = Grid.new(*axes)
        shape = bf.shape
        ensemble_dims.sort.reverse_each{|d| shape.delete_at(d)}
        bf = bf.reshape(*shape)
      end

      va = VArray.new(bf, self.data, self.name)
      bf = GPhys.new(grid, va)

      [c, bf, diff]
    end
  end

end


######################################################
if $0 == __FILE__

  include NumRu
  nx = 7
  ny = 5
  x = NArray.float(nx).indgen!
  y = NArray.float(ny).indgen! - 1
  #p GAnalysis::Fitting::X[x]
  #p GAnalysis::Fitting::X[x,nil,nil]
  #p GAnalysis::Fitting::X[x,y]
  #p GAnalysis::Fitting::XX[x,y]
  #p GAnalysis::Fitting::Y[x,y]
  #p GAnalysis::Fitting::YY[x,y]
  #p GAnalysis::Fitting::XY[x,y,nil]
  #exit

  xx = x.newdim(-1)
  yy = y.newdim(0)
  data = xx + 2*yy + 100
  data += data.random * 0.1
  p "***data**", data

  f_x = GAnalysis::Fitting::X
  f_y = GAnalysis::Fitting::Y

  p GAnalysis::Fitting.least_square_fit(data, [x,y], [f_x, f_y])

  data2 = NArray.float(nx,2,ny)
  data2[true,0,true] = data - 1
  data2[true,1,true] = data + 1
  p GAnalysis::Fitting.least_square_fit(data2, [x,y], [f_x, f_y], [1])

  nx = 5
  x = NArray.float(nx).indgen! - nx/2
  data = x + x*x*0.1
  c, bf, diff = GAnalysis::Fitting.least_square_fit(data, [x], 
                                              [GAnalysis::Fitting::X])
  p "data:", data, "c:", c, "bf:", bf
exit

  c, bf, diff = GAnalysis::Fitting.least_square_fit(data, [x], 
                             [GAnalysis::Fitting::X,GAnalysis::Fitting::XX])
  p c

  xx = x.newdim(-1)
  data = xx + 2*yy + 100
  cosx = proc {|x,y| NMath.cos(x).newdim!(-1)}
  sinx = proc {|x,y| NMath.sin(x).newdim!(-1)}
  cosy = proc {|x,y| NMath.cos(y).newdim!(0)}
  siny = proc {|x,y| NMath.sin(y).newdim!(0)}
  p GAnalysis::Fitting.least_square_fit(data, [x,y], [cosx, sinx, cosy, siny])

end
