-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtest_darray.py
139 lines (132 loc) · 4.94 KB
/
test_darray.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import numpy as np
from mpi4py import MPI
from mpi4py_fft import DistArray, newDistArray, PFFT
from mpi4py_fft.pencil import Subcomm
comm = MPI.COMM_WORLD
def test_1Darray():
N = (8,)
z = DistArray(N, val=2)
assert z[0] == 2
assert z.shape == N
def test_2Darray():
N = (8, 8)
for subcomm in ((0, 1), (1, 0), None, Subcomm(comm, (0, 1))):
for rank in (0, 1, 2):
M = (2,)*rank + N
alignment = None
if subcomm is None and rank == 1:
alignment = 1
a = DistArray(M, subcomm=subcomm, val=1, rank=rank, alignment=alignment)
assert a.rank == rank
assert a.global_shape == M
_ = a.substart
c = a.subcomm
z = a.commsizes
_ = a.pencil
assert np.prod(np.array(z)) == comm.Get_size()
if rank > 0:
a0 = a[0]
assert isinstance(a0, DistArray)
assert a0.rank == rank-1
aa = a.v
assert isinstance(aa, np.ndarray)
try:
k = a.get((0,)*rank+(0, slice(None)))
if comm.Get_rank() == 0:
assert len(k) == N[1]
assert np.sum(k) == N[1]
k = a.get((0,)*rank+(slice(None), 0))
if comm.Get_rank() == 0:
assert len(k) == N[0]
assert np.sum(k) == N[0]
except ModuleNotFoundError:
pass
_ = a.local_slice()
newaxis = (a.alignment+1)%2
p, t = a.get_pencil_and_transfer(newaxis)
a[:] = MPI.COMM_WORLD.Get_rank()
b = a.redistribute(newaxis)
a = b.redistribute(out=a)
a = b.redistribute(a.alignment, out=a)
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(a)**2)
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(b)**2)
if MPI.COMM_WORLD.Get_rank() == 0:
assert abs(s0-s1) < 1e-1
c = a.redistribute(a.alignment)
assert c is a
t.destroy()
def test_3Darray():
N = (8, 8, 8)
for subcomm in ((0, 0, 1), (0, 1, 0), (1, 0, 0), (0, 1, 1), (1, 0, 1), (1, 1, 0), None, Subcomm(comm, (0, 0, 1))):
for rank in (0, 1, 2):
M = (3,)*rank + N
alignment = None
if subcomm is None and rank == 1:
alignment = 2
a = DistArray(M, subcomm=subcomm, val=1, rank=rank, alignment=alignment)
assert a.rank == rank
assert a.global_shape == M
_ = a.substart
_ = a.subcomm
z = a.commsizes
_ = a.pencil
assert np.prod(np.array(z)) == comm.Get_size()
if rank > 0:
a0 = a[0]
assert isinstance(a0, DistArray)
assert a0.rank == rank-1
if rank == 2:
a0 = a[0, 1]
assert isinstance(a0, DistArray)
assert a0.rank == 0
aa = a.v
assert isinstance(aa, np.ndarray)
try:
k = a.get((0,)*rank+(0, 0, slice(None)))
if comm.Get_rank() == 0:
assert len(k) == N[2]
assert np.sum(k) == N[2]
k = a.get((0,)*rank+(slice(None), 0, 0))
if comm.Get_rank() == 0:
assert len(k) == N[0]
assert np.sum(k) == N[0]
except ModuleNotFoundError:
pass
_ = a.local_slice()
newaxis = (a.alignment+1)%3
p, t = a.get_pencil_and_transfer(newaxis)
a[:] = MPI.COMM_WORLD.Get_rank()
b = a.redistribute(newaxis)
a = b.redistribute(out=a)
s0 = MPI.COMM_WORLD.reduce(np.linalg.norm(a)**2)
s1 = MPI.COMM_WORLD.reduce(np.linalg.norm(b)**2)
if MPI.COMM_WORLD.Get_rank() == 0:
assert abs(s0-s1) < 1e-1
t.destroy()
def test_newDistArray():
N = (8, 8, 8)
pfft = PFFT(MPI.COMM_WORLD, N)
for forward_output in (True, False):
for view in (True, False):
for rank in (0, 1, 2):
a = newDistArray(pfft, forward_output=forward_output,
rank=rank, view=view)
if view is False:
assert isinstance(a, DistArray)
assert a.rank == rank
if rank == 0:
qfft = PFFT(MPI.COMM_WORLD, darray=a)
elif rank == 1:
qfft = PFFT(MPI.COMM_WORLD, darray=a[0])
else:
qfft = PFFT(MPI.COMM_WORLD, darray=a[0, 0])
qfft.destroy()
else:
assert isinstance(a, np.ndarray)
assert a.base.rank == rank
pfft.destroy()
if __name__ == '__main__':
test_1Darray()
test_2Darray()
test_3Darray()
test_newDistArray()