# MANIFEST:
#  class NMDArray
#  class Range (extension for NMDArray)

require "NArray"

class NMDArray

# NMDArray -- Numeric Multi-Dimension Array
# 1999/09/08  T. Horinouchi
#
# sort of inherits (but not technically) NArray
#
#
# class methods
#
#   NMDArray.new(l1,l2,...)     l? is the lengths of ?-th dimension;
#                      The # of the arguments determines the # of dimensions
#
# methods
#
#   [],[]=     : subset making and subset alternation, respectively
#                Example 1 (when 3D): [1..2,2,-4] [10..-9]
#                    The # of arguments must be either equal to the rank of
#                    the array (the former case) or one (the latter case).
#                    If the former, the rank of the result of [] is the same
#                    as that of the original array. Use "trim" or "trim!" to
#                    trim the dimensions of length 1.
#                    If the latter, the array is treated as if it were 1D.
#                    Then, the result of [] becomes 1D.
#                Example 2: [{0..-1,2},{-1..0,-1},0]
#                    A step can be specified by using a hash of length 1.
#                    Note that {0..-1,2} is equivalent to {0..-1=>2}.
#                    Specify a range as the key (here, 0..-1), and a step
#                    as the value (here, 2). Negative steps are accepted.
#                    In that case, range.first must be larger than range.last.
#
#   dup        : duplication (deep)
#   clone      : same as dup
#
#   trim trim!    :    eliminate the dimensions of length 1.
#                      trim creates another object; trim! transforms "self".
#                      (Example: a 3*1*2*1 4D array becomes 2D of 3*2)
#
#   length size       : total # of elements (length as a 1D array)
#   + - * / **        : numeric operators
#   abs               : absolute values
#   span_fill         : as in NArray (1D numeric array)
#
#   to_a       : into NArray (To get an Array, apply to_a again)
#
#   shape      : shape of the array (lengths of dimensions)
#
#   sin cos tan exp log log10 sqrt ldexp atan2   :  from Math module

  def initialize(*lens)
    @lens=lens.dup            # lengths of dimensions
    @nd=@lens.length      #   # of dimension
    @ntot=1 ; for i in 0..@nd-1 ; @ntot*=@lens[i]; end
    @dat=NArray.new(@ntot)
  end

  def dup
    out=NMDArray.new(*(self.shape))
    out.setv(@dat)
    return out
  end
  def clone; dup; end

  def length; @dat.length; end
  def size; @dat.size; end
  def span_fill(*v); @dat.span_fill(*v); end
  def +(a);r=self.dup;r.setv(@dat+(a.is_a?(NMDArray)? a.to_a: a));return r;end
  def -(a);r=self.dup;r.setv(@dat-(a.is_a?(NMDArray)? a.to_a: a));return r;end
  def *(a);r=self.dup;r.setv(@dat*(a.is_a?(NMDArray)? a.to_a: a));return r;end
  def /(a);r=self.dup;r.setv(@dat/(a.is_a?(NMDArray)? a.to_a: a));return r;end
  def **(a);r=self.dup;r.setv(@dat**a);return r;end
  def +@; self.dup; end
  def -@;r=self.dup;r.setv(-@dat);return r;end
  def abs;r=self.dup;r.setv(@dat.abs);return r;end
  def sin;r=self.dup;r.setv(@dat.sin);return r;end
  def cos;r=self.dup;r.setv(@dat.cos);return r;end
  def tan;r=self.dup;r.setv(@dat.tan);return r;end
  def exp;r=self.dup;r.setv(@dat.exp);return r;end
  def log;r=self.dup;r.setv(@dat.log);return r;end
  def log10;r=self.dup;r.setv(@dat.log10);return r;end
  def sqrt;r=self.dup;r.setv(@dat.sqrt);return r;end
  def ldexp(exp);r=self.dup;r.setv(@dat.ldexp(exp));return r;end
  def atan2(y);r=self.dup;r.setv(@dat.atan2(y.to_a));return r;end

  def to_a; @dat; end
  def shape; @lens; end


  def trim!
    @lens.delete(1)
    @nd=@lens.length
  end

  def trim
    out=self.dup
    out.trim!
    return out
  end

  def import1d(a)
    # Import a 1D array.
    # trim if too long; tail unchanged if too short
    if !a.is_a?(NArray) then raise(RuntimeError,"Not a NArray"); end
    if a.length <= @dat.length then
      @dat[0..a.length-1]=a
    else
      @dat[0..-1]=a[0..@dat.length-1]
    end
  end


  def [](*idx)
    # multi-dimensional subset (get).
    # NOTICE: if the number of the arguments is 1, the array is treated 
    # as if it were 1D.
    ni=idx.length
    if ni>@nd then raise(RuntimeError,"# of arguments > # of dimensions"); end
    if ni==0 then raise(RuntimeError,"argument(s) is(are) needed"); end
    if ni == 1 && (idx[0].is_a?(Range) || idx[0].is_a?(Numeric)) then
      # Short cut for 1D specification
      ii=indxar(*idx)
      il=ii[0].length
      out=NMDArray.new(il)
      out.setv(@dat[idx[0]])
      return out
    else
      # real multi-D treatment
      ii=indxar(*idx)
      il=[] ; for i in 0..ii.length-1; il=il+ii[i].length; end
      out=NMDArray.new(*il)
      out.setv(@dat.indices(*indx1d(ii)))       # indx is a private method
      return out
    end
  end

  def []=(*idx)
    # multi-dimensional subset (set).
    # NOTICE: if the number of the arguments is 1, the array is treated 
    # as if it were 1D.
    rhs=idx.pop     # idx=idx[0..-2] and rhs=idx[-1]
    ni=idx.length
    if ni>@nd then raise(RuntimeError,"# of arguments > # of dimensions"); end
    if ni==0 then raise(RuntimeError,"argument(s) is(are) needed"); end

    idxar=indx(*idx)
    if (idxar == nil) then
      raise(RuntimeError,"Invalid substitution -- nil subset specification")
    end
    if rhs.is_a?(Array) then
      if idxar.length < rhs.length then
	raise(RuntimeError,"the array at rhs is too short")
      end
      for i in 0..idxar.length-1; @dat[idxar[i]]=rhs[i]; end
    elsif rhs.is_a?(Numeric) 
      for i in 0..idxar.length-1; @dat[idxar[i]]=rhs; end
    else
      raise(RuntimeError,"invalid type: "+rhs.type.to_s)
    end
  end

  ######### PRIVATE & PROTECTED METHODS ########

  protected

  def setv(a)
    #(PRIVATE)
    # Import a 1D NAarray for internal usage
    # -- no validity check; assume the same length
    @dat=a.dup
  end

  private


  def indxar(*idx)
    #(PRIVATE)
    # indices => vector indices
    # Example: [0..-1,{1..3,2}] of a 3x4 array -> [[0,1,2],[1,3]]

    ni=idx.length
    if ni != @nd && ni != 1 then 
      raise(RuntimeError,"# of arguments ("+ni.to_s+") do not agree with dimension # ("+@nd.to_s+")")
    end

    for i in 0..ni-1
      len= ( ni != 1 ? @lens[i] : @ntot )
      if idx[i].is_a?(Range) then
	idx[i]=idx[i].to_idx(len)
      elsif idx[i].is_a?(Numeric) then
	idx[i]=idx[i] % len    # for negative values
	idx[i]=idx[i].to_a
      elsif idx[i].is_a?(Hash) then
	w=idx[i].to_a
	range=w[0][0]
	step=w[0][1]
	raise(RuntimeError,"Not a range") if (!range.is_a?(Range))
	raise(RuntimeError,"Not a Fixnum") if (!step.is_a?(Fixnum))
	idx[i]=range.to_idx(len,step)
      end
    end
    return idx
  end

  def indx1d(idx)
    #(PRIVATE)
    # vector indices (see indxar) -> 1D indices

    ni=idx.length
    nds=Array.new(ni)
    for i in 0..ni-1; nds[i]=idx[i].length; end    # nds: length of each dim
    ncf=nds.dup ; for i in 1..ni-1; ncf[i]*=ncf[i-1]; end
    if (ncf.min <= 0) then
      return nil
    end
    tot=ncf[-1]
    ncb=Array.new(ni) ; for i in 0..ni-1; ncb[i]=tot/ncf[i]; end

    index=NArray.new(tot); index.fill(0)

    cl=@lens.dup; for i in 1..@nd-1; cl[i]*=cl[i-1]; end  # cumulative lengths

    for d in 0..ni-1
      if d == 0 then
	idxd=idx[d]
      else
	idxd=Array.new(0)
	for i in 0..(nds[d]-1); idxd=idxd+[cl[d-1]*idx[d][i]]*ncf[d-1]; end
      end
      index=index+idxd*ncb[d]
    end

    return index
  end

  def indx(*idx)
    #(PRIVATE)
    # indices of a multi-D array -> indices of its equivalent 1D array
    # Example: [0..1,0..1] of a 3x4 array -> [1,2,5,6]
    indx1d(indxar(*idx))
  end

end

class Range
# extension for NMDArray

  def to_idx(len,step=1)
    # to_a with a negative value alternation.
    #
    # len : length of the dimension
    #
    # first -> first+len if (first < 0);  last -> last+len if (last < 0)
    f=self.first
    l=self.last
    f=f%len if (f<0)
    l=l%len if (l<0)
    l=l-1 if (self.exclude_end?)    # then, one can handle only with f..l
    if (step == 0) then; raise(RuntimeError,"step==0"); end
    if (step != 1) then
      shift=f
      f=0 ; l=(l-shift)/step
    end
    rg=f..l
    if(step == 1) then
      return rg.to_a
    else
      return (NArray.from_a(rg.to_a)*step + shift).to_a
    end
  end

end


####### test for development ########

a=NMDArray.new(4,3,2)
p=NArray.indgen(24)
a.import1d(p)

#p '+++++'
#p a[0]
#p a[-3..-1]
#p a[-2..-1,0..2,-1]
#a[-2..-1,0..2,-1]=555
#p a.to_a

a[{-1..0,-2},{0..2,2},0]=99
p a.to_a

#p a.to_a.type
#p a.to_a.to_a.type
#p a[{-1..0,-2},{0..2,2},0]
#p a[0..7]

#b=a[0,0..2,0]
#p b
#b.trim!
#p b

#b=a.dup
#p b.to_a
#a[{-1..0,-2},{0..2,2},0]=99
#p b.to_a

#a[-3..-1]=[331,332,333]
#p a.to_a
#a[1..3,0,0]=[881,882,883]
#p a.to_a
#p '+++++'

#b=a+a
#p b.to_a
#p b.shape
#b=(-a).abs**2
#p b.to_a

p=NArray.indgen(24,0.0,PI/23)
a.import1d(p)
p a.sin.to_a
