!===========================================================================
!
! Included from inversion.f90
!
!============================================================================

!---------------------- Use scaLAPACK For Inversion -----------------------------------


#ifdef USESCALAPACK

subroutine X(invert_with_scalapack)(nmtx, scal, matrix)
  integer, intent(in) :: nmtx
  type (scalapack), intent(in) :: scal
  SCALAR, intent(inout) :: matrix(scal%npr,scal%npc)

  integer :: ij, ik, info
  integer :: irow, icol, icurr, irowm, icolm
  integer :: descaA(9), descaB(9)
  character*16 :: tmpstr1
  character*256 :: tmpstr
  integer, allocatable :: ipiv(:)
  SCALAR, allocatable :: temp(:, :)

!-------------------------
! scaLAPACK Case (Added by JRD - must specify USESCALAPACK in arch.mk)

  PUSH_SUB(X(invert_with_scalapack))

#ifdef VERBOSE
  if (peinf%inode .eq. 0) then
    write(6,*) ' '
    write(6,*) 'We are doing inversion with scaLAPACK.'
    write(6,*) ' '
    write(6,*) 'Peinf: ', peinf%inode,' Scalapack Grid:'
    write(6,*) scal%nprow,scal%npcol,scal%myprow,scal%mypcol
    write(6,*) nmtx,scal%nbl,scal%npr,scal%npc,scal%nqrhs
    write(6,*) ' '
  end if
#endif

  SAFE_ALLOCATE(temp, (scal%npr,scal%npc))
  temp = 0D0

  call descinit(descaA,nmtx,nmtx,scal%nbl,scal%nbl,0, &
    0,scal%icntxt,MAX(1,scal%npr),info)
  call descinit(descaB,nmtx,nmtx,scal%nbl,scal%nbl,0, &
    0,scal%icntxt,MAX(1,scal%npr),info)

  SAFE_ALLOCATE(ipiv, (scal%npr+scal%nbl))

!------------------
! Construct identity for rhs

  icurr=0
  
  do ij = 1, nmtx
    irow=MOD(INT(((ij-1)/scal%nbl)+TOL_SMALL),scal%nprow)
    do ik = 1, nmtx
      icol=MOD(INT(((ik-1)/scal%nbl)+TOL_SMALL),scal%npcol)
      if(irow .eq. scal%myprow .and. icol .eq. scal%mypcol) then
        icurr=icurr+1
        irowm=INT((icurr-1)/scal%npc+TOL_SMALL)+1
        icolm=MOD((icurr-1),scal%npc)+1
        if (ik .eq. ij) then
          temp(irowm,icolm)=ONE
        else
          temp(irowm,icolm)=ZERO
        end if
      end if
    enddo
  enddo

  call pX(gesv)(nmtx,nmtx,matrix,1,1,descaA,ipiv,temp,1,1,descaB,info)

  if (info.ne.0) then
    write(tmpstr1,'(i16)') info
    write(tmpstr,'(a)') 'failed in ' // TOSTRING(pX(gesv)) // &
     ' with error code ' // TRUNC(tmpstr1) // '.'
    call die(tmpstr)
  endif

  matrix = temp

  SAFE_DEALLOCATE(ipiv)
  SAFE_DEALLOCATE(temp)
  
  POP_SUB(X(invert_with_scalapack))
  return
end subroutine X(invert_with_scalapack)

#endif

!------------------------------------------------------------

subroutine X(invert_serial)(nmtx, matrix)
  integer, intent(in) :: nmtx
  SCALAR, intent(inout) :: matrix(nmtx,nmtx)

  integer :: ii, info
  character*16 :: tmpstr1
  character*256 :: tmpstr
  integer, allocatable :: ipvt(:)
  SCALAR, allocatable :: temp(:, :)

  PUSH_SUB(X(invert_serial))

  SAFE_ALLOCATE(ipvt, (nmtx))
  SAFE_ALLOCATE(temp, (nmtx, nmtx))
  
  temp(:,:) = ZERO
  do ii=1,nmtx
    temp(ii,ii) = ONE
  enddo

#ifdef USEESSL
  call X(gef)(matrix(1,1),nmtx,nmtx,ipvt(1))
  call X(gesm)('N',matrix(1,1),nmtx,nmtx,ipvt(1),temp(1,1),nmtx,nmtx)
#else
! Call LAPACK routine to solve for X in matrix*X=temp
! and put the result into temp.

  call X(gesv)(nmtx,nmtx,matrix,nmtx,ipvt,temp,nmtx,info)

  if (info.ne.0) then
    write(tmpstr1,'(i16)') info
    write(tmpstr,'(a)') 'failed in ' // TOSTRING(X(gesv)) // &
     ' with error code ' // TRUNC(tmpstr1) // '.'
    call die(tmpstr)
  endif
#endif

! Copy result back into matrix.

  matrix = temp

  SAFE_DEALLOCATE(ipvt)
  SAFE_DEALLOCATE(temp)
  
  POP_SUB(X(invert_serial))
  return
end subroutine X(invert_serial)
