#----------------------------------------------------------------------
# Python interface for ISPACK3
# Copyright (C) 2023--2024 Toshiki Matsushima <toshiki@gfd-dennou.org>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
# 
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301 USA.
#----------------------------------------------------------------------
import ctypes
import numpy as np

# Load the shared library
lib = ctypes.CDLL('./libispack3.so')

def aligned_array(shape, dtype=np.float64, align=64):
    """
    Create an aligned array with a given alignment.
    
    Args:
    shape (tuple): The shape of the array.
    dtype (data-type, optional): The desired data-type for the array.
    align (int, optional): The desired memory alignment in bytes.
    
    Returns:
    np.ndarray: Aligned array.
    """

    element_size = np.dtype(dtype).itemsize
    bytes_to_allocate = np.prod(shape) * element_size

    buffer = np.empty(bytes_to_allocate + align, dtype=np.uint8)

    start_index = -buffer.ctypes.data % align

    aligned_arr = np.frombuffer(buffer[start_index:start_index + bytes_to_allocate], dtype=dtype)
    aligned_arr = aligned_arr.reshape(shape)

    aligned_arr[:] = 0

    assert aligned_arr.ctypes.data % align == 0, "Array is not properly aligned"
    
    return aligned_arr

def sxini1(mm, nm, im):
    it_shape = (im // 2, )
    t_shape = (im * 3 // 2, )
    r_shape = (((mm + 1) * (2 * nm - mm - 1) + 1) // 4 * 3 + (2 * nm - mm) * (mm + 1) // 2 + mm + 1, )

    IT = np.empty(it_shape, dtype=np.int64)
    T = np.empty(t_shape, dtype=np.float64)
    R = np.empty(r_shape, dtype=np.float64)

    lib.sxini1_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # IM
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape),  # T
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=r_shape)   # R
    ]

    lib.sxini1_.restype = None

    lib.sxini1_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(im)),
        IT,
        T,
        R
    )

    return IT, T, R

def sxini2(mm, nm, jm, ig, r):
    jc_shape = (mm*(2*nm-mm-1)//16+mm,)
    p_shape = (2*mm+5,jm//2)
    r_shape = (((mm+1)*(2*nm-mm-1)+1)//4*3+(2*nm-mm)*(mm+1)//2+mm+1,)

    P = aligned_array(p_shape, dtype=np.float64, align=64)
    JC = np.empty(jc_shape, dtype=np.int64)

    lib.sxini2_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # JM
        ctypes.POINTER(ctypes.c_int64),  # IG
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=p_shape),  # P
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=r_shape),  # R
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape)   # JC
    ]

    lib.sxini2_.restype = None

    lib.sxini2_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(ig)),
        P,
        r,
        JC
    )

    return P, JC

def sxnm2l(nn, n, m):
    """
    Calculates the storage location of spectral data from the total wave number and zonal wave number.

    Parameters:
    nn : int
        Cutoff wave number N.
    n : int or array_like
        Total wave number. Can be a single integer or an array of integers.
    m : int or array_like
        Zonal wave number. Can be a single integer or an array of integers. Positive for m = M, negative for m = -M.

    Returns:
    l : int or ndarray
        The adjusted storage location of the spectral data. Returns an integer if the arguments are scalar, or an array of integers if the arguments are arrays.

    Notes:
    For m > 0, returns the storage location of the real part, and for m < 0, returns the storage location of the imaginary part.
    The calculated storage locations are adjusted to be compatible with Python's zero-based indexing system.
    """
    
    n, m = np.atleast_1d(n), np.atleast_1d(m)
    l = np.where(
        m == 0,
        n,
        np.where(
            m > 0,
            nn + 1 + (m - 1) * (2 * nn + 2 - m) + 2 * (n - m),
            nn + 1 + (-m - 1) * (2 * nn + 2 + m) + 2 * (n + m) + 1
        )
    )

    if np.isscalar(n) and np.isscalar(m):
        return l[0]
    return l

def sxl2nm(nn, l):
    """
    Performs the inverse operation of sxnm2l, i.e., calculates the total and zonal wave numbers from the storage location of spectral data.

    Parameters:
    nn : int
        Cutoff wave number N.
    l : int or array_like
        The adjusted storage location of the spectral data. Can be a single integer or an array of integers.

    Returns:
    n : int or ndarray
        Total wave number. Returns an integer if the argument is scalar, or an array of integers if the argument is an array.
    m : int or ndarray
        Zonal wave number. The sign of m has the same meaning as in sxnm2l.

    Notes:
    This function reversely calculates the corresponding total and zonal wave numbers from the given storage location.
    """
        
    l = np.atleast_1d(l)
    m = np.zeros_like(l, dtype=np.int64)
    n = np.zeros_like(l, dtype=np.int64)
    
    m[:] = np.where(l <= nn, 0, nn + 1.5 - np.sqrt(1.0 * (nn + 1) * (nn + 1) - l))
    n[:] = np.where(l <= nn, l, m + (l - (nn + 1 + (m - 1) * (2 * nn + 2 - m))) / 2)

    mask_additional = np.where(l > nn, l - (nn + 1 + (m - 1) * (2 * nn + 2 - m)) != 2 * (n - m), False)
    m[mask_additional] = -m[mask_additional]

    if np.isscalar(l) or l.size == 1:
        return n[0], m[0]
    return n, m

def sxts2g(mm, nm, nn, im, jm, s, G, it, t, p, r, jc, W, ipow):
    s_shape = ((2*nn+1-mm)*mm+nn+1,)
    g_shape = (jm,im)
    it_shape = (im // 2, )
    t_shape = (im * 3 // 2, )
    p_shape = (2*mm+5,jm//2)
    r_shape = (((mm+1)*(2*nm-mm-1)+1)//4*3+(2*nm-mm)*(mm+1)//2+mm+1,)
    jc_shape = (mm*(2*nm-mm-1)//16+mm,)
    W_shape = (jm*im,)

    #G = aligned_array(g_shape, dtype=np.float64, align=64)
    #W = aligned_array(W_shape, dtype=np.float64, align=64)

    lib.sxts2g_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # NN
        ctypes.POINTER(ctypes.c_int64),  # IM
        ctypes.POINTER(ctypes.c_int64),  # JM
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g_shape),  # G
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape),  # T
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=p_shape),  # P
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=r_shape),  # R
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),  # JC
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=W_shape),  # W
        ctypes.POINTER(ctypes.c_int64)   # IPOW
    ]

    lib.sxts2g_.restype = None

    lib.sxts2g_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(im)),
        ctypes.byref(ctypes.c_int64(jm)),
        s,
        G,
        it,
        t,
        p,
        r,
        jc,
        W,
        ctypes.byref(ctypes.c_int64(ipow))
    )

    return

def sxtg2s(mm, nm, nn, im, jm, S, g, it, t, p, r, jc, W, ipow):
    s_shape = ((2*nn+1-mm)*mm+nn+1,)
    g_shape = (jm,im)
    it_shape = (im // 2, )
    t_shape = (im * 3 // 2, )
    p_shape = (2*mm+5,jm//2)
    r_shape = (((mm+1)*(2*nm-mm-1)+1)//4*3+(2*nm-mm)*(mm+1)//2+mm+1,)
    jc_shape = (mm*(2*nm-mm-1)//16+mm,)
    W_shape = (jm*im)

    #S = np.empty(s_shape, dtype=np.float64)
    #W = aligned_array(W_shape, dtype=np.float64, align=64)

    lib.sxtg2s_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # NN
        ctypes.POINTER(ctypes.c_int64),  # IM
        ctypes.POINTER(ctypes.c_int64),  # JM
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g_shape),  # G
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape),  # T
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=p_shape),  # P
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=r_shape),  # R
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),  # JC
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=W_shape),  # W
        ctypes.POINTER(ctypes.c_int64)   # IPOW
    ]

    lib.sxtg2s_.restype = None

    lib.sxtg2s_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(im)),
        ctypes.byref(ctypes.c_int64(jm)),
        S,
        g,
        it,
        t,
        p,
        r,
        jc,
        W,
        ctypes.byref(ctypes.c_int64(ipow))
    )

    return

def sxts2v(mm, nm, nn, im, jm, s1, s2, G1, G2, it, t, p, r, jc, W, ipow):
    s1_shape = ((2*nn+1-mm)*mm+nn+1,)
    s2_shape = ((2*nn+1-mm)*mm+nn+1,)
    g1_shape = (jm, im)
    g2_shape = (jm, im)
    it_shape = (im // 2, )
    t_shape = (im * 3 // 2, )
    p_shape = (2*mm+5,jm//2)
    r_shape = (((mm+1)*(2*nm-mm-1)+1)//4*3+(2*nm-mm)*(mm+1)//2+mm+1,)
    jc_shape = (mm*(2*nm-mm-1)//16+mm,)
    W_shape = (jm*im*2)

    #G1 = aligned_array(g1_shape, dtype=np.float64, align=64)
    #G2 = aligned_array(g2_shape, dtype=np.float64, align=64)
    #W = aligned_array(W_shape, dtype=np.float64, align=64)

    lib.sxts2v_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # NN
        ctypes.POINTER(ctypes.c_int64),  # IM
        ctypes.POINTER(ctypes.c_int64),  # JM
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s1_shape),  # S1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s2_shape),  # S2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g1_shape),  # G1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g2_shape),  # G2
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape),  # T
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=p_shape),  # P
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=r_shape),  # R
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),  # JC
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=W_shape),  # W
        ctypes.POINTER(ctypes.c_int64)   # IPOW
    ]

    lib.sxts2v_.restype = None

    lib.sxts2v_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(im)),
        ctypes.byref(ctypes.c_int64(jm)),
        s1,
        s2,
        G1,
        G2,
        it,
        t,
        p,
        r,
        jc,
        W,
        ctypes.byref(ctypes.c_int64(ipow))
    )

    return

def sxtv2s(mm, nm, nn, im, jm, S1, S2, g1, g2, it, t, p, r, jc, W, ipow):
    s1_shape = ((2*nn+1-mm)*mm+nn+1,)
    s2_shape = ((2*nn+1-mm)*mm+nn+1,)
    g1_shape = (jm,im)
    g2_shape = (jm,im)
    it_shape = (im // 2, )
    t_shape = (im * 3 // 2, )
    p_shape = (2*mm+5,jm//2)
    r_shape = (((mm+1)*(2*nm-mm-1)+1)//4*3+(2*nm-mm)*(mm+1)//2+mm+1,)
    jc_shape = (mm*(2*nm-mm-1)//16+mm,)
    w_shape = (jm*im*2)

    #S1 = np.empty(s1_shape, dtype=np.float64)
    #S2 = np.empty(s2_shape, dtype=np.float64)
    #W = aligned_array(w_shape, dtype=np.float64, align=64)

    lib.sxtv2s_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # NN
        ctypes.POINTER(ctypes.c_int64),  # IM
        ctypes.POINTER(ctypes.c_int64),  # JM
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s1_shape),  # S1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s2_shape),  # S2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g1_shape),  # G1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g2_shape),  # G2
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape),  # T
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=p_shape),  # P
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=r_shape),  # R
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),  # JC
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=w_shape),  # W
        ctypes.POINTER(ctypes.c_int64)   # IPOW
    ]

    lib.sxtv2s_.restype = None

    lib.sxtv2s_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(im)),
        ctypes.byref(ctypes.c_int64(jm)),
        S1,
        S2,
        g1,
        g2,
        it,
        t,
        p,
        r,
        jc,
        W,
        ctypes.byref(ctypes.c_int64(ipow))
    )

    return

def sxinic(mm, nt):
    c_shape = ((2*nt - mm + 1) * (mm + 1), )
    C = np.empty(c_shape, dtype=np.float64)

    lib.sxinic_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=c_shape)   # C
    ]

    lib.sxinic_.restype = None

    lib.sxinic_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt)),
        C
    )

    return C

def sxcs2y(mm, nt, s, SY, c):
    s_shape = (nt + 1 + mm * (2 * nt - mm + 1), )
    sy_shape = (nt + 2 + mm * (2 * nt - mm + 3), )
    c_shape = ((2 * nt - mm + 1) * (mm + 1), )

    #SY = np.empty(sy_shape, dtype=np.float64)

    lib.sxcs2y_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),   # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sy_shape),  # SY
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=c_shape)    # C
    ]

    lib.sxcs2y_.restype = None

    lib.sxcs2y_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt)),
        s,
        SY,
        c
    )

    return

def sxcy2s(mm, nt, sy, S, c):
    sy_shape = (nt + 2 + mm * (2 * nt - mm + 3), )
    s_shape = (nt + 1 + mm * (2 * nt - mm + 1), )
    c_shape = ((2 * nt - mm + 1) * (mm + 1), )

    #S = np.empty(s_shape, dtype=np.float64)

    lib.sxcy2s_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sy_shape),  # SY
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),   # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=c_shape)    # C
    ]

    lib.sxcy2s_.restype = None

    lib.sxcy2s_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt)),
        sy,
        S,
        c
    )

    return

def sxcs2x(mm, nt, s, SX):
    s_shape = (nt + 1 + mm * (2 * nt - mm + 1), )
    sx_shape = (nt + 1 + mm * (2 * nt - mm + 1), )

    #SX = np.empty(sx_shape, dtype=np.float64)

    lib.sxcs2x_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sx_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sx_shape)  # SX
    ]

    lib.sxcs2x_.restype = None

    lib.sxcs2x_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt)),
        s,
        SX
    )

    return

def sxinid(mm, nt):
    d_shape = (2,nt + 1 + mm * (2 * nt - mm + 1))

    D = np.empty(d_shape, dtype=np.float64)

    lib.sxinid_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=d_shape)   # D
    ]

    lib.sxinid_.restype = None

    lib.sxinid_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt)),
        D
    )

    return D

def sxclap(mm, nt, s, SL, d, iflag):

    s_shape = (nt + 1 + mm * (2 * nt - mm + 1),)
    sl_shape = (nt + 1 + mm * (2 * nt - mm + 1),)
    d_shape = (2, nt + 1 + mm * (2 * nt - mm + 1))

    #SL = np.empty(sl_shape, dtype=np.float64)

    lib.sxclap_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),   # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sl_shape),   # SL
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=d_shape),       # D
        ctypes.POINTER(ctypes.c_int64)   # IFLAG
    ]

    lib.sxclap_.restype = None

    lib.sxclap_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt)),
        s,
        SL,
        d,
        ctypes.byref(ctypes.c_int64(iflag))
    )

    return

def sxcrpk(mm, nt1, nt2, s1, S2):

    s1_shape = ((2 * nt1 + 1 - mm) * mm + nt1 + 1,)
    s2_shape = ((2 * nt2 + 1 - mm) * mm + nt2 + 1,)

    #S2 = np.empty(s2_shape, dtype=np.float64)

    lib.sxcrpk_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT1
        ctypes.POINTER(ctypes.c_int64),  # NT2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s1_shape),   # S1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s2_shape)    # S2
    ]

    lib.sxcrpk_.restype = None

    lib.sxcrpk_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt1)),
        ctypes.byref(ctypes.c_int64(nt2)),
        s1,
        S2
    )

    return


def syini1(mm, nm, im, icom):
    np_size = icom.Get_size()

    it_shape = (im // 2, )
    t_shape = (im * 3 // 2, )
    r_shape = (5 * (mm // np_size + 1) * (2 * nm - mm // np_size * np_size) // 4 + mm // np_size + 1, )

    IT = np.empty(it_shape, dtype=np.int64)
    T = np.empty(t_shape, dtype=np.float64)
    R = np.empty(r_shape, dtype=np.float64)

    lib.syini1_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # IM
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape),  # T
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=r_shape),  # R
        ctypes.POINTER(ctypes.c_int64),  # ICOM
    ]

    lib.syini1_.restype = None

    lib.syini1_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(im)),
        IT,
        T,
        R,
        ctypes.byref(ctypes.c_int64(icom.py2f())),
    )

    return IT, T, R

def syini2(mm, nm, jm, ig, r, icom):
    np_size = icom.Get_size()

    p_shape = (5 + 2 * (mm // np_size + 1), jm // 2)
    r_shape = (5 * (mm // np_size + 1) * (2 * nm - mm // np_size * np_size) // 4 + mm // np_size + 1, )
    jc_shape = ((mm // np_size + 1) * (2 * nm - mm // np_size * np_size) // 16 + mm // np_size + 1, )

    P = aligned_array(p_shape, dtype=np.float64, align=64)
    JC = np.empty(jc_shape, dtype=np.int64)

    lib.syini2_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # JM
        ctypes.POINTER(ctypes.c_int64),  # IG
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=p_shape),  # P
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=r_shape),  # R
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),  # JC
        ctypes.POINTER(ctypes.c_int64),  # ICOM
    ]

    lib.syini2_.restype = None

    lib.syini2_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(ig)),
        P,
        r,
        JC,
        ctypes.byref(ctypes.c_int64(icom.py2f())),
    )

    return P, JC

def synm2l(nn, n, m, icom):
    """
    Calculates the responsible process ID (rank) and the storage location of spectral data from total and zonal wave numbers.

    Parameters:
    nn : int
        Cutoff wave number N.
    n : int or array_like
        Total wave number. Can be a single integer or an array of integers.
    m : int or array_like
        Zonal wave number. Can be a single integer or an array of integers.
    icom : MPI Communicator
        MPI Communicator used for computation.

    Returns:
    l : int or ndarray
        The adjusted storage location of the spectral data in the array.
    ip : int or ndarray
        The ID (rank) of the process responsible for the given wave numbers.

    Notes:
    For m > 0, computes the storage location of Re(sm_n), and for m < 0, computes the storage location of Im(sm_n).
    SYPACK distributes the spectral data among processes, and this subroutine calculates the storage location and responsible process ID for given wave numbers.
    The calculated storage locations are adjusted to be compatible with Python's zero-based indexing system.
    """
      
    np_size = icom.Get_size()

    n, m = np.atleast_1d(n), np.atleast_1d(m)
    m_abs = np.abs(m)
    k = m_abs // np_size
    ip = np.where(k % 2 == 0, m_abs - k * np_size, (k + 1) * np_size - m_abs - 1)

    l = np.zeros_like(m, dtype=np.int64)
    l[m == 0] = n[m == 0]
    l[m > 0] = k[m > 0] * (2 * (nn + 1) - (k[m > 0] - 1) * np_size) + 2 * (n[m > 0] - m[m > 0])
    l[m < 0] = k[m < 0] * (2 * (nn + 1) - (k[m < 0] - 1) * np_size) + 2 * (n[m < 0] + m[m < 0]) + 1

    if np.isscalar(n) and np.isscalar(m):
        return l[0], ip[0]
    return l, ip

def syl2nm(mm, nn, l, icom):
    """
    Inverse operation of synm2l. Calculates total and zonal wave numbers from the storage location of spectral data in each process.

    Parameters:
    mm : int
        Cutoff wave number M.
    nn : int
        Cutoff wave number N.
    l : int or array_like
        The adjusted storage location of the spectral data. Can be a single integer or an array of integers.
    icom : MPI Communicator
        MPI Communicator used for computation.

    Returns:
    n : int or ndarray
        Total wave number.
    m : int or ndarray
        Zonal wave number.

    Notes:
    (a) The sign of M has the same meaning as in synm2l.
    (b) If the spectral data is not handled by the process, or if there is no spectral data at the specified location L, then N=-1, M=0 are returned as output.
    This function is used to calculate the total and zonal wave numbers based on the storage location of spectral data in distributed processing environments.
    """
    
    np_size = icom.Get_size()
    ip = icom.Get_rank()

    l = np.atleast_1d(l)
    m = np.zeros_like(l, dtype=np.int64)
    n = np.zeros_like(l, dtype=np.int64)
    
    lm = (mm // np_size + 1) * (2 * (nn + 1) - mm // np_size * np_size)
    dk = np.full_like(l, (nn + 1.0) / np_size + 0.5, dtype=np.float64)
    dk = dk - np.sqrt(dk ** 2 - (l + 0.5) / np_size)
    k = np.floor(dk)

    m[:] = np.where(l >= lm, 0, k * np_size + ip + np.mod(k, 2) * (np_size - 2 * ip - 1))
    ls = np.where(l >= lm, 0, k * (2 * (nn + 1) - (k - 1) * np_size))
    n[:] = np.where(l >= lm, -1, np.where(m == 0, l, m + (l - ls) / 2))

    mask_odd = np.mod(l - ls, 2) == 1
    m[mask_odd] = -m[mask_odd]

    mask_invalid = (m > mm) | (n > nn)
    m[mask_invalid] = 0
    n[mask_invalid] = -1

    if np.isscalar(l) or l.size == 1:
        return n[0], m[0]
    return n, m

def syqrnm(mm, icom):
    np_size = icom.Get_size()

    mc_shape = (mm // np_size + 1, )
    MC = np.empty(mc_shape, dtype=np.int64)

    lib.syqrnm_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # MCM
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=mc_shape),  # MC
        ctypes.POINTER(ctypes.c_int64),  # ICOM
    ]

    lib.syqrnm_.restype = None

    mcm_c = ctypes.c_int64(0)

    lib.syqrnm_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(mcm_c),
        MC,
        ctypes.byref(ctypes.c_int64(icom.py2f())),
    )

    return mcm_c.value, MC


def syqrnj(jm, jv, icom):
    """
    Queries the range of Gaussian latitudes handled by the process, adjusted for Python's zero-based indexing.

    This function is a Python wrapper for a Fortran subroutine, designed to work with Python's indexing system. 
    It returns the range of Gaussian latitude indices that a given process is responsible for in a distributed 
    computing environment.

    Parameters:
    jm : int
        Total number of north-south grid points.
    jv : int
        Vector length for the transformation.
    icom : MPI.Comm
        The MPI communicator used for computation.

    Returns:
    j1 : int
        The lower boundary index (zero-based) of the Gaussian latitudes handled by the process.
    j2 : int
        The upper boundary index (zero-based) of the Gaussian latitudes handled by the process.

    Notes:
    - If the process does not handle any grid data, j1=-1 and j2=-2 are returned.
    - The returned indices are adjusted to align with Python's zero-based indexing, where the first element is indexed as 0.
    """
    
    lib.syqrnj_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # JM
        ctypes.POINTER(ctypes.c_int64),  # JV
        ctypes.POINTER(ctypes.c_int64),  # J1
        ctypes.POINTER(ctypes.c_int64),  # J2
        ctypes.POINTER(ctypes.c_int64),  # ICOM
    ]


    lib.syqrnj_.restype = None

    j1_c = ctypes.c_int64(0)
    j2_c = ctypes.c_int64(0)

    lib.syqrnj_(
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(jv)),
        ctypes.byref(j1_c),
        ctypes.byref(j2_c),
        ctypes.byref(ctypes.c_int64(icom.py2f())),
    )

    return j1_c.value-1, j2_c.value-1

def syts2g(mm, nm, nn, im, jm, jv, s, G, it, t, p, r, jc, W, ipow, icom):

    np_size = icom.Get_size()
    s_shape = ((mm // np_size + 1) * (2 * (nn + 1) - mm // np_size * np_size), )
    g_shape = (((jm//jv-1)//np_size+1)*jv, im)
    it_shape = (im // 2, )
    t_shape = (im * 3 // 2, )
    p_shape = (5 + 2 * (mm // np_size + 1), jm // 2)
    r_shape = (5 * (mm // np_size + 1) * (2 * nm - mm // np_size * np_size) // 4 + mm // np_size + 1, )
    jc_shape = ((mm // np_size + 1) * (2 * nm - mm // np_size * np_size) // 16 + mm // np_size + 1, )
    w_shape = (2*jv*(((jm//jv-1)//np_size+1)*(mm//np_size+1)*np_size*2),)

    #G = aligned_array(g_shape, dtype=np.float64, align=64)
    #W = aligned_array(w_shape, dtype=np.float64, align=64)

    lib.syts2g_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # NN
        ctypes.POINTER(ctypes.c_int64),  # IM
        ctypes.POINTER(ctypes.c_int64),  # JM
        ctypes.POINTER(ctypes.c_int64),  # JV
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g_shape),  # G
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape),  # T
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=p_shape),  # P
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=r_shape),  # R
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),  # JC
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=w_shape),  # W
        ctypes.POINTER(ctypes.c_int64),  # IPOW
        ctypes.POINTER(ctypes.c_int64),  # ICOM
    ]

    # Define result type
    lib.syts2g_.restype = None

    # Call Fortran subroutine
    lib.syts2g_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(im)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(jv)),
        s,
        G,
        it,
        t,
        p,
        r,
        jc,
        W,
        ctypes.byref(ctypes.c_int64(ipow)),
        ctypes.byref(ctypes.c_int64(icom.py2f())),
    )

    return

def syts2v(mm, nm, nn, im, jm, jv, s1, s2, G1, G2, it, t, p, r, jc, W, ipow, icom):

    np_size = icom.Get_size()
    s1_shape = ((mm // np_size + 1) * (2 * (nn + 1) - mm // np_size * np_size), )
    s2_shape = ((mm // np_size + 1) * (2 * (nn + 1) - mm // np_size * np_size), )
    g1_shape = (((jm//jv-1)//np_size+1)*jv, im)
    g2_shape = (((jm//jv-1)//np_size+1)*jv, im)
    it_shape = (im // 2, )
    t_shape = (im * 3 // 2, )
    p_shape = (5 + 2 * (mm // np_size + 1), jm // 2)
    r_shape = (5 * (mm // np_size + 1) * (2 * nm - mm // np_size * np_size) // 4 + mm // np_size + 1, )
    jc_shape = ((mm // np_size + 1) * (2 * nm - mm // np_size * np_size) // 16 + mm // np_size + 1, )
    w_shape = (2*jv*(((jm//jv-1)//np_size+1)*(mm//np_size+1)*np_size*2*2),)

    #G1 = aligned_array(g1_shape, dtype=np.float64, align=64)
    #G2 = aligned_array(g2_shape, dtype=np.float64, align=64)
    #W = aligned_array(w_shape, dtype=np.float64, align=64)

    lib.syts2v_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # NN
        ctypes.POINTER(ctypes.c_int64),  # IM
        ctypes.POINTER(ctypes.c_int64),  # JM
        ctypes.POINTER(ctypes.c_int64),  # JV
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s1_shape),  # S1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s2_shape),  # S2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g1_shape),  # G1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g2_shape),  # G2
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape),  # T
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=p_shape),  # P
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=r_shape),  # R
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),  # JC
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=w_shape),  # W
        ctypes.POINTER(ctypes.c_int64),  # IPOW
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.syts2v_.restype = None

    lib.syts2v_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(im)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(jv)),
        s1, s2,
        G1, G2,
        it, t, p, r, jc,
        W,
        ctypes.byref(ctypes.c_int64(ipow)),
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )

    return

def sytg2s(mm, nm, nn, im, jm, jv, S, g, it, t, p, r, jc, W, ipow, icom):

    np_size = icom.Get_size()
    g_shape = (((jm//jv-1)//np_size+1)*jv, im)
    it_shape = (im // 2, )
    t_shape = (im * 3 // 2, )
    p_shape = (5 + 2 * (mm // np_size + 1), jm // 2)
    r_shape = (5 * (mm // np_size + 1) * (2 * nm - mm // np_size * np_size) // 4 + mm // np_size + 1, )
    jc_shape = ((mm // np_size + 1) * (2 * nm - mm // np_size * np_size) // 16 + mm // np_size + 1, )
    s_shape = ((mm // np_size + 1) * (2 * (nn + 1) - mm // np_size * np_size), )
    w_shape = (2 * jv * ((jm // jv - 1) // np_size + 1) * (mm // np_size + 1) * np_size * 2, )

    #S = np.empty(s_shape, dtype=np.float64)
    #W = aligned_array(w_shape, dtype=np.float64, align=64)

    lib.sytg2s_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # NN
        ctypes.POINTER(ctypes.c_int64),  # IM
        ctypes.POINTER(ctypes.c_int64),  # JM
        ctypes.POINTER(ctypes.c_int64),  # JV
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g_shape),  # G
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape),  # T
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=p_shape),  # P
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=r_shape),  # R
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),  # JC
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=w_shape),  # W
        ctypes.POINTER(ctypes.c_int64),  # IPOW
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.sytg2s_.restype = None

    lib.sytg2s_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(im)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(jv)),
        S,
        g,
        it,
        t,
        p,
        r,
        jc,
        W,
        ctypes.byref(ctypes.c_int64(ipow)),
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )

    return

def sytv2s(mm, nm, nn, im, jm, jv, S1, S2, g1, g2, it, t, p, r, jc, W, ipow, icom):

    np_size = icom.Get_size()
    s1_shape = ((mm//np_size+1)*(2*(nn+1)-mm//np_size*np_size),)
    s2_shape = ((mm//np_size+1)*(2*(nn+1)-mm//np_size*np_size),)
    g1_shape = (((jm//jv-1)//np_size+1)*jv, im)
    g2_shape = (((jm//jv-1)//np_size+1)*jv, im)
    it_shape = (im // 2, )
    t_shape = (im * 3 // 2, )
    p_shape = (5 + 2 * (mm // np_size + 1), jm // 2)
    r_shape = (5 * (mm // np_size + 1) * (2 * nm - mm // np_size * np_size) // 4 + mm // np_size + 1, )
    jc_shape = ((mm // np_size + 1) * (2 * nm - mm // np_size * np_size) // 16 + mm // np_size + 1, )
    w_shape = (2*jv*(((jm//jv-1)//np_size+1)*(mm//np_size+1)*np_size*2*2),)

    #S1 = np.empty(s_shape, dtype=np.float64)
    #S2 = np.empty(s_shape, dtype=np.float64)
    #W = aligned_array(w_shape, dtype=np.float64, align=64)

    lib.sytv2s_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # NN
        ctypes.POINTER(ctypes.c_int64),  # IM
        ctypes.POINTER(ctypes.c_int64),  # JM
        ctypes.POINTER(ctypes.c_int64),  # JV
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s1_shape),  # S1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s2_shape),  # S2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g1_shape),  # G1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g2_shape),  # G2
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape),  # T
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=p_shape),  # P
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=r_shape),  # R
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),  # JC
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=w_shape),  # W
        ctypes.POINTER(ctypes.c_int64),  # IPOW
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.sytv2s_.restype = None

    lib.sytv2s_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(im)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(jv)),
        S1, S2,
        g1, g2,
        it, t, p, r, jc,
        W,
        ctypes.byref(ctypes.c_int64(ipow)),
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )

    return


def sygs2s(mm, nn, s, SALL, icom):

    np_size = icom.Get_size()

    s_shape = ((mm // np_size + 1) * (2 * (nn + 1) - mm // np_size * np_size), )
    if icom.Get_rank() == 0:
        sall_shape = ((2 * nn + 1 - mm) * mm + nn + 1, )
    else:
        sall_shape = (0, )

    #SALL = np.empty(sall_shape, dtype=np.float64)

    lib.sygs2s_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NN
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sall_shape),  # SALL
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.sygs2s_.restype = None

    lib.sygs2s_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nn)),
        s,
        SALL,
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )

    return


def syss2s(mm, nn, sall, S, icom):

    np_size = icom.Get_size()

    s_shape = ((mm // np_size + 1) * (2 * (nn + 1) - mm // np_size * np_size), )
    if icom.Get_rank() == 0:
        sall_shape = ((2 * nn + 1 - mm) * mm + nn + 1, )
    else:
        sall_shape = (0, )

    #S = np.empty(s_shape, dtype=np.float64)

    lib.syss2s_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NN
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sall_shape),  # SALL
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # S
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.syss2s_.restype = None

    lib.syss2s_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nn)),
        sall,
        S,
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )

    return

def sygg2g(im, jm, jv, GALL, g, icom):

    np_size = icom.Get_size()

    if icom.Get_rank() == 0:
        gall_shape = (jm, im)
    else:
        gall_shape = (0, 0)
    g_shape = (((jm//jv-1)//np_size+1)*jv, im)

    #GALL = np.empty(gall_shape, dtype=np.float64)

    lib.sygg2g_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # IM
        ctypes.POINTER(ctypes.c_int64),  # JM
        ctypes.POINTER(ctypes.c_int64),  # JV
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g_shape),  # G
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=gall_shape),  # GALL
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.sygg2g_.restype = None

    lib.sygg2g_(
        ctypes.byref(ctypes.c_int64(im)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(jv)),
        g,
        GALL,
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )

    return


def sysg2g(im, jm, jv, gall, G, icom):

    np_size = icom.Get_size()

    if icom.Get_rank() == 0:
        gall_shape = (jm, im)
    else:
        gall_shape = (0, 0)
    g_shape = (((jm//jv-1)//np_size+1)*jv, im)

    G = aligned_array(g_shape, dtype=np.float64, align=64)

    lib.sysg2g_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # IM
        ctypes.POINTER(ctypes.c_int64),  # JM
        ctypes.POINTER(ctypes.c_int64),  # JV
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=gall_shape),  # GALL
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g_shape),  # G
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.sysg2g_.restype = None

    lib.sysg2g_(
        ctypes.byref(ctypes.c_int64(im)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(jv)),
        gall,
        G,
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )

    return

def syinic(mm, nt, icom):

    np_size = icom.Get_size()

    c_shape = ((mm // np_size + 1) * (2 * (nt + 1) - mm // np_size * np_size), )

    C = np.empty(c_shape, dtype=np.float64)

    lib.syinic_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=c_shape),  # C
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.syinic_.restype = None

    lib.syinic_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt)),
        C,
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )

    return C

def sycs2y(mm, nt, s, SY, c, icom):

    np_size = icom.Get_size()

    s_shape = ((mm // np_size + 1) * (2 * (nt + 1) - mm // np_size * np_size), )
    sy_shape = ((mm // np_size + 1) * (2 * (nt + 2) - mm // np_size * np_size), )
    c_shape = ((mm // np_size + 1) * (2 * (nt + 1) - mm // np_size * np_size), )

    #SY = np.empty(sy_shape, dtype=np.float64)

    lib.sycs2y_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sy_shape), # SY
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=c_shape),  # C
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.sycs2y_.restype = None

    lib.sycs2y_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt)),
        s,
        SY,
        c,
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )

    return

def sycy2s(mm, nt, sy, S, c, icom):

    np_size = icom.Get_size()

    s_shape = ((mm // np_size + 1) * (2 * (nt + 1) - mm // np_size * np_size), )
    sy_shape = ((mm // np_size + 1) * (2 * (nt + 2) - mm // np_size * np_size), )
    c_shape = ((mm // np_size + 1) * (2 * (nt + 1) - mm // np_size * np_size), )

    #S = np.empty(s_shape, dtype=np.float64)

    lib.sycy2s_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sy_shape), # SY
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=c_shape),  # C
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.sycy2s_.restype = None

    lib.sycy2s_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt)),
        sy,
        S,
        c,
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )
    
    return


def sycs2x(mm, nt, s, SX, icom):

    np_size = icom.Get_size()

    s_shape = ((mm // np_size + 1) * (2 * (nt + 1) - mm // np_size * np_size), )
    sx_shape = ((mm // np_size + 1) * (2 * (nt + 1) - mm // np_size * np_size), )

    #SX = np.empty(sx_shape, dtype=np.float64)

    lib.sycs2x_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sx_shape), # SX
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.sycs2x_.restype = None

    lib.sycs2x_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt)),
        s,
        SX,
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )
    
    return

def syinid(mm, nt, icom):

    np_size = icom.Get_size()

    d_shape = ((mm // np_size + 1) * (2 * (nt + 1) - mm // np_size * np_size) * 2, )

    D = np.empty(d_shape, dtype=np.float64)

    lib.syinid_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=d_shape),  # D
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.syinid_.restype = None

    lib.syinid_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt)),
        D,
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )
    return D

def syclap(mm, nt, s, SL, d, iflag, icom):

    np_size = icom.Get_size()

    s_shape = ((mm // np_size + 1) * (2 * (nt + 1) - mm // np_size * np_size), )
    sl_shape = ((mm // np_size + 1) * (2 * (nt + 1) - mm // np_size * np_size), )
    d_shape = ((mm // np_size + 1) * (2 * (nt + 1) - mm // np_size * np_size) * 2, )

    #SL = np.empty(sl_shape, dtype=np.float64)

    lib.syclap_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sl_shape), # SL
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=d_shape),  # D
        ctypes.POINTER(ctypes.c_int64),  # IFLAG
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.syclap_.restype = None

    lib.syclap_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt)),
        s,
        SL,
        d,
        ctypes.byref(ctypes.c_int64(iflag)),
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )
    
    return

def sycrpk(mm, nt1, nt2, s1, S2, icom):

    np_size = icom.Get_size()

    s1_shape = ((mm // np_size + 1) * (2 * (nt1 + 1) - mm // np_size * np_size), )
    s2_shape = ((mm // np_size + 1) * (2 * (nt2 + 1) - mm // np_size * np_size), )

    #S2 = np.empty(s2_shape, dtype=np.float64)

    lib.sycrpk_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # MM
        ctypes.POINTER(ctypes.c_int64),  # NT1
        ctypes.POINTER(ctypes.c_int64),  # NT2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s1_shape),  # S1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s2_shape), # S2
        ctypes.POINTER(ctypes.c_int64)   # ICOM
    ]

    lib.sycrpk_.restype = None

    lib.sycrpk_(
        ctypes.byref(ctypes.c_int64(mm)),
        ctypes.byref(ctypes.c_int64(nt1)),
        ctypes.byref(ctypes.c_int64(nt2)),
        s1,
        S2,
        ctypes.byref(ctypes.c_int64(icom.py2f()))
    )
    
    return

def sykini(km, ndv, icoml):
    """
    Initializes a multilayer setup in SYPACK for Python. It divides a global MPI communicator for each layer of a 
    multilayer model and returns the communicator for the current process, along with the adjusted range of layer 
    numbers it handles, suitable for Python's zero-based indexing.

    Parameters:
    km : int
        The number of layers in the multilayer model.
    ndv : int
        The number of divisions in the layer direction.
    icoml : MPI.Comm
        The global MPI communicator used for the entire multilayer computation.

    Returns:
    k1 : int
        The adjusted lower layer number boundary (zero-based) that the divided communicator handles.
    k2 : int
        The adjusted upper layer number boundary (zero-based) that the divided communicator handles.
    icoms : MPI.Comm
        The divided MPI communicator that the current process belongs to.

    Notes:
    - The division number NDV must satisfy NDV ≤ min(NPL, KM), where NPL is the number of processes in the original 
      communicator ICOML, and KM is the number of layers.
    """

    npl4 = icoml.Get_size()
    ipl4 = icoml.Get_rank()
    
    if ndv > npl4 or ndv > km:
        raise ValueError('*** error in SYKINI: NDV must <= # of processes and <= KM.')

    nph1 = (npl4-1) // ndv
    nph2 = nph1 + 1
    na = npl4 - nph1 * ndv

    if ipl4 < na * nph2:
        icolor = ipl4 // nph2
        ikey = ipl4 % nph2
        nph = nph2
    else:
        icolor = ndv - 1 - (npl4 - ipl4 - 1) // nph1
        ikey = nph1 - 1 - (npl4 - ipl4 - 1) % nph1
        nph = nph1

    icoms = icoml.Split(color=icolor, key=ikey)

    nk1 = (km-1) // ndv
    nk2 = nk1 + 1
    na = km - nk1 * ndv

    if icolor < na:
        k1 = nk2 * icolor
        k2 = k1 + nk2 - 1
    else:
        k2 = km - nk1 * (ndv - icolor - 1) - 1
        k1 = k2 - nk1 + 1

    return k1, k2, icoms

def sykgxx(ndim, km, ndv, x, XALL, icoml, icoms):

    if icoms.Get_rank() == 0:
        x_shape = ((km-1) // ndv + 1, ndim)
    else:
        x_shape = ((km-1) // ndv + 1, 0)
        
    if icoml.Get_rank() == 0:
        xall_shape = (km, ndim)
    else:
        xall_shape = (0, 0)

    #XALL = np.empty(xall_shape, dtype=np.float64)

    lib.sykgxx_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # NDIM
        ctypes.POINTER(ctypes.c_int64),  # KM
        ctypes.POINTER(ctypes.c_int64),  # NDV
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=x_shape),  # X
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=xall_shape), # XALL
        ctypes.POINTER(ctypes.c_int64)   # ICOML
    ]

    lib.sykgxx_.restype = None

    lib.sykgxx_(
        ctypes.byref(ctypes.c_int64(ndim)),
        ctypes.byref(ctypes.c_int64(km)),
        ctypes.byref(ctypes.c_int64(ndv)),
        x,
        XALL,
        ctypes.byref(ctypes.c_int64(icoml.py2f()))
    )
    
    return

def syksxx(ndim, km, ndv, xall, X, icoml, icoms):

    if icoms.Get_rank() == 0:
        x_shape = ((km-1) // ndv + 1, ndim)
    else:
        x_shape = ((km-1) // ndv + 1, 0)
        
    if icoml.Get_rank() == 0:
        xall_shape = (km, ndim)
    else:
        xall_shape = (0, 0)

    #X = np.empty(x_shape, dtype=np.float64)

    lib.syksxx_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # NDIM
        ctypes.POINTER(ctypes.c_int64),  # KM
        ctypes.POINTER(ctypes.c_int64),  # NDV
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=xall_shape), # XALL
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=x_shape),  # X
        ctypes.POINTER(ctypes.c_int64)   # ICOML
    ]

    lib.syksxx_.restype = None

    lib.syksxx_(
        ctypes.byref(ctypes.c_int64(ndim)),
        ctypes.byref(ctypes.c_int64(km)),
        ctypes.byref(ctypes.c_int64(ndv)),
        xall,
        X,
        ctypes.byref(ctypes.c_int64(icoml.py2f()))
    )
    
    return


def syqrjv(jm):

    lib.syqrjv_.argtypes = [ctypes.POINTER(ctypes.c_int64),
                           ctypes.POINTER(ctypes.c_int64)]
    lib.syqrjv_.restype = None

    jv = ctypes.c_int64(0)

    lib.syqrjv_(ctypes.byref(ctypes.c_int64(jm)), ctypes.byref(jv))

    return jv.value

def mxgcpu():

    icpu = ctypes.c_int64()

    lib.mxgcpu_.argtypes = [
        ctypes.POINTER(ctypes.c_int64)  # ICPU
    ]

    lib.mxgcpu_.restype = None

    lib.mxgcpu_(
        ctypes.byref(icpu)
    )

    sse_mappings = {
        0: "fort",
        10: "avx",
        20: "fma",
        30: "avx512"
    }
    
    return sse_mappings.get(icpu.value, "unknown")


def mxsomp(nth):

    lib.mxsomp_.argtypes = [
        ctypes.POINTER(ctypes.c_int64)  # NTH
    ]

    lib.mxsomp_.restype = None

    lib.mxsomp_(
        ctypes.byref(ctypes.c_int64(nth))
    )

    return

def mxgomp():

    nth = ctypes.c_int64()

    lib.mxgomp_.argtypes = [
        ctypes.POINTER(ctypes.c_int64)  # NTH
    ]

    lib.mxgomp_.restype = None

    lib.mxgomp_(
        ctypes.byref(nth)
    )

    return nth.value

def lxinig(jm, ig):

    pz_shape = (5, jm // 2)

    pz = np.empty(pz_shape, dtype=np.float64)

    lib.lxinig_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # JM
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pz_shape),  # PZ
        ctypes.POINTER(ctypes.c_int64)   # IG
    ]

    lib.lxinig_.restype = None

    lib.lxinig_(
        ctypes.byref(ctypes.c_int64(jm)),
        pz,
        ctypes.byref(ctypes.c_int64(ig))
    )

    return pz


def lxinir(nm, m):

    rm_shape = ((nm - m) // 2 * 3 + nm - m + 1, )

    rm = np.empty(rm_shape, dtype=np.float64)

    lib.lxinir_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # M
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=rm_shape)  # RM
    ]

    lib.lxinir_.restype = None

    lib.lxinir_(
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(m)),
        rm
    )

    return rm

def lxiniw(nm, jm, m, pz, rm):

    pz_shape = (5, jm // 2)
    pm_shape = (2, jm // 2)
    rm_shape = ((nm - m) // 2 * 3 + nm - m + 1, )
    jc_shape = ((nm - m) // 8 + 1, )

    pm = np.empty(pm_shape, dtype=np.float64)
    jc = np.empty(jc_shape, dtype=np.int64)

    lib.lxiniw_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # JM
        ctypes.POINTER(ctypes.c_int64),  # M
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pz_shape),  # PZ
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pm_shape),  # PM
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=rm_shape),  # RM
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape)    # JC
    ]

    lib.lxiniw_.restype = None

    lib.lxiniw_(
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(m)),
        pz,
        pm,
        rm,
        jc
    )

    return pm, jc

def lxtswg(nm, nn, jm, m, s, g, pz, pm, rm, jc, ipow):

    s_shape = (nn - m + 1, 2)
    g_shape = (jm, 2)
    pz_shape = (5, jm // 2)
    pm_shape = (2, jm // 2)
    rm_shape = ((nm - m) // 2 * 3 + nm - m + 1, )
    jc_shape = ((nm - m) // 8 + 1, )

    #g = aligned_array(g_shape, dtype=np.float64, align=64)

    lib.lxtswg_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # NM
        ctypes.POINTER(ctypes.c_int64),  # NN
        ctypes.POINTER(ctypes.c_int64),  # JM
        ctypes.POINTER(ctypes.c_int64),  # M
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=s_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g_shape),  # G
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pz_shape),  # PZ
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pm_shape),  # PM
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=rm_shape),  # RM
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),    # JC
        ctypes.POINTER(ctypes.c_int64)   # IPOW
    ]

    lib.lxtswg_.restype = None

    lib.lxtswg_(
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(m)),
        s,
        g,
        pz,
        pm,
        rm,
        jc,
        ctypes.byref(ctypes.c_int64(ipow))
    )

    return

def lxtgws(nm, nn, jm, m, s, g, pz, pm, rm, jc, ipow):

    s_shape = (nn-m+1, 2)
    g_shape = (jm, 2)
    pz_shape = (5, jm//2)
    pm_shape = (2, jm//2)
    rm_shape = ((nm-m)//2*3 + nm-m+1,)
    jc_shape = ((nm-m)//8 + 1,)

    #s = np.empty(s_shape, dtype=np.float64)

    lib.lxtgws_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # nm
        ctypes.POINTER(ctypes.c_int64),  # nn
        ctypes.POINTER(ctypes.c_int64),  # jm
        ctypes.POINTER(ctypes.c_int64),  # m
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=s_shape),  # s
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g_shape),  # g
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pz_shape), # pz
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pm_shape), # pm
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=rm_shape), # rm
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),   # jc
        ctypes.POINTER(ctypes.c_int64)   # ipow
    ]

    lib.lxtgws_.restype = None

    lib.lxtgws_(
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(m)),
        s,
        g,
        pz,
        pm,
        rm,
        jc,
        ctypes.byref(ctypes.c_int64(ipow))
    )

    return


def lxtszg(nm, nn, jm, s, g, pz, rm, ipow):

    s_shape = (nn+1,)   # 0:NN
    g_shape = (jm,)
    pz_shape = (5, jm//2)
    rm_shape = (nm//2*3 + nm + 1,)

    #g = aligned_array(g_shape, dtype=np.float64, align=64)

    lib.lxtszg_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # nm
        ctypes.POINTER(ctypes.c_int64),  # nn
        ctypes.POINTER(ctypes.c_int64),  # jm
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # s
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=g_shape),  # g
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pz_shape),  # pz
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=rm_shape), # rm
        ctypes.POINTER(ctypes.c_int64)   # ipow
    ]

    lib.lxtszg_.restype = None

    lib.lxtszg_(
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(jm)),
        s,
        g,
        pz,
        rm,
        ctypes.byref(ctypes.c_int64(ipow))
    )

    return


def lxtgzs(nm, nn, jm, s, g, pz, rm, ipow):

    s_shape = (nn+1,)   # 0:NN
    g_shape = (jm,)
    pz_shape = (5, jm//2)
    rm_shape = (nm//2*3 + nm + 1,)

    #s = np.empty(s_shape, dtype=np.float64)

    lib.lxtgzs_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # nm
        ctypes.POINTER(ctypes.c_int64),  # nn
        ctypes.POINTER(ctypes.c_int64),  # jm
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=g_shape),  # g
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # s
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pz_shape),  # pz
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=rm_shape), # rm
        ctypes.POINTER(ctypes.c_int64)   # ipow
    ]

    lib.lxtgzs_.restype = None

    lib.lxtgzs_(
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(jm)),
        g,
        s,
        pz,
        rm,
        ctypes.byref(ctypes.c_int64(ipow))
    )

    return

def lxtswv(nm, nn, jm, m, s1, s2, g1, g2, pz, pm, rm, jc, ipow):

    s1_shape = (nn-m+1, 2)
    s2_shape = (nn-m+1, 2)
    g1_shape = (jm, 2)
    g2_shape = (jm, 2)
    pz_shape = (5, jm//2)
    pm_shape = (2, jm//2)
    rm_shape = ((nm-m)//2*3+nm-m+1,)
    jc_shape = ((nm-m)//8 + 1,)

    #g1 = aligned_array(g1_shape, dtype=np.float64, align=64)
    #g2 = aligned_array(g2_shape, dtype=np.float64, align=64)

    lib.lxtswv_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # nm
        ctypes.POINTER(ctypes.c_int64),  # nn
        ctypes.POINTER(ctypes.c_int64),  # jm
        ctypes.POINTER(ctypes.c_int64),  # m
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=s1_shape),  # s1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=s2_shape),  # s2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g1_shape),  # g1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g2_shape),  # g2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pz_shape),  # pz
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pm_shape),  # pm
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=rm_shape),  # rm
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),    # jc
        ctypes.POINTER(ctypes.c_int64)  # ipow
    ]

    lib.lxtswv_.restype = None

    lib.lxtswv_(
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(m)),
        s1, s2,
        g1, g2,
        pz,
        pm,
        rm,
        jc,
        ctypes.byref(ctypes.c_int64(ipow))
    )

    return


def lxtvws(nm, nn, jm, m, s1, s2, g1, g2, pz, pm, rm, jc, ipow):

    s1_shape = (nn-m+1, 2)
    s2_shape = (nn-m+1, 2)
    g1_shape = (jm, 2)
    g2_shape = (jm, 2)
    pz_shape = (5, jm//2)
    pm_shape = (2, jm//2)
    rm_shape = ((nm-m)//2*3+nm-m+1,)
    jc_shape = ((nm-m)//8 + 1,)

    #s1 = np.empty(s1_shape, dtype=np.float64)
    #s2 = np.empty(s2_shape, dtype=np.float64)

    lib.lxtvws_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # nm
        ctypes.POINTER(ctypes.c_int64),  # nn
        ctypes.POINTER(ctypes.c_int64),  # jm
        ctypes.POINTER(ctypes.c_int64),  # m
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=s1_shape),  # s1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=s2_shape),  # s2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g1_shape),  # g1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=g2_shape),  # g2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pz_shape),  # pz
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pm_shape),  # pm
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=rm_shape),  # rm
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=jc_shape),    # jc
        ctypes.POINTER(ctypes.c_int64)  # ipow
    ]

    lib.lxtvws_.restype = None

    lib.lxtvws_(
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(jm)),
        ctypes.byref(ctypes.c_int64(m)),
        s1, s2,
        g1, g2,
        pz,
        pm,
        rm,
        jc,
        ctypes.byref(ctypes.c_int64(ipow))
    )

    return

def lxtszv(nm, nn, jm, s1, s2, g1, g2, pz, rm, ipow):

    s1_shape = (nn+1,)
    s2_shape = (nn+1,)
    g1_shape = (jm,)
    g2_shape = (jm,)
    pz_shape = (5, jm//2)
    rm_shape = (nm//2*3+nm+1,)

    #g1 = aligned_array(g1_shape, dtype=np.float64, align=64)
    #g2 = aligned_array(g2_shape, dtype=np.float64, align=64)

    lib.lxtszv_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # nm
        ctypes.POINTER(ctypes.c_int64),  # nn
        ctypes.POINTER(ctypes.c_int64),  # jm
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s1_shape),  # s1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s2_shape),  # s2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=g1_shape),  # g1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=g2_shape),  # g2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pz_shape),  # pz
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=rm_shape),  # rm
        ctypes.POINTER(ctypes.c_int64)  # ipow
    ]

    lib.lxtszv_.restype = None

    lib.lxtszv_(
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(jm)),
        s1, s2,
        g1, g2,
        pz,
        rm,
        ctypes.byref(ctypes.c_int64(ipow))
    )

    return


def lxtvzs(nm, nn, jm, s1, s2, g1, g2, pz, rm, ipow):

    s1_shape = (nn+1,)
    s2_shape = (nn+1,)
    g1_shape = (jm,)
    g2_shape = (jm,)
    pz_shape = (5, jm//2)
    rm_shape = (nm//2*3+nm+1,)

    #s1 = np.empty(s1_shape, dtype=np.float64)
    #s2 = np.empty(s2_shape, dtype=np.float64)

    lib.lxtvzs_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # nm
        ctypes.POINTER(ctypes.c_int64),  # nn
        ctypes.POINTER(ctypes.c_int64),  # jm
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s1_shape),  # s1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s2_shape),  # s2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=g1_shape),  # g1
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=g2_shape),  # g2
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=pz_shape),  # pz
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=rm_shape),  # rm
        ctypes.POINTER(ctypes.c_int64)  # ipow
    ]

    lib.lxtvzs_.restype = None

    lib.lxtvzs_(
        ctypes.byref(ctypes.c_int64(nm)),
        ctypes.byref(ctypes.c_int64(nn)),
        ctypes.byref(ctypes.c_int64(jm)),
        s1, s2,
        g1, g2,
        pz,
        rm,
        ctypes.byref(ctypes.c_int64(ipow))
    )

    return


def lxinic(nt, m):

    cm_shape = (2*(nt-m)+1,)

    cm = np.empty(cm_shape, dtype=np.float64)

    lib.lxinic_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # nt
        ctypes.POINTER(ctypes.c_int64),  # m
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=cm_shape)   # cm
    ]

    lib.lxinic_.restype = None

    lib.lxinic_(
        ctypes.byref(ctypes.c_int64(nt)),
        ctypes.byref(ctypes.c_int64(m)),
        cm
    )

    return cm

def lxcswy(nt, m, s, SY, cm):

    s_shape = (nt-m+1, 2)
    sy_shape = (nt-m+2, 2)
    cm_shape = (2 * (nt - m) + 1, )

    #SY = np.empty(sy_shape, dtype=np.float64)

    lib.lxcswy_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # NT
        ctypes.POINTER(ctypes.c_int64),  # M
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=s_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=sy_shape), # SY
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=cm_shape),  # CM
    ]

    lib.lxcswy_.restype = None

    lib.lxcswy_(
        ctypes.byref(ctypes.c_int64(nt)),
        ctypes.byref(ctypes.c_int64(m)),
        s,
        SY,
        cm
    )
    
    return

def lxcszy(nt, s, SY, cm):

    s_shape = (nt+1, )
    sy_shape = (nt+2, )
    cm_shape = (2*nt+1, )

    #SY = np.empty(sy_shape, dtype=np.float64)

    lib.lxcszy_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sy_shape), # SY
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=cm_shape),  # CM
    ]

    lib.lxcszy_.restype = None

    lib.lxcszy_(
        ctypes.byref(ctypes.c_int64(nt)),
        s,
        SY,
        cm
    )

    return


def lxcyws(nt, m, sy, S, cm):

    s_shape = (nt-m+1, 2)
    sy_shape = (nt-m+2, 2)
    cm_shape = (2 * (nt - m) + 1, )

    #S = np.empty(s_shape, dtype=np.float64)

    lib.lxcyws_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # NT
        ctypes.POINTER(ctypes.c_int64),  # M
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=sy_shape),  # SY
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=2, shape=s_shape),   # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=cm_shape),  # CM
    ]

    lib.lxcyws_.restype = None

    lib.lxcyws_(
        ctypes.byref(ctypes.c_int64(nt)),
        ctypes.byref(ctypes.c_int64(m)),
        sy,
        S,
        cm
    )

    return


def lxcyzs(nt, sy, S, cm):

    s_shape = (nt+1, )
    sy_shape = (nt+2, )
    cm_shape = (2*nt+1, )

    #S = np.empty(s_shape, dtype=np.float64)

    lib.lxcyzs_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # NT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=sy_shape),  # SY
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),   # S
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=cm_shape),  # CM
    ]

    lib.lxcyzs_.restype = None

    lib.lxcyzs_(
        ctypes.byref(ctypes.c_int64(nt)),
        sy,
        S,
        cm
    )

    return


def lxtszp(nn, s):

    s_shape = (nn + 1, )

    gnp = ctypes.c_double()
    gsp = ctypes.c_double()

    lib.lxtszp_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),  # NN
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=s_shape),  # S
        ctypes.POINTER(ctypes.c_double), # GNP
        ctypes.POINTER(ctypes.c_double)  # GSP
    ]

    lib.lxtszp_.restype = None

    lib.lxtszp_(
        ctypes.byref(ctypes.c_int64(nn)),
        s,
        ctypes.byref(gnp),
        ctypes.byref(gsp)
    )

    return gnp.value, gsp.value


def fxzini(n):
    it_shape = (n, )
    t_shape = (n * 2, )

    IT = np.empty(it_shape, dtype=np.int64)
    T = np.empty(t_shape, dtype=np.float64)

    lib.fxzini_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),               # N
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape)  # T
    ]
    lib.fxzini_.restype = None

    lib.fxzini_(
        ctypes.byref(ctypes.c_int64(n)),
        IT,
        T
    )

    return IT, T


def fxztfa(m, n, x, it, t):
    x_shape = (m * 2 * n, )
    it_shape = (n, )
    t_shape = (n * 2, )

    lib.fxztfa_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),                              # M
        ctypes.POINTER(ctypes.c_int64),                              # N
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=x_shape), # X
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape), # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape) # T
    ]
    lib.fxztfa_.restype = None

    lib.fxztfa_(
        ctypes.byref(ctypes.c_int64(m)),
        ctypes.byref(ctypes.c_int64(n)),
        x,
        it,
        t
    )

    return

def fxztba(m, n, x, it, t):
    x_shape = (m * 2 * n, )
    it_shape = (n, )
    t_shape = (n * 2, )

    lib.fxztba_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),                              # M
        ctypes.POINTER(ctypes.c_int64),                              # N
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=x_shape), # X
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape), # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape) # T
    ]
    lib.fxztba_.restype = None

    lib.fxztba_(
        ctypes.byref(ctypes.c_int64(m)),
        ctypes.byref(ctypes.c_int64(n)),
        x,
        it,
        t
    )

    return

def fxrini(n):
    it_shape = (n // 2, )
    t_shape = (n * 3 // 2, )

    IT = np.empty(it_shape, dtype=np.int64)
    T = np.empty(t_shape, dtype=np.float64)

    lib.fxrini_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),               # N
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape)  # T
    ]
    lib.fxrini_.restype = None

    lib.fxrini_(
        ctypes.byref(ctypes.c_int64(n)),
        IT,
        T
    )

    return IT, T


def fxrtfa(m, n, x, it, t):
    x_shape = (m * n, )

    lib.fxrtfa_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),                              # M
        ctypes.POINTER(ctypes.c_int64),                              # N
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=x_shape), # X
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape)  # T
    ]
    lib.fxrtfa_.restype = None

    lib.fxrtfa_(
        ctypes.byref(ctypes.c_int64(m)),
        ctypes.byref(ctypes.c_int64(n)),
        x,
        it,
        t
    )

    return


def fxrtba(m, n, x, it, t):
    x_shape = (m * n, )

    lib.fxrtba_.argtypes = [
        ctypes.POINTER(ctypes.c_int64),                              # M
        ctypes.POINTER(ctypes.c_int64),                              # N
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=x_shape), # X
        np.ctypeslib.ndpointer(dtype=np.int64, ndim=1, shape=it_shape),  # IT
        np.ctypeslib.ndpointer(dtype=np.float64, ndim=1, shape=t_shape)  # T
    ]
    lib.fxrtba_.restype = None

    lib.fxrtba_(
        ctypes.byref(ctypes.c_int64(m)),
        ctypes.byref(ctypes.c_int64(n)),
        x,
        it,
        t
    )

    return
