#----------------------------------------------------------------------
# Python interface for ISPACK3
# Copyright (C) 2023--2024 Toshiki Matsushima <toshiki@gfd-dennou.org>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301 USA.
#----------------------------------------------------------------------
import sys
from mpi4py import MPI
import numpy as np
import time
import ispack3 as isp

jm  = 2**10
ntr = 1
mm  = jm-1
im  = jm*2
nm  = mm
nn  = nm
km = 8
ndv = 2

ig=1
ipow=0

if( ndv>km ):
    print('Please set NDV not to be larger than KM. --> stop')
    sys.exit()

comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

if( ndv>size ):
    print('Please set NDV not to be larger than the number of processes. --> stop')
    sys.exit()

k1, k2, coms = isp.sykini(km,ndv,comm)

np_size = coms.Get_size()

jv = isp.syqrjv(jm)
it, t, r = isp.syini1(mm, nm, im, coms)
p, jc = isp.syini2(mm, nm, jm, ig, r, coms)

g_shape = ((km-1)//ndv+1, ((jm//jv-1)//np_size+1)*jv, im)
w_shape = (2 * jv * ((jm // jv - 1) // np_size + 1) * (mm // np_size + 1) * np_size * 2, )
s_shape = ( (km-1)//ndv+1, (mm // np_size + 1) * (2 * (nn + 1) - mm // np_size * np_size))

G = isp.aligned_array(g_shape, align=64)
W = isp.aligned_array(w_shape, dtype=np.float64, align=64)

S = np.empty(s_shape, dtype=np.float64)

if comm.Get_rank() == 0:
    skall_shape = (km, (mm+1)**2)
else:
    skall_shape = (0, 0)

if coms.Get_rank() == 0:
    sall_shape = ((km-1)//ndv+1, (mm+1)**2)
else:
    sall_shape = ((km-1)//ndv+1,0)
    
SKALL = np.empty(skall_shape, dtype=np.float64)
SALL = np.empty(sall_shape, dtype=np.float64)

np.random.seed(0)
SKALL = 2 * np.random.rand(*SKALL.shape) - 1
SKALLD = np.empty(SKALL.shape, dtype=np.float64)

isp.syksxx( (mm+1)**2, km, ndv, SKALL, SALL, comm, coms )

for k in range(k2-k1+1):
    isp.syss2s(mm,nn,SALL[k,:], S[k,:], coms)

if(rank==0):
    print("MM=", mm, ", IM=", im, ", JM=", jm, ", JV=", jv, ", KM=", km, ", NDV=", ndv, ", NTR=", ntr)
    print("SSE=", isp.mxgcpu())
    print("number of threads =", isp.mxgomp())
    print("number of processes =", size)    

rc=(5*im*np.log(im)/np.log(2.0)*0.5*jm+(mm+1)*(mm+1)*jm)*km
    
comm.Barrier()
start_time = time.perf_counter()

for n in range(ntr):
    for k in range(k2-k1+1):
        isp.syts2g(mm,nm,nn,im,jm,jv,S[k,:], G[k,:], it,t,p,r,jc,W,ipow,coms)

comm.Barrier()            
elapsed_time = time.perf_counter() - start_time
    
GFLOPS=rc*ntr/elapsed_time/10**9

if(rank==0):    
    print("S2G:", elapsed_time/ntr, "sec (", GFLOPS, "GFlops)")

comm.Barrier()
start_time = time.perf_counter()

for n in range(ntr):
    for k in range(k2-k1+1):
        isp.sytg2s(mm,nm,nn,im,jm,jv,S[k,:], G[k,:], it,t,p,r,jc,W,ipow,coms)

comm.Barrier()        
elapsed_time = time.perf_counter() - start_time
    
GFLOPS=rc*ntr/elapsed_time/10**9

if(rank==0):    
    print("G2S:", elapsed_time/ntr, "sec (", GFLOPS, "GFlops)")

for k in range(k2-k1+1):
    isp.sygs2s(mm,nn,S[k,:],SALL[k,:],coms)

isp.sykgxx( (mm+1)**2, km, ndv, SALL, SKALLD, comm, coms )


if(rank==0):
    
    l = np.arange(len(SKALL[0,:]))
    n, m = isp.sxl2nm(nn, l)

    SL_values = np.zeros_like(SKALL, dtype=np.float64)
    
    for k in range(k2-k1+1):
        SL_values[k, m == 0] = np.abs(SKALLD[k, m == 0] - SKALL[k, m == 0])
        SL_values[k, m > 0] = np.sqrt((SKALLD[k, m > 0] - SKALL[k, m > 0])**2 + (SKALLD[k, m < 0] - SKALL[k, m < 0])**2)
        
    SLMAX = np.max(SL_values)

    flat_index_of_max = np.argmax(SL_values)
    max_index = np.unravel_index(flat_index_of_max, SL_values.shape)
    
    LAS = max_index[1]

    n, m = isp.sxl2nm(mm, LAS)
    k = max_index[0]

    SLK_values = np.sum(SL_values**2,axis=1)
    SLAMAXK = np.max(SLK_values)
    KSA = np.argmax(SLK_values)
    
    print("maxerror=", SLMAX, "(n=", n, ", m=", m, ", k=", k,  ")" )
    print("rmserror=", np.sqrt(SLAMAXK/((mm+1)*(mm+2)/2) ), " (k=", KSA, ")" )
