Skip to content

Commit c987ee4

Browse files
committed
fix convtranspose2d
1 parent 51bc1e1 commit c987ee4

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

tensorlayerx/backend/ops/torch_nn.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -1235,7 +1235,8 @@ def same_padding_deconvolution(input, weight, strides, dilations):
12351235
input_cols = input.size(3)
12361236

12371237
out_rows = input_rows * strides[0] - strides[0] + 1
1238-
out_cols = input_rows * strides[1] - strides[1] + 1
1238+
out_cols = input_cols * strides[1] - strides[1] + 1
1239+
12391240

12401241
padding_rows = max(0, (input_rows - 1) * strides[0] + (filter_rows - 1) * dilations[0] + 1 - out_rows)
12411242
padding_cols = max(0, (input_cols - 1) * strides[1] + (filter_cols - 1) * dilations[1] + 1 - out_cols)
@@ -1250,8 +1251,8 @@ def same_padding_deconvolution(input, weight, strides, dilations):
12501251
input_depth = input.size(4)
12511252

12521253
out_rows = input_rows * strides[0] - strides[0] + 1
1253-
out_cols = input_rows * strides[1] - strides[1] + 1
1254-
out_depth = input_rows * strides[2] - strides[2] + 1
1254+
out_cols = input_cols * strides[1] - strides[1] + 1
1255+
out_depth = input_depth * strides[2] - strides[2] + 1
12551256

12561257
padding_rows = max(0, (input_rows - 1) * strides[0] + (filter_rows - 1) * dilations[0] + 1 - out_rows)
12571258
padding_cols = max(0, (input_cols - 1) * strides[1] + (filter_cols - 1) * dilations[1] + 1 - out_cols)
@@ -1410,7 +1411,7 @@ def __call__(self, input, filters, output_size):
14101411
if self.data_format == 'NHWC':
14111412
input = nhwc_to_nchw(input)
14121413
if self.padding == 'same':
1413-
out = self.conv2d_transpore_same(input, filters, output_size)
1414+
out = self.conv2d_transpore_same(input, filters)
14141415
else:
14151416
out_padding = self._output_padding(input, output_size, self.strides, (0 if isinstance(self.padding, str) else self.padding),
14161417
filters.shape,
@@ -1428,15 +1429,13 @@ def __call__(self, input, filters, output_size):
14281429
out = nchw_to_nhwc(out)
14291430
return out
14301431

1431-
def conv2d_transpore_same(self,input, filters, output_size):
1432+
def conv2d_transpore_same(self,input, filters):
14321433
rows_odd, cols_odd, padding_rows, padding_cols = same_padding_deconvolution(input, filters, self.strides, (1, 1))
14331434
if rows_odd or cols_odd:
14341435
input = F.pad(input, [0, int(rows_odd), 0, int(cols_odd)])
14351436
out_padding = 0
14361437
else:
14371438
out_padding = 1
1438-
out_padding = self._output_padding(input, output_size, self.strides, (padding_rows // 2, padding_cols // 2), filters.shape,
1439-
2, self.dilations)
14401439
out = F.conv_transpose2d(input, weight=filters, padding=(padding_rows // 2, padding_cols // 2), stride=self.strides,
14411440
dilation=self.dilations, output_padding=out_padding, groups=self.groups)
14421441
return out

0 commit comments

Comments
 (0)