! Copyright (C) GFD Dennou Club, 2000.  All rights reserved.
! gtbinary - 変数への二項演算子適用

! 書式
!   gtbinary 左辺 演算子 右辺 [オプション=値 ...] 
!
! 出力は output オプションまたは gtool.nc@default
! 演算子は任意個指定可. デフォルトで「加算」.

subroutine help
    use sysdep
    write(*, *) "usage: gtbinary lhs operator rhs [option=value ...]"
    call abortprogram('gtbinary')
end subroutine

program gtbinary
    use gtool
    use dc_trace, only: setdebug, message
    implicit none
    character(STRING):: lhs = "", rhs = ""
    character(string):: output = "gtool.nc@default"
    character(string):: operator = ""
    character(STRING):: arg, optname, optvalue
    type(GT_VARIABLE):: lvar, rvar, ovar
    double precision, allocatable:: lhsbuf(:), rhsbuf(:), obuf(:)
    character(string):: expand_dim = ""
    integer:: i, nargs, siz, stat
!--
    nargs = GtArgCount()
    do, i = 1, nargs
        call GtArgGet(i, arg) 
        if (gtoptionform(arg, optname, optvalue)) then
            if (optname == "output" .or. optname == "out") then
                output = optvalue
            else if (optname == '-expand') then
                expand_dim = optvalue
            else if (optname == '-debug') then
                call setdebug
            else
                call put(optname // ": unknown option")
                call help
            endif
        else if (lhs == "") then
            lhs = arg
        else if (operator == "") then
            operator = arg
        else if (rhs == "") then
            rhs = arg
        endif
    enddo
    if (rhs == "") then
        call help
    endif

    if (expand_dim /= "") then
        call do_expand
    else
        call do_equivalent
    endif
    stop
contains

subroutine do_expand
    character(string):: expand
    character(token):: dimname
    integer:: comma, islice
    logical:: err
    call Open(rvar, rhs)
    call Open(lvar, lhs)
    call Create(ovar, url=output, copyfrom=lvar, copyvalue=.false.)
    expand = expand_dim
    do, while (expand /= "")
        comma = index(expand, ',')
        if (comma == 0) then
            dimname = expand
            expand = ""
        else 
            dimname = expand(1: comma-1)
	    expand = expand(comma+1: )
        endif
        write(0, *) "expand for ", trim(dimname)
        dimname = trim(dimname) // "=1:1:1"
	call Slice(lvar, dimname, err)
        if (err) exit
	call Slice(ovar, dimname, err)
    enddo
    call Inquire(lvar, size=siz)
    allocate(lhsbuf(siz), rhsbuf(siz), obuf(siz))
    call Get(rvar, rhsbuf, siz)
    islice = 0
    do
        islice = islice + 1
        write(0, "(a,i4)") "slice", islice
        call Get(lvar, lhsbuf, siz)
        call BinOp(obuf, lhsbuf(1:siz), operator, rhsbuf(1:siz), siz)
        call Put(ovar, obuf, siz)
        call Slice_Next(lvar, stat=stat);  if (stat /= 0) exit
        call Slice_Next(ovar, stat=stat);  if (stat /= 0) exit
    enddo
    call Close(lvar)
    call Close(rvar)
    call Close(ovar)
end subroutine

subroutine do_equivalent
    call Open(lvar, lhs)
    call Open(rvar, rhs)
    call Transform(rvar, lvar)
    call Create(ovar, url=output, copyfrom=lvar, copyvalue=.false.)
    call Slice(lvar, compatible=rvar)
    call Slice(ovar, compatible=lvar)
    call inquire(rvar, size=siz)
    allocate(lhsbuf(siz), rhsbuf(siz), obuf(siz))
    do
        call Get(rvar, rhsbuf, siz)
        call Get(lvar, lhsbuf, siz)
        call BinOp(obuf, lhsbuf(1:siz), operator, rhsbuf(1:siz), siz)
        call Put(ovar, obuf, siz)
        call Slice_Next(lvar, stat=stat);  if (stat /= 0) exit
        call Slice_Next(rvar, stat=stat);  if (stat /= 0) exit
        call Slice_Next(ovar, stat=stat);  if (stat /= 0) exit
    enddo
    call Close(lvar)
    call Close(rvar)
    call Close(ovar)
end subroutine

subroutine binop(out, lhs, operator, rhs, n)
    integer, intent(in):: n
    double precision, intent(out):: out(n)
    character(*), intent(in):: operator
    double precision, intent(in):: lhs(n), rhs(n)
    if (operator == '+' .or. operator == 'add') then
        out(1:n) = lhs(1:n) + rhs(1:n)
    else if (operator == '-' .or. operator == 'sub') then
        out(1:n) = lhs(1:n) - rhs(1:n)
    else if (operator == '*' .or. operator == 'mul') then
        out(1:n) = lhs(1:n) * rhs(1:n)
    else if (operator == '/' .or. operator == 'div') then
        where (abs(rhs(1:n)) > epsilon(1.0d0))
            out(1:n) = lhs(1:n) / rhs(1:n)
        elsewhere
            out(1:n) = sign(huge(1.0d0), lhs(1:n) * rhs(1:n))
        end where   
    else
        print *, 'operator "', trim(operator), '" unknown'
        out(1:n) = lhs(1:n) + rhs(1:n)
    endif
end subroutine

end program
