Skip to content

Commit 0d19794

Browse files
committed
fix weighting
1 parent 6c621eb commit 0d19794

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

interpolate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def _fill(self, holed_img, params):
5151
idx_select = filled_idx[idx[:, i]] # size: num_coords, 2
5252

5353
# add value of those coords, weighted by their distance
54-
vals += holed_img[idx_select[:, 0], idx_select[:, 1]] * dist[:, i]
55-
vals /= torch.sum(dist, dim=1)
54+
vals += holed_img[idx_select[:, 0], idx_select[:, 1]] * (1.0 / dist[:, i])
55+
vals /= torch.sum((1.0 / dist), dim=1)
5656
holed_img[unfilled_idx[:, 0], unfilled_idx[:, 1]] = vals
5757
return holed_img
5858

@@ -125,8 +125,8 @@ def _fill(self, holed_img, params):
125125
idx_select = filled_idx[idx[:, i]] # size: num_coords, 2
126126

127127
# add value of those coords, weighted by their distance
128-
vals += holed_img[idx_select[:, 0], idx_select[:, 1]] * dist[:, i]
129-
vals /= torch.sum(dist, dim=1)
128+
vals += holed_img[idx_select[:, 0], idx_select[:, 1]] * (1.0 / dist[:, i])
129+
vals /= torch.sum((1.0 / dist), dim=1)
130130
holed_img[unfilled_idx[:, 0], unfilled_idx[:, 1]] = vals
131131
return holed_img
132132

model.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch.nn as nn
22
from interpolate import TreeMultiQuad, TreeMultiRandom
33
import torch
4+
from debayer import Debayer3x3
45

56

67
class TreeModel(nn.Module):
@@ -22,7 +23,7 @@ def __init__(self, coded_type='irregular', sz=512, num_channels=6):
2223
self.tree = TreeMultiQuad(sz=sz)
2324
else:
2425
# irregular spaced points
25-
self.tree = TreeMultiRandom(sz=sz, num_channels=6)
26+
self.tree = TreeMultiRandom(sz=sz, num_channels=num_channels)
2627

2728
def forward(self, coded, lookup_channels=None):
2829
''' Coded is the single channel image we want to stack into multiple channels '''
@@ -32,6 +33,16 @@ def forward(self, coded, lookup_channels=None):
3233
return self.tree(coded)
3334

3435

36+
class InterpModel(nn.Module):
37+
''' Can use a bilinear interpolation module for comparison'''
38+
def __init__(self):
39+
super().__init__()
40+
self.interp = Debayer3x3().cuda()
41+
42+
def forward(self, coded):
43+
return self.interp(coded)
44+
45+
3546
if __name__ == '__main__':
3647
device = 'cuda:0'
3748
tree = TreeModel(coded_type='irregular', sz=512, num_channels=6)

0 commit comments

Comments
 (0)