-
Notifications
You must be signed in to change notification settings - Fork 5
/
Mypnufft_mc_func_cardiac.py
213 lines (165 loc) · 6.63 KB
/
Mypnufft_mc_func_cardiac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import torch
from torch import nn, optim
from torch import Tensor
import numpy as np
import numpy
import scipy.misc
import matplotlib.pyplot
import scipy.io as sio
import nufft
dtype = numpy.complex64
def np_to_torch_(img_np):
'''Converts image in numpy.array to torch.Tensor.
From C x W x H [0..1] to C x W x H [0..1]
'''
return torch.from_numpy(img_np).cuda()
def torch_to_np_(img_var):
'''Converts an image in torch.Tensor format to np.array.
From 1 x C x W x H [0..1] to C x W x H [0..1]
'''
return img_var.detach().cpu().numpy().astype(np.float32)
class Mypnufft_cardiac_func(torch.autograd.Function):
"""
We can implement our own custom autograd Functions by subclassing
torch.autograd.Function and implementing the forward and backward passes
which operate on Tensors.
NuFFT from https://github.com/jyhmiinlin/pynufft
"""
@staticmethod
def forward(ctx,input_r,angle,Nspoke,Nvec,Nc,C,w):
"""
In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output. ctx is a context object that can be used
to stash information for backward computation. You can cache arbitrary
objects for use in the backward pass using the ctx.save_for_backward method.
"""
N=np.shape(input_r)[-2]
if np.mod(N,2)==0:
x=[np.linspace(-N//2,N//2-1,N),
np.linspace(-N//2,N//2-1,N)
]
else:
x=[np.linspace(-N//2,N//2,N),
np.linspace(-N//2,N//2,N)
]
X = np.meshgrid(x[0], x[1], indexing='ij')
x1=X[0].reshape(-1)
x2=X[1].reshape(-1)
## ctx define
ctx.x1=x1
ctx.x2=x2
ctx.N=N
ctx.Nc=Nc
ctx.Nspoke=Nspoke
ctx.Nvec=Nvec
ctx.angle=angle
ctx.wr=w
ctx.C=C
###########
input=torch_to_np_(input_r)
input_c=input[...,0]+1j*input[...,1]
input_c=np.tile(input_c[np.newaxis],(Nc,1,1))
input_c*=C
y=np.zeros((Nc,Nvec*Nspoke),dtype=np.complex64)
for it in range(Nc):
y[it,:] = nufft.nufft2d3(-x1,-x2,input_c[it,:,:].reshape(-1),angle[:,0],angle[:,1],iflag=0)
# density correction
#y=y*self.wr
y = y[...,np.newaxis]
y_c = np.concatenate((np.real(y),np.imag(y)),axis=-1)
y_t = np_to_torch_(y_c.astype(np.float32))
return y_t
@staticmethod
def backward(ctx,grad_output):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
angle=ctx.angle
grad_output_n=torch_to_np_(grad_output)
grad_output=grad_output_n[...,0]+1j*grad_output_n[...,1]
yr=np.reshape(grad_output,(ctx.Nc,ctx.Nspoke,ctx.Nvec))
yc=yr*ctx.wr
out=np.zeros((ctx.Nc,ctx.N,ctx.N),dtype=np.complex64)
for it in range(ctx.Nc):
x_re = nufft.nufft2d3(angle[:,0],angle[:,1],yc[it,:,:].reshape(-1),ctx.x1,ctx.x2,iflag=1)
tmp = x_re.reshape(ctx.N,ctx.N)
out[it,:,:]=tmp
out=np.sum(out*np.conj(ctx.C),0)/sum( np.abs(ctx.C)**2,0 )
out = out[...,np.newaxis]
out_c = np.concatenate((np.real(out),np.imag(out)),axis=-1)
grad_output = np_to_torch_(out_c.astype(np.float32))
return grad_output, None, None, None, None, None, None
class Mypnufft_cardiac_test(nn.Module):
def __init__(self,ImageSize, angle, Nspoke,Nvec,Nc,C,w):
super(Mypnufft_cardiac_test,self).__init__()
N=ImageSize
if np.mod(N,2)==0:
x=[np.linspace(-N//2,N//2-1,N),
np.linspace(-N//2,N//2-1,N)
]
else:
x=[np.linspace(-N//2,N//2,N),
np.linspace(-N//2,N//2,N)
]
X = np.meshgrid(x[0], x[1], indexing='ij')
x1=X[0].reshape(-1)
x2=X[1].reshape(-1)
self.x1=x1
self.x2=x2
self.N=N
self.Nc=Nc
self.Nspoke=Nspoke
self.Nvec=Nvec
self.angle=angle
self.C=C
self.wr=w#np.sqrt(wr)
def forward(self,input):
"""
In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output. ctx is a context object that can be used
to stash information for backward computation. You can cache arbitrary
objects for use in the backward pass using the ctx.save_for_backward method.
"""
angle=self.angle
input_c=input[...,0]+1j*input[...,1]
input_c=input_c
# input_c*=self.C
y=np.zeros((self.Nc,self.Nvec*self.Nspoke),dtype=np.complex64)
for it in range(self.Nc):
y[it,:] = nufft.nufft2d3(-self.x1,-self.x2,input_c[it,:,:].reshape(-1),angle[:,0],angle[:,1],iflag=0)
# density correction
#y=y*self.wr
y = y[...,np.newaxis]
y_c = np.concatenate((np.real(y),np.imag(y)),axis=-1)
y_t = y_c.astype(np.float32)
return y_t
def backward(self,grad_output_n):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
angle=self.angle
grad_output=grad_output_n[...,0]+1j*grad_output_n[...,1]
yr=np.reshape(grad_output,(self.Nc,self.Nspoke,self.Nvec))
yc=yr*self.wr
out=np.zeros((self.Nc,self.N,self.N),dtype=np.complex64)
for it in range(self.Nc):
x_re = nufft.nufft2d3(angle[:,0],angle[:,1],yc[it,:,:].reshape(-1),self.x1,self.x2,iflag=1)
tmp = x_re.reshape(self.N,self.N)
out[it,:,:]=tmp
#print(out.shape)
#out*=(np.pi/2/self.Nspoke)
# coil combination
out=np.sum(out*np.conj(self.C),0)/sum( np.abs(self.C)**2,0 )
out = out[...,np.newaxis]
out_c = np.concatenate((np.real(out),np.imag(out)),axis=-1)
return out_c
class Mypnufft_cardiac(nn.Module):
def __init__(self,ImageSize,Nc):
super(Mypnufft_cardiac,self).__init__()
self.X=Tensor(ImageSize,ImageSize,Nc).fill_(0).cuda()
def forward(self,angles, Nspoke,Nvec,Nc,C,w):
return Mypnufft_cardiac_func.apply(self.X, angles, Nspoke,Nvec,Nc,C,w)