@@ -86,6 +86,11 @@ def __init__(self, subscription_info):
86
86
the client.
87
87
88
88
"""
89
+ # Python 2 v. 3 hack
90
+ try :
91
+ self .basestr = basestring
92
+ except NameError :
93
+ self .basestr = str
89
94
if 'endpoint' not in subscription_info :
90
95
raise WebPushException ("subscription_info missing endpoint URL" )
91
96
if 'keys' not in subscription_info :
@@ -95,22 +100,22 @@ def __init__(self, subscription_info):
95
100
for k in ['p256dh' , 'auth' ]:
96
101
if keys .get (k ) is None :
97
102
raise WebPushException ("Missing keys value: %s" , k )
98
- receiver_raw = base64 .urlsafe_b64decode (
99
- self ._repad (keys ['p256dh' ].encode ('utf8' )))
103
+ if isinstance (keys [k ], self .basestr ):
104
+ keys [k ] = bytes (keys [k ].encode ('utf8' ))
105
+ receiver_raw = base64 .urlsafe_b64decode (self ._repad (keys ['p256dh' ]))
100
106
if len (receiver_raw ) != 65 and receiver_raw [0 ] != "\x04 " :
101
107
raise WebPushException ("Invalid p256dh key specified" )
102
108
self .receiver_key = receiver_raw
103
- self .auth_key = base64 .urlsafe_b64decode (
104
- self ._repad (keys ['auth' ].encode ('utf8' )))
109
+ self .auth_key = base64 .urlsafe_b64decode (self ._repad (keys ['auth' ]))
105
110
106
- def _repad (self , str ):
111
+ def _repad (self , data ):
107
112
"""Add base64 padding to the end of a string, if required"""
108
- return str + "====" [:len (str ) % 4 ]
113
+ return data + b "====" [:len (data ) % 4 ]
109
114
110
115
def encode (self , data ):
111
116
"""Encrypt the data.
112
117
113
- :param data: A serialized block of data (String, JSON, bit array,
118
+ :param data: A serialized block of byte data (String, JSON, bit array,
114
119
etc.) Make sure that whatever you send, your client knows how
115
120
to understand it.
116
121
@@ -124,6 +129,9 @@ def encode(self, data):
124
129
# ID tag.
125
130
server_key_id = base64 .urlsafe_b64encode (server_key .get_pubkey ()[1 :])
126
131
132
+ if isinstance (data , self .basestr ):
133
+ data = bytes (data .encode ('utf8' ))
134
+
127
135
# http_ece requires that these both be set BEFORE encrypt or
128
136
# decrypt is called if you specify the key as "dh".
129
137
http_ece .keys [server_key_id ] = server_key
@@ -138,8 +146,8 @@ def encode(self, data):
138
146
139
147
return CaseInsensitiveDict ({
140
148
'crypto_key' : base64 .urlsafe_b64encode (
141
- server_key .get_pubkey ()).strip ('=' ),
142
- 'salt' : base64 .urlsafe_b64encode (salt ).strip ("=" ),
149
+ server_key .get_pubkey ()).strip (b '=' ),
150
+ 'salt' : base64 .urlsafe_b64encode (salt ).strip (b'=' ),
143
151
'body' : encrypted ,
144
152
})
145
153
@@ -160,11 +168,12 @@ def send(self, data, headers={}, ttl=0):
160
168
crypto_key = headers .get ("crypto-key" , "" )
161
169
if crypto_key :
162
170
crypto_key += ','
163
- crypto_key += "keyid=p256dh;dh=" + encoded ["crypto_key" ]
171
+ crypto_key += "keyid=p256dh;dh=" + encoded ["crypto_key" ]. decode ( 'utf8' )
164
172
headers .update ({
165
173
'crypto-key' : crypto_key ,
166
174
'content-encoding' : 'aesgcm' ,
167
- 'encryption' : "keyid=p256dh;salt=" + encoded ['salt' ],
175
+ 'encryption' : "keyid=p256dh;salt=" +
176
+ encoded ['salt' ].decode ('utf8' ),
168
177
})
169
178
if 'ttl' not in headers or ttl :
170
179
headers ['ttl' ] = ttl
0 commit comments