@@ -18,9 +18,10 @@ use tokio::{net::UdpSocket, sync::mpsc};
18
18
19
19
use super :: { DEFAULT_TTL , MDNS_PORT } ;
20
20
use crate :: address_family:: AddressFamily ;
21
- use crate :: services:: { ServiceData , Services } ;
21
+ use crate :: services:: { ServiceData , Services , ServicesInner } ;
22
22
23
23
pub type AnswerBuilder = dns_parser:: Builder < dns_parser:: Answers > ;
24
+ pub type AdditionalBuilder = dns_parser:: Builder < dns_parser:: Additional > ;
24
25
25
26
const SERVICE_TYPE_ENUMERATION_NAME : Cow < ' static , str > =
26
27
Cow :: Borrowed ( "_services._dns-sd._udp.local" ) ;
@@ -104,57 +105,46 @@ impl<AF: AddressFamily> FSM<AF> {
104
105
return ;
105
106
}
106
107
107
- let mut unicast_builder = dns_parser:: Builder :: new_response ( packet. header . id , false , true )
108
- . move_to :: < dns_parser:: Answers > ( ) ;
109
- let mut multicast_builder =
110
- dns_parser:: Builder :: new_response ( packet. header . id , false , true )
111
- . move_to :: < dns_parser:: Answers > ( ) ;
112
- unicast_builder. set_max_size ( None ) ;
113
- multicast_builder. set_max_size ( None ) ;
114
-
115
108
for question in packet. questions {
116
109
debug ! (
117
110
"received question: {:?} {}" ,
118
111
question. qclass, question. qname
119
112
) ;
120
113
121
114
if question. qclass == QueryClass :: IN || question. qclass == QueryClass :: Any {
115
+ let mut builder = dns_parser:: Builder :: new_response ( packet. header . id , false , true )
116
+ . move_to :: < dns_parser:: Answers > ( ) ;
117
+ builder. set_max_size ( None ) ;
118
+ let builder = self . handle_question ( & question, builder) ;
119
+ if builder. is_empty ( ) {
120
+ continue ;
121
+ }
122
+ let response = builder. build ( ) . unwrap_or_else ( |x| x) ;
122
123
if question. qu {
123
- unicast_builder = self . handle_question ( & question , unicast_builder ) ;
124
+ self . outgoing . push_back ( ( response , addr ) ) ;
124
125
} else {
125
- multicast_builder = self . handle_question ( & question, multicast_builder) ;
126
+ let addr = SocketAddr :: new ( AF :: MDNS_GROUP . into ( ) , MDNS_PORT ) ;
127
+ self . outgoing . push_back ( ( response, addr) ) ;
126
128
}
127
129
}
128
130
}
129
-
130
- if !multicast_builder. is_empty ( ) {
131
- let response = multicast_builder. build ( ) . unwrap_or_else ( |x| x) ;
132
- let addr = SocketAddr :: new ( AF :: MDNS_GROUP . into ( ) , MDNS_PORT ) ;
133
- self . outgoing . push_back ( ( response, addr) ) ;
134
- }
135
-
136
- if !unicast_builder. is_empty ( ) {
137
- let response = unicast_builder. build ( ) . unwrap_or_else ( |x| x) ;
138
- self . outgoing . push_back ( ( response, addr) ) ;
139
- }
140
131
}
141
132
142
133
/// https://www.rfc-editor.org/rfc/rfc6763#section-9
143
134
fn handle_service_type_enumeration < ' a > (
144
135
question : & dns_parser:: Question ,
145
- services : impl Iterator < Item = & ' a ServiceData > ,
136
+ services : & ServicesInner ,
146
137
mut builder : AnswerBuilder ,
147
138
) -> AnswerBuilder {
148
139
let service_type_enumeration_name = Name :: FromStr ( SERVICE_TYPE_ENUMERATION_NAME ) ;
149
140
if question. qname == service_type_enumeration_name {
150
- for svc in services {
151
- let svc_type = ServiceData {
152
- name : svc. typ . clone ( ) ,
153
- typ : service_type_enumeration_name. clone ( ) ,
154
- port : svc. port ,
155
- txt : vec ! [ ] ,
156
- } ;
157
- builder = svc_type. add_ptr_rr ( builder, DEFAULT_TTL ) ;
141
+ for typ in services. all_types ( ) {
142
+ builder = builder. add_answer (
143
+ & service_type_enumeration_name,
144
+ QueryClass :: IN ,
145
+ DEFAULT_TTL ,
146
+ & RRData :: PTR ( typ. clone ( ) ) ,
147
+ ) ;
158
148
}
159
149
}
160
150
@@ -165,93 +155,136 @@ impl<AF: AddressFamily> FSM<AF> {
165
155
& self ,
166
156
question : & dns_parser:: Question ,
167
157
mut builder : AnswerBuilder ,
168
- ) -> AnswerBuilder {
158
+ ) -> AdditionalBuilder {
169
159
let services = self . services . read ( ) . unwrap ( ) ;
170
160
let hostname = services. get_hostname ( ) ;
171
161
172
162
match question. qtype {
173
- QueryType :: A | QueryType :: AAAA if question. qname == * hostname => {
174
- builder = self . add_ip_rr ( hostname, builder , DEFAULT_TTL ) ;
175
- }
163
+ QueryType :: A | QueryType :: AAAA if question. qname == * hostname => builder
164
+ . add_answers ( hostname, QueryClass :: IN , DEFAULT_TTL , self . ip_rr ( ) )
165
+ . move_to ( ) ,
176
166
QueryType :: All => {
167
+ let mut include_ip_additionals = false ;
177
168
// A / AAAA
178
169
if question. qname == * hostname {
179
- builder = self . add_ip_rr ( hostname, builder, DEFAULT_TTL ) ;
170
+ builder =
171
+ builder. add_answers ( hostname, QueryClass :: IN , DEFAULT_TTL , self . ip_rr ( ) ) ;
180
172
}
181
173
// PTR
182
- builder =
183
- Self :: handle_service_type_enumeration ( question, services. into_iter ( ) , builder) ;
174
+ builder = Self :: handle_service_type_enumeration ( question, & services, builder) ;
184
175
for svc in services. find_by_type ( & question. qname ) {
185
- builder = svc. add_ptr_rr ( builder, DEFAULT_TTL ) ;
186
- builder = svc. add_srv_rr ( hostname, builder, DEFAULT_TTL ) ;
187
- builder = svc. add_txt_rr ( builder, DEFAULT_TTL ) ;
188
- builder = self . add_ip_rr ( hostname, builder, DEFAULT_TTL ) ;
176
+ builder =
177
+ builder. add_answer ( & svc. typ , QueryClass :: IN , DEFAULT_TTL , & svc. ptr_rr ( ) ) ;
178
+ include_ip_additionals = true ;
189
179
}
190
180
// SRV
191
181
if let Some ( svc) = services. find_by_name ( & question. qname ) {
192
- builder = svc. add_srv_rr ( hostname, builder, DEFAULT_TTL ) ;
193
- builder = self . add_ip_rr ( hostname, builder, DEFAULT_TTL ) ;
182
+ builder = builder. add_answer (
183
+ & svc. name ,
184
+ QueryClass :: IN ,
185
+ DEFAULT_TTL ,
186
+ & svc. srv_rr ( hostname) ,
187
+ ) ;
188
+ include_ip_additionals = true ;
189
+ }
190
+ let mut builder = builder. move_to :: < dns_parser:: Additional > ( ) ;
191
+ // PTR (additional)
192
+ for svc in services. find_by_type ( & question. qname ) {
193
+ builder = builder
194
+ . add_additional (
195
+ & svc. name ,
196
+ QueryClass :: IN ,
197
+ DEFAULT_TTL ,
198
+ & svc. srv_rr ( hostname) ,
199
+ )
200
+ . add_additional ( & svc. name , QueryClass :: IN , DEFAULT_TTL , & svc. txt_rr ( ) ) ;
201
+ include_ip_additionals = true ;
202
+ }
203
+
204
+ if include_ip_additionals {
205
+ builder = builder. add_additionals (
206
+ hostname,
207
+ QueryClass :: IN ,
208
+ DEFAULT_TTL ,
209
+ self . ip_rr ( ) ,
210
+ ) ;
194
211
}
212
+ builder
195
213
}
196
214
QueryType :: PTR => {
197
- builder =
198
- Self :: handle_service_type_enumeration ( question, services. into_iter ( ) , builder) ;
215
+ let mut builder =
216
+ Self :: handle_service_type_enumeration ( question, & services, builder) ;
217
+ for svc in services. find_by_type ( & question. qname ) {
218
+ builder =
219
+ builder. add_answer ( & svc. typ , QueryClass :: IN , DEFAULT_TTL , & svc. ptr_rr ( ) )
220
+ }
221
+ let mut builder = builder. move_to :: < dns_parser:: Additional > ( ) ;
199
222
for svc in services. find_by_type ( & question. qname ) {
200
- builder = svc. add_ptr_rr ( builder, DEFAULT_TTL ) ;
201
- builder = svc. add_srv_rr ( hostname, builder, DEFAULT_TTL ) ;
202
- builder = svc. add_txt_rr ( builder, DEFAULT_TTL ) ;
203
- builder = self . add_ip_rr ( hostname, builder, DEFAULT_TTL ) ;
223
+ builder = builder
224
+ . add_additional (
225
+ & svc. name ,
226
+ QueryClass :: IN ,
227
+ DEFAULT_TTL ,
228
+ & svc. srv_rr ( hostname) ,
229
+ )
230
+ . add_additional ( & svc. name , QueryClass :: IN , DEFAULT_TTL , & svc. txt_rr ( ) )
231
+ . add_additionals ( hostname, QueryClass :: IN , DEFAULT_TTL , self . ip_rr ( ) ) ;
204
232
}
233
+ builder
205
234
}
206
235
QueryType :: SRV => {
207
236
if let Some ( svc) = services. find_by_name ( & question. qname ) {
208
- builder = svc. add_srv_rr ( hostname, builder, DEFAULT_TTL ) ;
209
- builder = self . add_ip_rr ( hostname, builder, DEFAULT_TTL ) ;
237
+ builder
238
+ . add_answer (
239
+ & svc. name ,
240
+ QueryClass :: IN ,
241
+ DEFAULT_TTL ,
242
+ & svc. srv_rr ( hostname) ,
243
+ )
244
+ . add_additionals ( hostname, QueryClass :: IN , DEFAULT_TTL , self . ip_rr ( ) )
245
+ . move_to ( )
246
+ } else {
247
+ builder. move_to ( )
210
248
}
211
249
}
212
250
QueryType :: TXT => {
213
251
if let Some ( svc) = services. find_by_name ( & question. qname ) {
214
- builder = svc. add_txt_rr ( builder, DEFAULT_TTL ) ;
252
+ builder
253
+ . add_answer ( & svc. name , QueryClass :: IN , DEFAULT_TTL , & svc. txt_rr ( ) )
254
+ . move_to ( )
255
+ } else {
256
+ builder. move_to ( )
215
257
}
216
258
}
217
- _ => ( ) ,
259
+ _ => builder . move_to ( ) ,
218
260
}
219
-
220
- builder
221
261
}
222
262
223
- fn add_ip_rr ( & self , hostname : & Name , mut builder : AnswerBuilder , ttl : u32 ) -> AnswerBuilder {
263
+ fn ip_rr ( & self ) -> impl Iterator < Item = RRData < ' static > > + ' _ {
224
264
let interfaces = match get_if_addrs ( ) {
225
265
Ok ( interfaces) => interfaces,
226
266
Err ( err) => {
227
267
error ! ( "could not get list of interfaces: {}" , err) ;
228
- return builder ;
268
+ vec ! [ ]
229
269
}
230
270
} ;
231
-
232
- for iface in interfaces {
271
+ interfaces. into_iter ( ) . filter_map ( move |iface| {
233
272
if iface. is_loopback ( ) {
234
- continue ;
273
+ return None ;
235
274
}
236
275
237
276
trace ! ( "found interface {:?}" , iface) ;
238
277
if !self . allowed_ip . is_empty ( ) && !self . allowed_ip . contains ( & iface. ip ( ) ) {
239
278
trace ! ( " -> interface dropped" ) ;
240
- continue ;
279
+ return None ;
241
280
}
242
281
243
282
match ( iface. ip ( ) , AF :: DOMAIN ) {
244
- ( IpAddr :: V4 ( ip) , Domain :: IPV4 ) => {
245
- builder = builder. add_answer ( hostname, QueryClass :: IN , ttl, & RRData :: A ( ip) )
246
- }
247
- ( IpAddr :: V6 ( ip) , Domain :: IPV6 ) => {
248
- builder = builder. add_answer ( hostname, QueryClass :: IN , ttl, & RRData :: AAAA ( ip) )
249
- }
250
- _ => ( ) ,
283
+ ( IpAddr :: V4 ( ip) , Domain :: IPV4 ) => Some ( RRData :: A ( ip) ) ,
284
+ ( IpAddr :: V6 ( ip) , Domain :: IPV6 ) => Some ( RRData :: AAAA ( ip) ) ,
285
+ _ => None ,
251
286
}
252
- }
253
-
254
- builder
287
+ } )
255
288
}
256
289
257
290
fn send_unsolicited ( & mut self , svc : & ServiceData , ttl : u32 , include_ip : bool ) {
@@ -261,11 +294,17 @@ impl<AF: AddressFamily> FSM<AF> {
261
294
262
295
let services = self . services . read ( ) . unwrap ( ) ;
263
296
264
- builder = svc. add_ptr_rr ( builder, ttl) ;
265
- builder = svc. add_srv_rr ( services. get_hostname ( ) , builder, ttl) ;
266
- builder = svc. add_txt_rr ( builder, ttl) ;
297
+ builder = builder. add_answer ( & svc. typ , QueryClass :: IN , ttl, & svc. ptr_rr ( ) ) ;
298
+ builder = builder. add_answer (
299
+ & svc. name ,
300
+ QueryClass :: IN ,
301
+ ttl,
302
+ & svc. srv_rr ( services. get_hostname ( ) ) ,
303
+ ) ;
304
+ builder = builder. add_answer ( & svc. name , QueryClass :: IN , ttl, & svc. txt_rr ( ) ) ;
267
305
if include_ip {
268
- builder = self . add_ip_rr ( services. get_hostname ( ) , builder, ttl) ;
306
+ builder =
307
+ builder. add_answers ( services. get_hostname ( ) , QueryClass :: IN , ttl, self . ip_rr ( ) ) ;
269
308
}
270
309
271
310
if !builder. is_empty ( ) {
@@ -349,7 +388,7 @@ mod tests {
349
388
350
389
answer_builder = FSM :: < Inet > :: handle_service_type_enumeration (
351
390
& question,
352
- services. read ( ) . unwrap ( ) . into_iter ( ) ,
391
+ & services. read ( ) . unwrap ( ) ,
353
392
answer_builder,
354
393
) ;
355
394
0 commit comments