forked from jbornschein/mpi4py-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
08-matrix-matrix-product.py
executable file
·96 lines (63 loc) · 2.42 KB
/
08-matrix-matrix-product.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#!/usr/bin/env python
from __future__ import division
from __future__ import print_function
import numpy as np
from mpi4py import MPI
from time import time
#=============================================================================#
my_N = 3000
my_M = 3000
#=============================================================================#
NORTH = 0
SOUTH = 1
EAST = 2
WEST = 3
def pprint(string, comm=MPI.COMM_WORLD):
if comm.rank == 0:
print(string)
if __name__ == "__main__":
comm = MPI.COMM_WORLD
mpi_rows = int(np.floor(np.sqrt(comm.size)))
mpi_cols = comm.size // mpi_rows
if mpi_rows*mpi_cols > comm.size:
mpi_cols -= 1
if mpi_rows*mpi_cols > comm.size:
mpi_rows -= 1
pprint("Creating a %d x %d processor grid..." % (mpi_rows, mpi_cols) )
ccomm = comm.Create_cart( (mpi_rows, mpi_cols), periods=(True, True), reorder=True)
my_mpi_row, my_mpi_col = ccomm.Get_coords( ccomm.rank )
neigh = [0,0,0,0]
neigh[NORTH], neigh[SOUTH] = ccomm.Shift(0, 1)
neigh[EAST], neigh[WEST] = ccomm.Shift(1, 1)
# Create matrices
my_A = np.random.normal(size=(my_N, my_M)).astype(np.float32)
my_B = np.random.normal(size=(my_N, my_M)).astype(np.float32)
my_C = np.zeros_like(my_A)
tile_A = my_A
tile_B = my_B
tile_A_ = np.empty_like(my_A)
tile_B_ = np.empty_like(my_A)
req = [None, None, None, None]
t0 = time()
for r in range(mpi_rows):
req[EAST] = ccomm.Isend(tile_A , neigh[EAST])
req[WEST] = ccomm.Irecv(tile_A_, neigh[WEST])
req[SOUTH] = ccomm.Isend(tile_B , neigh[SOUTH])
req[NORTH] = ccomm.Irecv(tile_B_, neigh[NORTH])
#t0 = time()
my_C += np.dot(tile_A, tile_B)
#t1 = time()
req[0].Waitall(req)
#t2 = time()
#print("Time computing %6.2f %6.2f" % (t1-t0, t2-t1))
comm.barrier()
t_total = time()-t0
t0 = time()
np.dot(tile_A, tile_B)
t_serial = time()-t0
pprint(78*"=")
pprint("Computed (serial) %d x %d x %d in %6.2f seconds" % (my_M, my_M, my_N, t_serial))
pprint(" ... expecting parallel computation to take %6.2f seconds" % (mpi_rows*mpi_rows*mpi_cols*t_serial / comm.size))
pprint("Computed (parallel) %d x %d x %d in %6.2f seconds" % (mpi_rows*my_M, mpi_rows*my_M, mpi_cols*my_N, t_total))
#print "[%d] (%d,%d): %s" % (comm.rank, my_mpi_row, my_mpi_col, neigh)
comm.barrier()