Skip to content

Commit 769ff01

Browse files
committed
readme
1 parent 42797c8 commit 769ff01

File tree

5 files changed

+27
-25
lines changed

5 files changed

+27
-25
lines changed

Influence_function/influence_function.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,7 @@
1-
import faiss
2-
import numpy as np
3-
import torch
4-
import torchvision
51
import loss
62
from networks import Feat_resnet50_max_n, bninception
73
from utils import get_wrong_indices
84
import torch.nn as nn
9-
import os
10-
import matplotlib.pyplot as plt
11-
from torchvision.transforms.functional import normalize, resize, to_pil_image
12-
from torchvision.io.image import read_image
135
from Influence_function.EIF_utils import *
146
from Influence_function.IF_utils import *
157
import pickle

README.md

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@ Debugging and Explaining Metric Learning Approach: An Influence Function Perspec
33
==============================================================================
44

55
## Introduction
6-
Deep metric learning (DML) learns a generalizable embedding space of a dataset,
7-
where semantically similar samples are mapped closer.
6+
Deep metric learning (DML) learns a generalizable embedding space of a dataset, where semantically similar samples are mapped closer.
87
Recently, the record-breaking methodologies have been generally evolving from pairwise-based approaches to proxy-based approaches.
98
However, many recent works begin to achieve only marginal improvements on the classical datasets.
109
Thus, the explanation approaches of DML are in need for understanding
11-
**why the trained model can confuse the dissimilar samples and cannot recognize the similar samples**.
10+
**why the trained model can confuse the dissimilar samples?**.
1211

13-
To answer the above question, we conduct extensive experiments by running 2 comparable state-of-the-art DML approaches.
14-
The observation leads us to design an influence function based explanation framework to investigate the existing datasets, consisting of:
12+
The question motivates us to design an influence function based explanation framework to investigate the existing datasets, consisting of:
1513
- [x] Scalable training-sample attribution:
1614
- We propose empirical influence function to identify what training samples contribute to the generalization errors, and quantify how much contribution they make to the errors.
1715
- [x] Dataset relabelling recommendation:
1816
- We further aim to identify the potentially ``buggy'' training samples with mistaken labels and generate their relabelling recommendation.
1917

2018
## Requirements
21-
Install torch, torchvision compatible with your CUDA, see here: https://pytorch.org/get-started/previous-versions/
19+
- Step 1: Install torch, torchvision compatible with your CUDA, see here: [https://pytorch.org/get-started/previous-versions/](https://pytorch.org/get-started/previous-versions/)
20+
- Step 2: Install faiss compatible with your CUDA, see here: [https://github.com/facebookresearch/faiss/blob/main/INSTALL.md](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md)
21+
- Step 3:
2222
```
2323
pip install -r requirements.txt
2424
```
@@ -37,6 +37,22 @@ Put them under mnt/datasets/
3737
- We use the same hyperparameters specified in [Proxy-NCA++](https://github.com/euwern/proxynca_pp), except for In-Shop we reduce the batch size to 32 due to the limit of our GPU resources.
3838

3939
## Project Structure
40+
```
41+
|__ config/: training config json files
42+
|__ dataset/: define dataloader
43+
|__ mnt/datasets/
44+
|__ CARS_196/
45+
|__ CUB200_2011/
46+
|__ inshop/
47+
|__ evaluation/: evaluation script for recall@k, NMI etc.
48+
|__ experiments/: scripts for experiments
49+
|__ Influence_function/: implementation of IF and EIF
50+
|__ train.py: normal training script
51+
|__ train_noisy_data.py: noisy data trianing script
52+
|__ train_sample_reweight.py: re-weighted training script
53+
```
54+
55+
## Instructions
4056
- Training the original models
4157
- Training the DML models with Proxy-NCA++ loss or with SoftTriple loss
4258
```
@@ -81,10 +97,6 @@ python train_noisydata.py --dataset [cub_noisy|cars_noisy|inshop_noisy] \
8197

8298
   See experiments/sample_recommendation_evaluation.py
8399

84-
- Implementation of EIF
85-
86-
    See Influence_function/influence_function.py
87-
88100
## Results
89101
- All trained models: https://drive.google.com/drive/folders/1uzy3J78iwKZMCx_k5yESDLbcLl9RADDb?usp=sharing
90102
- For the detailed statistics of Table 1, please see https://docs.google.com/spreadsheets/d/1f4OXVLO2Mu2CHrBVm72a2ztTHx5nNG92dczTNNw7io4/edit?usp=sharing

loss.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@
33
from similarity import pairwise_distance
44
import torch.nn.functional as F
55
import sklearn.preprocessing
6-
import logging
7-
from scipy.optimize import linear_sum_assignment
8-
import utils
9-
from tqdm import tqdm
106

117
def masked_softmax(A, dim, t=1.0):
128
'''

requirements.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ scikit-learn
22
matplotlib
33
scipy
44
tqdm
5-
faiss-cpu
65
h5py
76
torchcam
8-
rembg
7+
rembg
8+
numpy
9+
argparse
10+
Pillow

train_sample_reweight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
import torch
77
import numpy as np
88
import matplotlib
9-
matplotlib.use('agg', force=True)
109
import time
1110
import argparse
1211
import json
1312
from tqdm import tqdm
1413
from torch.utils.data import Dataset, DataLoader
1514
import random
15+
matplotlib.use('agg', force=True)
1616

1717
os.environ["CUDA_VISIBLE_DEVICES"]="1, 0"
1818

0 commit comments

Comments
 (0)