Skip to content

Commit 9bcd835

Browse files
authored
add matlab imresize bicubic (XPixelGroup#317)
1 parent df5816f commit 9bcd835

File tree

1 file changed

+169
-0
lines changed

1 file changed

+169
-0
lines changed

basicsr/utils/matlab_functions.py

+169
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,173 @@
1+
import math
12
import numpy as np
3+
import torch
4+
5+
6+
def cubic(x):
7+
"""cubic function used for calculate_weights_indices."""
8+
absx = torch.abs(x)
9+
absx2 = absx**2
10+
absx3 = absx**3
11+
return (1.5 * absx3 - 2.5 * absx2 + 1) * (
12+
(absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx +
13+
2) * (((absx > 1) *
14+
(absx <= 2)).type_as(absx))
15+
16+
17+
def calculate_weights_indices(in_length, out_length, scale, kernel,
18+
kernel_width, antialiasing):
19+
"""Calculate weights and indices, used for imresize function.
20+
21+
Args:
22+
in_length (int): Input length.
23+
out_length (int): Output length.
24+
scale (float): Scale factor.
25+
kernel_width (int): Kernel width.
26+
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
27+
"""
28+
29+
if (scale < 1) and antialiasing:
30+
# Use a modified kernel (larger kernel width) to simultaneously
31+
# interpolate and antialias
32+
kernel_width = kernel_width / scale
33+
34+
# Output-space coordinates
35+
x = torch.linspace(1, out_length, out_length)
36+
37+
# Input-space coordinates. Calculate the inverse mapping such that 0.5
38+
# in output space maps to 0.5 in input space, and 0.5 + scale in output
39+
# space maps to 1.5 in input space.
40+
u = x / scale + 0.5 * (1 - 1 / scale)
41+
42+
# What is the left-most pixel that can be involved in the computation?
43+
left = torch.floor(u - kernel_width / 2)
44+
45+
# What is the maximum number of pixels that can be involved in the
46+
# computation? Note: it's OK to use an extra pixel here; if the
47+
# corresponding weights are all zero, it will be eliminated at the end
48+
# of this function.
49+
p = math.ceil(kernel_width) + 2
50+
51+
# The indices of the input pixels involved in computing the k-th output
52+
# pixel are in row k of the indices matrix.
53+
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(
54+
0, p - 1, p).view(1, p).expand(out_length, p)
55+
56+
# The weights used to compute the k-th output pixel are in row k of the
57+
# weights matrix.
58+
distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
59+
60+
# apply cubic kernel
61+
if (scale < 1) and antialiasing:
62+
weights = scale * cubic(distance_to_center * scale)
63+
else:
64+
weights = cubic(distance_to_center)
65+
66+
# Normalize the weights matrix so that each row sums to 1.
67+
weights_sum = torch.sum(weights, 1).view(out_length, 1)
68+
weights = weights / weights_sum.expand(out_length, p)
69+
70+
# If a column in weights is all zero, get rid of it. only consider the
71+
# first and last column.
72+
weights_zero_tmp = torch.sum((weights == 0), 0)
73+
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
74+
indices = indices.narrow(1, 1, p - 2)
75+
weights = weights.narrow(1, 1, p - 2)
76+
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
77+
indices = indices.narrow(1, 0, p - 2)
78+
weights = weights.narrow(1, 0, p - 2)
79+
weights = weights.contiguous()
80+
indices = indices.contiguous()
81+
sym_len_s = -indices.min() + 1
82+
sym_len_e = indices.max() - in_length
83+
indices = indices + sym_len_s - 1
84+
return weights, indices, int(sym_len_s), int(sym_len_e)
85+
86+
87+
@torch.no_grad()
88+
def imresize(img, scale, antialiasing=True):
89+
"""imresize function same as MATLAB.
90+
91+
It now only supports bicubic.
92+
The same scale applies for both height and width.
93+
94+
Args:
95+
img (Tensor | Numpy array):
96+
Tensor: Input image with shape (c, h, w), [0, 1] range.
97+
Numpy: Input image with shape (h, w, c), [0, 1] range.
98+
scale (float): Scale factor. The same scale applies for both height
99+
and width.
100+
antialisaing (bool): Whether to apply anti-aliasing when downsampling.
101+
Default: True.
102+
103+
Returns:
104+
Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
105+
"""
106+
if type(img).__module__ == np.__name__: # numpy type
107+
numpy_type = True
108+
img = torch.from_numpy(img.transpose(2, 0, 1)).float()
109+
else:
110+
numpy_type = False
111+
112+
in_c, in_h, in_w = img.size()
113+
out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
114+
kernel_width = 4
115+
kernel = 'cubic'
116+
117+
# get weights and indices
118+
weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(
119+
in_h, out_h, scale, kernel, kernel_width, antialiasing)
120+
weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(
121+
in_w, out_w, scale, kernel, kernel_width, antialiasing)
122+
# process H dimension
123+
# symmetric copying
124+
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
125+
img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
126+
127+
sym_patch = img[:, :sym_len_hs, :]
128+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
129+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
130+
img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
131+
132+
sym_patch = img[:, -sym_len_he:, :]
133+
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
134+
sym_patch_inv = sym_patch.index_select(1, inv_idx)
135+
img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
136+
137+
out_1 = torch.FloatTensor(in_c, out_h, in_w)
138+
kernel_width = weights_h.size(1)
139+
for i in range(out_h):
140+
idx = int(indices_h[i][0])
141+
for j in range(in_c):
142+
out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(
143+
0, 1).mv(weights_h[i])
144+
145+
# process W dimension
146+
# symmetric copying
147+
out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
148+
out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
149+
150+
sym_patch = out_1[:, :, :sym_len_ws]
151+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
152+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
153+
out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
154+
155+
sym_patch = out_1[:, :, -sym_len_we:]
156+
inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
157+
sym_patch_inv = sym_patch.index_select(2, inv_idx)
158+
out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
159+
160+
out_2 = torch.FloatTensor(in_c, out_h, out_w)
161+
kernel_width = weights_w.size(1)
162+
for i in range(out_w):
163+
idx = int(indices_w[i][0])
164+
for j in range(in_c):
165+
out_2[j, :, i] = out_1_aug[j, :,
166+
idx:idx + kernel_width].mv(weights_w[i])
167+
168+
if numpy_type:
169+
out_2 = out_2.numpy().transpose(1, 2, 0)
170+
return out_2
2171

3172

4173
def rgb2ycbcr(img, y_only=False):

0 commit comments

Comments
 (0)