Skip to content

Commit 70d77df

Browse files
author
chenxj
committed
DeepFM model
0 parents  commit 70d77df

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed

DeepFM.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
A pytorch implementation of DeepFM for rates prediction problem.
5+
6+
Created on Aug 10, 2018
7+
"""
8+
9+
__author__ = 'Xijun Chen'
10+
11+
import torch
12+
import torch.nn as nn
13+
import torch.nn.functional as F
14+
import torch.optim as optim
15+
16+
17+
class DeepFM(nn.Module):
18+
"""
19+
A DeepFM network with RMSE loss for rates prediction problem.
20+
21+
There are two parts in the architecture of this network: fm part for low
22+
order interactions of features and deep part for higher order. In this
23+
network, we use bachnorm and dropout technology for all hidden layers,
24+
and "Adam" method for optimazation.
25+
26+
You may find more details in this paper:
27+
DeepFM: A Factorization-Machine based Neural Network for CTR Prediction,
28+
Huifeng Guo, Ruiming Tang, Yunming Yey, Zhenguo Li, Xiuqiang He.
29+
"""
30+
31+
def __init__(self, feature_sizes, embedding_size=4,
32+
hidden_dims=[32, 32], num_classes=10, dropout=[0.5, 0.5],
33+
use_cuda=True, verbose=False):
34+
"""
35+
Initialize a new network
36+
37+
Inputs:
38+
- feature_size: A list of integer giving the size of features for each field.
39+
- embedding_size: An integer giving size of feature embedding.
40+
- hidden_dims: A list of integer giving the size of each hidden layer.
41+
- num_classes: An integer giving the number of classes to predict. For example,
42+
someone may rate 1,2,3,4 or 5 stars to a film.
43+
- batch_size: An integer giving size of instances used in each interation.
44+
- use_cuda: Bool, Using cuda or not
45+
- verbose: Bool
46+
"""
47+
super().__init__()
48+
self.field_size = len(feature_sizes)
49+
self.feature_sizes = feature_sizes
50+
self.embedding_size = embedding_size
51+
self.hidden_dims = hidden_dims
52+
self.num_classes = num_classes
53+
self.bias = torch.nn.Parameter(torch.randn(1))
54+
"""
55+
check if use cuda
56+
"""
57+
if use_cuda and torch.cuda.is_available():
58+
self.device = torch.device('cuda')
59+
else:
60+
self.device = torch.device('cpu')
61+
"""
62+
init fm part
63+
"""
64+
self.fm_first_order_embeddings = nn.ModuleList(
65+
[nn.Embedding(feature_size, 1) for feature_size in self.feature_sizes])
66+
self.fm_second_order_embeddings = nn.ModuleList(
67+
[nn.Embedding(feature_size, self.embedding_size) for feature_size in self.feature_sizes])
68+
"""
69+
init deep part
70+
"""
71+
all_dims = [self.field_size * self.embedding_size] + \
72+
self.hidden_dims + [self.num_classes]
73+
for i in range(1, len(hidden_dims) + 1):
74+
setattr(self, 'linear_'+str(i),
75+
nn.Linear(all_dims[i-1], all_dims[i]))
76+
# nn.init.kaiming_normal_(self.fc1.weight)
77+
setattr(self, 'batchNorm_' + str(i),
78+
nn.BatchNorm1d(all_dims[i]))
79+
setattr(self, 'dropout_'+str(i),
80+
nn.Dropout(dropout[i-1]))
81+
82+
def forward(self, Xi, Xv):
83+
"""
84+
Forward process of network.
85+
86+
Inputs:
87+
- Xi: A tensor of input's index, shape of (N, embedding_size, 1)
88+
- Xv: A tensor of input's value, shape of (N, embedding_size, 1)
89+
"""
90+
"""
91+
fm part
92+
"""
93+
fm_first_order_emb_arr = [(torch.sum(emb(Xi[:, i, :]), 1).t() * \
94+
Xv[:, i]).t() for i, emb in enumerate(self.fm_first_order_embeddings)]
95+
fm_first_order = torch.cat(fm_first_order_emb_arr, 1)
96+
# use 2xy = (x+y)^2 - x^2 - y^2 reduce calculation
97+
fm_second_order_emb_arr = [(torch.sum(emb(Xi[:, i, :]), 1).t() * \
98+
Xv[:, i]).t() for i, emb in enumerate(self.fm_second_order_embeddings)]
99+
fm_sum_second_order_emb = sum(fm_second_order_emb_arr)
100+
fm_sum_second_order_emb_square = fm_sum_second_order_emb * \
101+
fm_sum_second_order_emb # (x+y)^2
102+
fm_second_order_emb_square = [
103+
item*item for item in fm_second_order_emb_arr]
104+
fm_second_order_emb_square_sum = sum(
105+
fm_second_order_emb_square) # x^2+y^2
106+
fm_second_order = (fm_sum_second_order_emb_square -
107+
fm_second_order_emb_square_sum) * 0.5
108+
"""
109+
deep part
110+
"""
111+
deep_emb = torch.cat(fm_second_order_emb_arr, 1)
112+
deep_out = deep_emb
113+
for i in range(1, self.hidden_dims + 1):
114+
deep_out = getattr(self, 'linear_' + str(i))(deep_out)
115+
deep_out = getattr(self, 'batchNorm_' + str(i))(deep_out)
116+
deep_out = getattr(self, 'dropout_' + str(i))(deep_out)
117+
"""
118+
sum
119+
"""
120+
total_sum = torch.sum(fm_first_order, 1) + \
121+
torch.sum(fm_second_order, 1) + torch.sum(deep_out, 1) + self.bias
122+
return total_sum
123+
124+
125+
126+
127+

0 commit comments

Comments
 (0)