!--
!----------------------------------------------------------------------
! Copyright (c) 2011-2016 SPMODEL Development Group. All rights reserved.
!----------------------------------------------------------------------
!
!表題  lumatrix_cuda : 行列の LU 分解による線形連立方程式の解法(CUDA版)
!
!履歴  2011/03/14(takepiro) CUDA 版作成
!      2016/04/11(takepiro) ベクトル長問題対応 lusol3 を用意
!
!++
module lumatrix_cuda_param

  integer, parameter  :: nthread = 512   ! GPU でのスレッド数

end module lumatrix_cuda_param

module lumatrix_cuda
  use cudafor

  private
  public lumake_kernel, lusolve_kernel, lusol2_kernel, lusolm_kernel

contains

  attributes(global) subroutine lumake_kernel(alu, kp, jdim, ndim)
    !
    ! * ndim x ndim の行列 jdim 個を一度に計算
    ! * LU 行列は入力に上書きされる
    !
    implicit none
    integer, value :: jdim
    integer, value :: ndim
    integer :: kp(jdim, ndim)
    real(8) :: alu(jdim, ndim, ndim)
    integer, value :: j, k, m, n
    real(8), value :: pivot, temp

    j = (blockidx%x - 1) * blockdim%x + threadidx%x

    if ( j <= jdim ) then

       do k=1,ndim-1
          !! pivot 選択
          pivot   = alu(j,k,k)
          kp(j,k) = k
          do m=k+1,ndim
             if ( abs(alu(j,m,k)) .gt. abs(pivot)) then
                pivot   = alu(j,m,k)
                kp(j,k) = m
             end if
          end do
          if ( kp(j,k) .ne. k ) then
             do n=1,ndim
                temp             = alu(j,k,n)
                alu(j,k,n)       = alu(j,kp(j,k),n)
                alu(j,kp(j,k),n) = temp
             end do
          end if

          !! LU 分解
          do n=k+1,ndim
             alu(j,k,n) = alu(j,k,n)/pivot
             do m=k+1,ndim
                alu(j,m,n) = alu(j,m,n) - alu(j,m,k) * alu(j,k,n)
             end do
          end do
       end do
    endif
  end subroutine lumake_kernel

  attributes(global) subroutine lusolv_kernel(xv, alu, kp, idim, jdim, ndim)
    !
    ! * ndim x ndim 行列を jdim 個並べた連立方程式 AX = B を idim 個 の B
    !   について計算する.
    ! * 解は右辺入力ベクトルに上書きされる
    !
    implicit none
    integer, value :: idim
    integer, value :: jdim
    integer, value :: ndim
    real(8)        :: xv(idim, jdim, ndim)
    real(8)        :: alu(jdim, ndim, ndim)
    integer        :: kp(jdim, ndim)
    integer, value :: i, j, k, n, nn
    real(8), value :: temp

    j = (blockidx%x - 1) * blockdim%x + threadidx%x

    if ( j <=jdim ) then
       do i=1,idim
          do k=1, ndim-1
             if ( kp(j,k) .ne. k ) then
                temp            = xv(i,j,k)
                xv(i,j,k)       = xv(i,j,kp(j,k))
                xv(i,j,kp(j,k)) = temp
             end if
          end do

          do n=1, ndim
             xv(i,j,n) = xv(i,j,n)/alu(j,n,n)
             do nn=n+1,ndim
                xv(i,j,nn) = xv(i,j,nn) - xv(i,j,n) * alu(j,nn,n)
             end do
          end do

          do k=ndim-1, 1, -1
             do n=k+1,ndim
                xv(i,j,k) = xv(i,j,k) - xv(i,j,n) * alu(j,k,n)
             end do
          end do
       enddo
    endif

  end subroutine lusolv_kernel

  attributes(global) subroutine lusol2_kernel(xv, alu, kp, idim, ndim)
    !
    ! * ndim x ndim 型行列の連立方程式 A X = B を idim 個の B に対して計算.
    ! * 解は右辺の入力ベクトルに上書きされる.
    ! * LUSOLV の JDIM = 1 に相当. JDIM=1 には際のベクトル長が短くなるため,
    !   このルーチンを用意している
    !
    implicit none
    integer, value :: idim
    integer, value :: ndim
    real(8)        :: xv(idim, ndim)
    real(8)        :: alu(ndim,ndim)
    integer        :: kp(ndim)
    integer, value :: i, k, n, nn
    real(8), value :: temp

    i = (blockidx%x - 1) * blockdim%x + threadidx%x
    
    if ( i <= idim ) then
       do k = 1, ndim-1
         if ( kp ( k ) .ne. k ) then
           temp           = xv ( i,k )
           xv ( i,k )     = xv ( i,kp(k) )
           xv ( i,kp(k) ) = temp
         endif
       end do

       do n = 1, ndim
         xv ( i,n ) = xv ( i,n ) / alu ( n,n )
         do nn = n+1, ndim
           xv ( i,nn ) = xv ( i,nn ) - xv ( i,n ) * alu ( nn,n )
         end do
       end do

       do k = ndim-1, 1, -1
         do n = k+1, ndim
           xv ( i,k ) = xv ( i,k ) - xv ( i,n ) * alu ( k,n )
         end do
       end do
    endif

  end subroutine lusol2_kernel

  attributes(global) subroutine lusolm_kernel(xv, alu, kp, jmtx, idim, jdim, ndim)
    !
    ! * jdim 個の ndim x ndim 型行列について
    !   連立方程式 A X = B を idim 個の B に対して計算.
    ! * 解は右辺の入力ベクトルに上書きされる.
    !
    implicit none
    integer, value :: idim
    integer, value :: jdim
    integer, value :: ndim
    real(8)        :: xv(idim, ndim)
    real(8)        :: alu(jdim, ndim, ndim)
    integer        :: kp(jdim, ndim)
    integer        :: jmtx(idim)
    integer, value :: i, k, n, nn
    real(8), value :: temp

    i = (blockidx%x - 1) * blockdim%x + threadidx%x

    if ( i <= idim ) then

       do k = 1, ndim-1
         if ( kp ( jmtx(i),k ) .ne. k ) then
           temp             = xv ( i,k )
           xv ( i,k )       = xv ( i,kp(jmtx(i),k) )
           xv ( i,kp(jmtx(i),k) ) = temp
         endif
       end do

       do  n = 1, ndim
         xv ( i,n ) = xv ( i,n ) / alu ( jmtx(i),n,n )
         do  nn = n+1, ndim
           xv ( i,nn ) = xv ( i,nn ) - xv ( i,n ) * alu ( jmtx(i),nn,n )
         end do
       end do

       do  k = ndim-1, 1, -1
         do  n = k+1, ndim
           xv ( i,k ) = xv ( i,k ) - xv ( i,n ) * alu ( jmtx(i),k,n )
         end do
       end do

    end if

  end subroutine lusolm_kernel

end module lumatrix_cuda

subroutine lumake(alu, kp, jdim, ndim)
  !
  ! * ndim x ndim の行列 jdim 個を一度に計算
  ! * LU 行列は入力に上書きされる
  !
  use lumatrix_cuda_param
  use lumatrix_cuda
  implicit none
  integer, intent(in)    :: jdim
  integer, intent(in)    :: ndim
  integer, intent(out)   :: kp(jdim, ndim)
  real(8), intent(inout) :: alu(jdim, ndim, ndim)

  integer, device        :: kpd(jdim, ndim)
  real(8), device        :: alud(jdim, ndim, ndim)

  integer :: nblock

  alud = alu

  if ( mod(jdim, nthread) == 0 ) then
     nblock = jdim/nthread
  else
     nblock = jdim/nthread+1
  endif

  call lumake_kernel<<<nblock, nthread>>>(alud, kpd, jdim, ndim)

  kp  = kpd
  alu = alud

end subroutine lumake

subroutine lumak1(alu, kp, ndim)
  !
  ! * ndim x ndim の行列一個を計算
  ! * LU 行列は入力に上書き
  !
  implicit none
  integer, intent(in)    :: ndim
  real(8), intent(inout) :: alu(ndim,ndim)
  integer, intent(out)   :: kp(ndim)
  integer :: k,m,n
  real(8) :: pivot, temp

  do k=1, ndim-1
    !! pivot 選択
    pivot = alu(k,k)
    kp(k) = k
    do m=k+1,ndim
      if (abs(alu(m,k)) .gt. abs(pivot)) then
        pivot = alu(m,k)
        kp(k) = m
      end if
    end do
    if ( kp(k) .ne. k ) then
      do n=1, ndim
        temp         = alu(k,n)
        alu(k,n)     = alu(kp(k),n)
        alu(kp(k),n) = temp
      end do
    end if

    do n=k+1,ndim
      alu(k,n) = alu(k,n)/pivot
    end do

    do n=k+1,ndim
      do m=k+1,ndim
        alu(m,n) = alu(m,n) - alu(m,k)*alu(k,n)
      end do
    end do

  end do
end subroutine lumak1

subroutine lusolv(xv, alu, kp, idim, jdim, ndim)
  !
  ! * ndim x ndim 行列を jdim 個並べた連立方程式 AX = B を idim 個 の B
  !   について計算する.
  ! * 解は右辺入力ベクトルに上書きされる
  !
  use lumatrix_cuda_param
  use lumatrix_cuda
  use cudafor
  implicit none
  integer, intent(in) :: idim
  integer, intent(in) :: jdim
  integer, intent(in) :: ndim
  real(8), intent(inout) :: xv(idim, jdim, ndim)
  real(8), intent(in) :: alu(jdim, ndim, ndim)
  integer, intent(in) :: kp(jdim, ndim)

  real(8), device :: xvd(idim, jdim, ndim)
  real(8), device :: alud(jdim, ndim, ndim)
  integer, device :: kpd(jdim, ndim)

  integer    :: nblock

  if ( mod(jdim, nthread) == 0 ) then
     nblock = jdim/nthread
  else
     nblock = jdim/nthread+1
  endif

  alud = alu ;  kpd  = kp ;  xvd  = xv

  call lusolv_kernel<<<nblock, nthread>>>(xvd, alud, kpd, idim, jdim, ndim)

  xv   = xvd

end subroutine lusolv

subroutine lusol3(xv, alu, kp, jdim, ndim)
  !
  ! * ndim x ndim 行列を jdim 個並べた連立方程式 AX = B を jdim 個 の B
  !   について計算する.
  ! * 解は右辺入力ベクトルに上書きされる
  !
  use lumatrix_cuda_param
  use lumatrix_cuda
  use cudafor
  implicit none
  integer, intent(in) :: jdim
  integer, intent(in) :: ndim
  real(8), intent(inout) :: xv(jdim, ndim)
  real(8), intent(in) :: alu(jdim, ndim, ndim)
  integer, intent(in) :: kp(jdim, ndim)

  real(8), device :: xvd(1, jdim, ndim)
  real(8), device :: alud(jdim, ndim, ndim)
  integer, device :: kpd(jdim, ndim)

  integer    :: nblock

  if ( mod(jdim, nthread) == 0 ) then
     nblock = jdim/nthread
  else
     nblock = jdim/nthread+1
  endif

  alud = alu ;  kpd  = kp ;  xvd  = xv

  call lusolv_kernel<<<nblock, nthread>>>(xvd, alud, kpd, 1, jdim, ndim)

  xv   = xvd

end subroutine lusol3

subroutine lusol2(xv, alu, kp, idim, ndim)
  !
  ! * ndim x ndim 型行列の連立方程式 A X = B を idim 個の B に対して計算.
  ! * 解は右辺の入力ベクトルに上書きされる.
  ! * LUSOLV の JDIM = 1 に相当. JDIM=1 には際のベクトル長が短くなるため,
  !   このルーチンを用意している
  !
  use lumatrix_cuda_param
  use lumatrix_cuda
  implicit none
  integer, intent(in) :: idim
  integer, intent(in) :: ndim
  real(8), intent(inout) :: xv(idim, ndim)
  real(8), intent(in) :: alu(ndim,ndim)
  integer, intent(in) :: kp(ndim)

  real(8), device :: xvd(idim, ndim)
  real(8), device :: alud(ndim,ndim)
  integer, device :: kpd(ndim)

  integer :: nblock

  if ( mod(idim, nthread) == 0 ) then
     nblock = idim/nthread
  else
     nblock = idim/nthread+1
  endif

  xvd = xv ; alud = alu ; kpd = kp

  call lusol2_kernel<<<nblock, nthread>>>(xvd, alud, kpd, idim, ndim)
  
  xv = xvd

end subroutine lusol2

subroutine lusolm(xv, alu, kp, jmtx, idim, jdim, ndim)
  !
  ! * jdim 個の ndim x ndim 型行列について
  !   連立方程式 A X = B を idim 個の B に対して計算.
  ! * 解は右辺の入力ベクトルに上書きされる.
  !
  use lumatrix_cuda_param
  use lumatrix_cuda
  implicit none
  integer, intent(in)    :: idim
  integer, intent(in)    :: jdim
  integer, intent(in)    :: ndim
  real(8), intent(inout) :: xv(idim, ndim)
  real(8), intent(in)    :: alu(jdim, ndim, ndim)
  integer, intent(in)    :: kp(jdim, ndim)
  integer, intent(in)    :: jmtx(idim)

  real(8), device :: xvd(idim, ndim)
  real(8), device :: alud(jdim, ndim, ndim)
  integer, device :: kpd(jdim, ndim)
  integer, device :: jmtxd(idim)

  integer :: nblock
  
  if ( mod(idim, nthread) == 0 ) then
     nblock = idim/nthread
  else
     nblock = idim/nthread+1
  endif

  xvd = xv ; alud = alu ; kpd = kp ; jmtxd = jmtx

  call lusolm_kernel<<<nblock, nthread>>>(xvd, alud, kpd, jmtxd, idim, jdim, ndim)

  xv = xvd

end subroutine lusolm

