!============================================================================
!
! Module fftw_m, originally by DAS 1/14/2011
!
!   Routines used with FFTW, as well as interfaces for library calls.
!
!   Interfaces for FFTW2 functions, formulated from fftw-2.1.5/fftw/fftwf77.c
!   and http://www.fftw.org/fftw2_doc/fftw_5.html
!   Every FFTW2 function used in the code should be listed here, and this
!   module should be used in every routine containing FFTW2 calls to ensure
!   the argument types are correct.
!
!   Contains include file fftw_f77.i which has parameter used in FFTW calls.
!
!============================================================================

#include "f_defs.h"

module fftw_m

  use global_m
  implicit none

  include 'fftw_f77.i'

  public ::            &
    check_FFT_size,    &
    setup_FFT_sizes,   &
    gvec_to_fft_index, &
    put_into_fftbox,   &
    get_from_fftbox,   &
    do_FFT,            &
    conjg_fftbox,      &
    multiply_fftboxes, &
    destroy_fftw_plans

    integer, private :: Nfftold(3) = 0
    integer*8, private :: plus_plan = 0
    integer*8, private :: minus_plan = 0

! fftw_plan type in C is recommended to be integer*8 by FFTW documentation
  interface
    subroutine fftwnd_f77_create_plan(p, rank, n, idir, flags)
      integer*8 :: p
      integer :: rank, n(*), idir, flags
    end subroutine fftwnd_f77_create_plan
  end interface

  interface
    subroutine fftwnd_f77_one(p, in, out)
      integer*8 :: p
      complex*16 :: in(*)
      integer :: out
! The argument is really complex*16 out(*), but we only use in-place transforms,
! in which case this argument is ignored. For simplicity we just pass it 0.    
    end subroutine fftwnd_f77_one
  end interface
  
  interface
    subroutine fftwnd_f77_destroy_plan(p)
      integer*8 :: p
    end subroutine fftwnd_f77_destroy_plan
  end interface

contains

! Originally by gsm      Last Modified: 4/10/2010 (gsm)
!     Best FFT grid dimension is given by 2^a*3^b*5^c*7^d*11^e*13^f
!     where a,b,c,d are arbitrary and e,f are 0 or 1
! Ref: http://www.fftw.org/fftw2_doc/fftw_3.html
!     On entry
!             Nfft = FFT grid dimension to test
!             Nfac = number of factors to test
!     On exit
!             check_FFT_size = .true. if good FFT grid dimension
  logical function check_FFT_size(Nfft, Nfac) 
    integer, intent(in) :: Nfft, Nfac
  
    integer :: remainder, product, ifac, ipow, maxpow
    integer, parameter :: maxfac = 6
    integer :: pow(maxfac)
    integer, parameter :: fac(maxfac) = (/ 2, 3, 5, 7, 11, 13 /)
  
    PUSH_SUB(check_FFT_size)
  
    if(Nfft .lt. 1 .or. Nfac .lt. 1 .or. Nfac .gt. maxfac) then
      call die('check_FFT_size input')
    endif
  
    remainder = Nfft
    do ifac = 1, maxfac
      pow(ifac) = 0
    enddo
  
    do ifac = 1, Nfac
      maxpow = int(log(dble(remainder)) / log(dble(fac(ifac)))) + 1
      do ipow = 1, maxpow
        if (mod(remainder, fac(ifac)) .eq. 0) then
          remainder = remainder / fac(ifac)
          pow(ifac) = pow(ifac) + 1
        endif
      enddo
    enddo
  
    product = remainder
    do ifac = 1, Nfac
      do ipow = 1, pow(ifac)
        product = product * fac(ifac)
      enddo
    enddo
    if (product .ne. Nfft) then
      call die('Internal error in check_FFT_size; factorization failed')
    endif
  
    check_FFT_size = remainder .eq. 1 .and. pow(5) .le. 1 .and. pow(6) .le. 1
  
    POP_SUB(check_FFT_size)
  
    return
  end function check_FFT_size

! The former "fft_routines.f90"
! Sohrab Ismail-Beigi   Feb 28 2001
!
! There are a set of Fast Fourier-related routines that are used
! to compute the matrix elements of the type <nk|e^(i*G.r)|mk`>.
! For many G-vectors, FFTs will be the fastest way to compute them.
!
! The FFTW (http://www.fftw.org) suite of routines do the actual work.
! Most of what is below is interfacing code and routines that simplify
! small and useful tasks.

!
! Given gvec%kmax(1:3) values (in kmax), finds appropriate FFT box
! sizes to use in Nfft(1:3).  scale = 1/(Nfftx*Nffty*Nfftz).
!
  subroutine setup_FFT_sizes(kmax,Nfft,scale)
    integer, intent(in) :: kmax(3)
    integer, intent(out) :: Nfft(3)
    real(DP), intent(out) :: scale

    integer, parameter :: Nfac = 3
    integer :: i

    PUSH_SUB(setup_FFT_sizes)

    do i=1,3
      Nfft(i) = kmax(i)
      do while (.not. check_FFT_size(Nfft(i), Nfac))
        Nfft(i) = Nfft(i) + 1
      enddo
    enddo
    scale = 1.0d0/product(Nfft(1:3))
    
    POP_SUB(setup_FFT_sizes)

    return
  end subroutine setup_FFT_sizes

!
! Takes the G-vector g(1:3) and FFT box size Nfft(1:3) and finds the
! point idx(1:3) in the box corresponding to that G-vector.
!
  subroutine gvec_to_fft_index(g,idx,Nfft)
    integer, intent(in) :: g(3), Nfft(3)
    integer, intent(out) :: idx(3)

! no push/pop since called too frequently.

    idx(1:3) = g(1:3) + 1
    
    if (g(1) < 0) idx(1) = Nfft(1) + idx(1)
    if (g(2) < 0) idx(2) = Nfft(2) + idx(2)
    if (g(3) < 0) idx(3) = Nfft(3) + idx(3)

    return
  end subroutine gvec_to_fft_index

!
! This routine takes data(1:ndata) and puts it into the FFT box fftbox(:,:,:).
! The FFT box is zeroed out first, and the data is entered into it.
! 
!   ndata -- number of data items in data(:)
!   data -- the data set, real or complex, depending on ifdef CPLX
!   ng -- number of g vectors in glist
!   glist -- a master list of g vectors
!   gindex(1:ng) -- which g vector (in the master list) the data(1:ndata)
!                   actually refer to:  so data(j) is for the g-vector
!                   glist(1:3,gindex(j))
!   fftbox(:,:,:) -- 3D complex FFT box where the data is put
!   Nfft(1:3) -- sizes of FFT box Nx,Ny,Nz
!
  subroutine put_into_fftbox(ndata, data, ng, glist, gindex, fftbox, Nfft)
    integer, intent(in) :: ndata
    SCALAR,  intent(in) :: data(:) !< (ndata) this is to avoid creation of array temporary
    integer, intent(in) :: ng
    integer, intent(in) :: glist(:,:) !< (3, ng)
    integer, intent(in) :: gindex(:) !< (ng)
    integer, intent(in) :: Nfft(:) !< (3)
    complex(DPC), intent(out) :: fftbox(:,:,:) !< (Nfft(1), Nfft(2), Nfft(3))

    integer :: j, bidx(3)

    PUSH_SUB(put_into_fftbox)

    ! Zero out FFT box and put data into it
    fftbox(:,:,:) = (0.0d0,0.0d0)
    do j=1,ndata
      call gvec_to_fft_index(glist(:,gindex(j)),bidx,Nfft)
      fftbox(bidx(1),bidx(2),bidx(3)) = data(j)
    end do

    POP_SUB(put_into_fftbox)

    return
  end subroutine put_into_fftbox

!
! Does the inverse of the above routine:  takes the data in the
! fftbox(:,:,:) and puts it into the data(1:ndata) array.  ndata entries
! are extracted, and the gindex and glist specify which ones to get:
! data(j) corresponds to the g-vector glist(:,gindex(j)).  The data
! in fftbox is multiplied by scale before storage into data(:).
!
! data(:) is zeroed first and then the data is put into it.
!
  subroutine get_from_fftbox(ndata, data, ng, glist, gindex, fftbox, Nfft, scale)
    integer, intent(in) :: ndata
    SCALAR, intent(out) :: data(:) !< (ndata)
    integer, intent(in) :: ng
    integer, intent(in) :: glist(:,:) !< (3, ng)
    integer, intent(in) :: gindex(:) !< (ng)
    integer, intent(in) :: Nfft(:)
    complex(DPC), intent(in) :: fftbox(:,:,:) !< (Nfft(1), Nfft(2), Nfft(3))
    real(DP), intent(in) :: scale

    integer :: j, bidx(3)

    PUSH_SUB(get_from_fftbox)

    ! Zero out data set
    data(:) = 0.0
    do j=1,ndata
      call gvec_to_fft_index(glist(:,gindex(j)),bidx,Nfft)
      data(j) = fftbox(bidx(1),bidx(2),bidx(3))*scale
    end do

    POP_SUB(get_from_fftbox)

    return
  end subroutine get_from_fftbox

!
! Do an FFT on the fftbox in place:  destroys contents of fftbox
! and replaces them by the Fourier transform.
!
! The FFT done is:
!
!   fftbox(p) <- sum_j { fftbox(j)*e^{sign*i*j.p} }
!
! where j and p are integer 3-vectors ranging over Nfft(1:3).
!
  subroutine do_FFT(fftbox, Nfft, sign)
    complex(DPC), intent(inout) :: fftbox(:,:,:)
    integer, intent(in) :: Nfft(3)
    integer, intent(in) :: sign

    character (len=100) :: str

    PUSH_SUB(do_FFT)

    if(any(Nfftold(1:3) .ne. Nfft(1:3))) then
#ifdef VERBOSE
      write(str,'(a,3(i4,a))') &
        'Creating ',Nfft(1),' x',Nfft(2),' x',Nfft(3),' FFTW plans.'
      call logit(str)
#endif
      
      ! otherwise there is a memory leak
      if(all(Nfftold(1:3) == -1)) call destroy_fftw_plans()
      
      call fftwnd_f77_create_plan(plus_plan,3,Nfft,FFTW_BACKWARD, &
        FFTW_MEASURE+FFTW_IN_PLACE+FFTW_USE_WISDOM)
      call fftwnd_f77_create_plan(minus_plan,3,Nfft,FFTW_FORWARD, &
        FFTW_MEASURE+FFTW_IN_PLACE+FFTW_USE_WISDOM)
      Nfftold(1:3) = Nfft(1:3)
#ifdef VERBOSE
      call logit('Done creating plans')
#endif
    endif

    if (sign == 1) then
      call fftwnd_f77_one(plus_plan,fftbox,0)
    else if (sign == -1) then
      call fftwnd_f77_one(minus_plan,fftbox,0)
    else
      call die('sign is not 1 or -1 in do_FFT')
    endif

    POP_SUB(do_FFT)

    return
  end subroutine do_FFT

  subroutine destroy_fftw_plans()

    ! FFTW plan was never created
    if(all(Nfftold(1:3) == 0)) return

    PUSH_SUB(destroy_fftw_plans)

    if(all(Nfftold(1:3) == -1)) call die("Cannot destroy FFTW plan for a second time.")

#ifdef VERBOSE
    if(peinf%inode == 0) &
      write(6,'(a,3(i4,a))') '*** VERBOSE: Destroying ',Nfftold(1),' x',Nfftold(2),' x',Nfftold(3),' FFTW plans.'
#endif

    Nfftold(1:3) = -1 ! make clear there is no plan anymore so we do not try to destroy twice

    call fftwnd_f77_destroy_plan(plus_plan)
    call fftwnd_f77_destroy_plan(minus_plan)
    ! should forget wisdom here, but I cannot figure out how... --DAS

    POP_SUB(destroy_fftw_plans)
    return
  end subroutine destroy_fftw_plans

!
! Complex conjugate contents of FFT box
!
  subroutine conjg_fftbox(fftbox,Nfft)
    integer, intent(in) :: Nfft(3)
    complex(DPC), intent(inout) :: fftbox(:,:,:) !< (Nfft(1), Nfft(2), Nfft(3))
    ! for some reason, absoft segfaults if dims specified for fftbox as above right

    integer :: ix,iy,iz

    PUSH_SUB(conjg_fftbox)

    forall(iz=1:Nfft(3), iy=1:Nfft(2), ix=1:Nfft(1)) fftbox(ix,iy,iz) = CONJG(fftbox(ix,iy,iz))

    POP_SUB(conjg_fftbox)

    return
  end subroutine conjg_fftbox

!
! Multiply contents of two fft boxes, result into fftbox2
!
  subroutine multiply_fftboxes(fftbox1, fftbox2, Nfft)
    integer, intent(in) :: Nfft(3)
    complex(DPC), intent(in) :: fftbox1(:,:,:) !< (Nfft(1), Nfft(2), Nfft(3))
    complex(DPC), intent(inout) :: fftbox2(:,:,:) !< (Nfft(1), Nfft(2), Nfft(3))

    integer :: ix,iy,iz

    PUSH_SUB(multiply_fftboxes)

    forall(iz=1:Nfft(3), iy=1:Nfft(2), ix=1:Nfft(1)) &
      fftbox2(ix,iy,iz) = fftbox1(ix,iy,iz) * fftbox2(ix,iy,iz)

    POP_SUB(multiply_fftboxes)

    return
  end subroutine multiply_fftboxes

end module fftw_m
