Programs

Matrix multiplication

from __future__ import division
from __future__ import print_function

import numpy as np
from mpi4py import MPI
from time import time

#=============================================================================#

my_N = 3000
my_M = 3000

#=============================================================================#

NORTH = 0
SOUTH = 1
EAST = 2
WEST = 3



def pprint(string, comm=MPI.COMM_WORLD):
    if comm.rank == 0:
        print(string)


if __name__ == "__main__":
    comm = MPI.COMM_WORLD

    mpi_rows = int(np.floor(np.sqrt(comm.size)))
    mpi_cols = comm.size // mpi_rows
    if mpi_rows*mpi_cols > comm.size:
        mpi_cols -= 1
    if mpi_rows*mpi_cols > comm.size:
        mpi_rows -= 1

    pprint("Creating a %d x %d processor grid..." % (mpi_rows, mpi_cols) )

    ccomm = comm.Create_cart( (mpi_rows, mpi_cols), periods=(True, True), reorder=True)

    my_mpi_row, my_mpi_col = ccomm.Get_coords( ccomm.rank )
    neigh = [0,0,0,0]

    neigh[NORTH], neigh[SOUTH] = ccomm.Shift(0, 1)
    neigh[EAST],  neigh[WEST]  = ccomm.Shift(1, 1)


    # Create matrices
    my_A = np.random.normal(size=(my_N, my_M)).astype(np.float32)
    my_B = np.random.normal(size=(my_N, my_M)).astype(np.float32)
    my_C = np.zeros_like(my_A)

    tile_A = my_A
    tile_B = my_B
    tile_A_ = np.empty_like(my_A)
    tile_B_ = np.empty_like(my_A)
    req = [None, None, None, None]

    t0 = time()
    for r in range(mpi_rows):
        req[EAST]  = ccomm.Isend(tile_A , neigh[EAST])
        req[WEST]  = ccomm.Irecv(tile_A_, neigh[WEST])
        req[SOUTH] = ccomm.Isend(tile_B , neigh[SOUTH])
        req[NORTH] = ccomm.Irecv(tile_B_, neigh[NORTH])

        #t0 = time()
        my_C += np.dot(tile_A, tile_B)
        #t1 = time()

        req[0].Waitall(req)
        #t2 = time()
        #print("Time computing %6.2f  %6.2f" % (t1-t0, t2-t1))
    comm.barrier()
    t_total = time()-t0

    t0 = time()
    np.dot(tile_A, tile_B)
    t_serial = time()-t0

    pprint(78*"=")
    pprint("Computed (serial) %d x %d x %d in  %6.2f seconds" % (my_M, my_M, my_N, t_serial))
    pprint(" ... expecting parallel computation to take %6.2f seconds" % (mpi_rows*mpi_rows*mpi_cols*t_serial / comm.size))
    pprint("Computed (parallel) %d x %d x %d in        %6.2f seconds" % (mpi_rows*my_M, mpi_rows*my_M, mpi_cols*my_N, t_total))


    #print "[%d] (%d,%d): %s" % (comm.rank, my_mpi_row, my_mpi_col, neigh)

    comm.barrier()