! Tests functionality of ScaLAPACK
! Calculates SVD of a random matrix A distributed on 4 procs
! Singular values are stored in the global array S
! 
! Usage: mpirun -np 4 ./a.out
!
! stali@purdue.edu

program parallel_svd
  implicit none

  ! Global arrays A and S (used only to initially distribute A to other procs and finally display S)
  real(8), allocatable :: A(:,:),S(:)
  integer :: n

  ! Variables for BLACS init and proc grid generation
  integer :: iam,nprocs,cntxt,nprow,npcol,myrow,mycol

  ! Variables needed for distributing global arrays across the processor grid
  integer :: desca(9),descu(9),descvt(9)
  integer :: info,rsrc,csrc

  ! Local arrays A,U,VT (proc specific)
  real(8), allocatable :: la(:,:),lu(:,:),lvt(:,:),work(:)
  integer :: ln, lwork

  n = 8     ! Size of the global matrix A (8 x 8)
  ln = n/2  ! Size of each local matrix (4 x 4), distributed on 4 procs

  !--- Initialize BLACS
  call blacs_pinfo(iam,nprocs)
  call blacs_get(-1,0,cntxt)

  !--- Set up the proc grid
  nprow=2; npcol=2
  call blacs_gridinit(cntxt,'r',nprow,npcol)
  call blacs_gridinfo(cntxt,nprow,npcol,myrow,mycol)

  !--- Define the necessary array descriptors
  rsrc=0; csrc=0
  call descinit(desca,n,n,ln,ln,rsrc,csrc,cntxt,ln,info)
  call descinit(descu,n,n,ln,ln,rsrc,csrc,cntxt,ln,info)
  call descinit(descvt,n,n,ln,ln,rsrc,csrc,cntxt,ln,info)

  !--- Set up and populate global/local arrays

  ! Allocate space for global arrays
  allocate (A(n,n))
  allocate (S(n))

  ! The first proc (0,0) on the processor grid populates and broadcasts A to 
  ! all others (alternatively point to point comm. can be used to directly fill
  ! the local arrays)
  if (myrow==0 .and. mycol==0) then
     call random_number(A)
     call dgebs2d(cntxt,'All','i-ring',n,n,A,n)
  else
     call dgebr2d(cntxt,'All','i-ring',n,n,A,n,0,0)
  end if

  ! Allocate space for local arrays
  allocate (la(ln,ln)) 
  allocate (lu(ln,ln)) 
  allocate (lvt(ln,ln)) 

  ! Each proc fills its local array with correct elements from the global array A
  if (myrow==0 .and. mycol==0) then 
     la=A(1:ln,1:ln)
  end if
  if (myrow==1 .and. mycol==0) then 
     la=A(ln+1:n,1:ln)
  end if
  if (myrow==0 .and. mycol==1) then
     la=A(1:ln,ln+1:n)
  end if
  if (myrow==1 .and. mycol==1) then
     la=A(ln+1:n,ln+1:n)
  end if

  !--- ScaLAPACK call

  ! First carry out a workspace query to figure out the size of lwork
  lwork=-1
  allocate(work(1))
  call pdgesvd('v','v',n,n,la,1,1,desca,S,lu,1,1,descu,lvt, &
       &     1,1,descvt,work,lwork,info)
  lwork=int(abs(work(1)))
  deallocate (work)

  ! With lwork now known call the main computational routine to compute the SVD
  allocate (work(lwork))
  call pdgesvd('v','v',n,n,la,1,1,desca,S,lu,1,1,descu,lvt, &
       &     1,1,descvt,work,lwork,info)
  deallocate(work)

  !--- Print the singular values (stored in the global array S) using any one of the procs
  if(myrow==1 .and. mycol==0) then
     print*, "Singular values:", S
  end if

  !--- Free up local/global arrays, release the proc grid and terminate BLACS
  deallocate(la,lu,lvt)
  deallocate(A,S)
  call blacs_gridexit(cntxt) 
  call blacs_exit(0)
  stop

end program parallel_svd
