Skip to content

Commit 61fe298

Browse files
fix: additional records should be used for unrequired records
Link: https://datatracker.ietf.org/doc/html/rfc6763#section-12
1 parent d605749 commit 61fe298

File tree

4 files changed

+179
-103
lines changed

4 files changed

+179
-103
lines changed

src/dns_parser/builder.rs

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,21 @@ impl<T: MoveTo<Answers>> Builder<T> {
195195

196196
builder
197197
}
198+
199+
pub fn add_answers<'a, 'b>(
200+
self,
201+
name: &Name,
202+
cls: QueryClass,
203+
ttl: u32,
204+
data: impl Iterator<Item = RRData<'b>> + 'a,
205+
) -> Builder<Answers> {
206+
let mut builder = self.move_to::<Answers>();
207+
for item in data {
208+
builder.write_rr(name, cls, ttl, &item);
209+
Header::inc_answers(&mut builder.buf).expect("Too many answers");
210+
}
211+
builder
212+
}
198213
}
199214

200215
impl<T: MoveTo<Nameservers>> Builder<T> {
@@ -213,10 +228,25 @@ impl<T: MoveTo<Nameservers>> Builder<T> {
213228

214229
builder
215230
}
231+
232+
#[allow(dead_code)]
233+
pub fn add_nameservers<'a, 'b>(
234+
self,
235+
name: &Name,
236+
cls: QueryClass,
237+
ttl: u32,
238+
data: impl Iterator<Item = RRData<'b>> + 'a,
239+
) -> Builder<Nameservers> {
240+
let mut builder = self.move_to::<Nameservers>();
241+
for item in data {
242+
builder.write_rr(name, cls, ttl, &item);
243+
Header::inc_nameservers(&mut builder.buf).expect("Too many nameservers");
244+
}
245+
builder
246+
}
216247
}
217248

218249
impl<T: MoveTo<Additional>> Builder<T> {
219-
#[allow(dead_code)]
220250
pub fn add_additional(
221251
self,
222252
name: &Name,
@@ -231,6 +261,21 @@ impl<T: MoveTo<Additional>> Builder<T> {
231261

232262
builder
233263
}
264+
265+
pub fn add_additionals<'a, 'b>(
266+
self,
267+
name: &Name,
268+
cls: QueryClass,
269+
ttl: u32,
270+
data: impl Iterator<Item = RRData<'b>> + 'a,
271+
) -> Builder<Additional> {
272+
let mut builder = self.move_to::<Additional>();
273+
for item in data {
274+
builder.write_rr(name, cls, ttl, &item);
275+
Header::inc_additional(&mut builder.buf).expect("Too many additional answers");
276+
}
277+
builder
278+
}
234279
}
235280

236281
#[cfg(test)]

src/dns_parser/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ pub use self::header::Header;
1212
mod rrdata;
1313
pub use self::rrdata::RRData;
1414
mod builder;
15-
pub use self::builder::{Answers, Builder, Questions};
15+
pub use self::builder::*;

src/fsm.rs

Lines changed: 116 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@ use tokio::{net::UdpSocket, sync::mpsc};
1818

1919
use super::{DEFAULT_TTL, MDNS_PORT};
2020
use crate::address_family::AddressFamily;
21-
use crate::services::{ServiceData, Services};
21+
use crate::services::{ServiceData, Services, ServicesInner};
2222

2323
pub type AnswerBuilder = dns_parser::Builder<dns_parser::Answers>;
24+
pub type AdditionalBuilder = dns_parser::Builder<dns_parser::Additional>;
2425

2526
const SERVICE_TYPE_ENUMERATION_NAME: Cow<'static, str> =
2627
Cow::Borrowed("_services._dns-sd._udp.local");
@@ -104,57 +105,46 @@ impl<AF: AddressFamily> FSM<AF> {
104105
return;
105106
}
106107

107-
let mut unicast_builder = dns_parser::Builder::new_response(packet.header.id, false, true)
108-
.move_to::<dns_parser::Answers>();
109-
let mut multicast_builder =
110-
dns_parser::Builder::new_response(packet.header.id, false, true)
111-
.move_to::<dns_parser::Answers>();
112-
unicast_builder.set_max_size(None);
113-
multicast_builder.set_max_size(None);
114-
115108
for question in packet.questions {
116109
debug!(
117110
"received question: {:?} {}",
118111
question.qclass, question.qname
119112
);
120113

121114
if question.qclass == QueryClass::IN || question.qclass == QueryClass::Any {
115+
let mut builder = dns_parser::Builder::new_response(packet.header.id, false, true)
116+
.move_to::<dns_parser::Answers>();
117+
builder.set_max_size(None);
118+
let builder = self.handle_question(&question, builder);
119+
if builder.is_empty() {
120+
continue;
121+
}
122+
let response = builder.build().unwrap_or_else(|x| x);
122123
if question.qu {
123-
unicast_builder = self.handle_question(&question, unicast_builder);
124+
self.outgoing.push_back((response, addr));
124125
} else {
125-
multicast_builder = self.handle_question(&question, multicast_builder);
126+
let addr = SocketAddr::new(AF::MDNS_GROUP.into(), MDNS_PORT);
127+
self.outgoing.push_back((response, addr));
126128
}
127129
}
128130
}
129-
130-
if !multicast_builder.is_empty() {
131-
let response = multicast_builder.build().unwrap_or_else(|x| x);
132-
let addr = SocketAddr::new(AF::MDNS_GROUP.into(), MDNS_PORT);
133-
self.outgoing.push_back((response, addr));
134-
}
135-
136-
if !unicast_builder.is_empty() {
137-
let response = unicast_builder.build().unwrap_or_else(|x| x);
138-
self.outgoing.push_back((response, addr));
139-
}
140131
}
141132

142133
/// https://www.rfc-editor.org/rfc/rfc6763#section-9
143134
fn handle_service_type_enumeration<'a>(
144135
question: &dns_parser::Question,
145-
services: impl Iterator<Item = &'a ServiceData>,
136+
services: &ServicesInner,
146137
mut builder: AnswerBuilder,
147138
) -> AnswerBuilder {
148139
let service_type_enumeration_name = Name::FromStr(SERVICE_TYPE_ENUMERATION_NAME);
149140
if question.qname == service_type_enumeration_name {
150-
for svc in services {
151-
let svc_type = ServiceData {
152-
name: svc.typ.clone(),
153-
typ: service_type_enumeration_name.clone(),
154-
port: svc.port,
155-
txt: vec![],
156-
};
157-
builder = svc_type.add_ptr_rr(builder, DEFAULT_TTL);
141+
for typ in services.all_types() {
142+
builder = builder.add_answer(
143+
&service_type_enumeration_name,
144+
QueryClass::IN,
145+
DEFAULT_TTL,
146+
&RRData::PTR(typ.clone()),
147+
);
158148
}
159149
}
160150

@@ -165,93 +155,136 @@ impl<AF: AddressFamily> FSM<AF> {
165155
&self,
166156
question: &dns_parser::Question,
167157
mut builder: AnswerBuilder,
168-
) -> AnswerBuilder {
158+
) -> AdditionalBuilder {
169159
let services = self.services.read().unwrap();
170160
let hostname = services.get_hostname();
171161

172162
match question.qtype {
173-
QueryType::A | QueryType::AAAA if question.qname == *hostname => {
174-
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
175-
}
163+
QueryType::A | QueryType::AAAA if question.qname == *hostname => builder
164+
.add_answers(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr())
165+
.move_to(),
176166
QueryType::All => {
167+
let mut include_ip_additionals = false;
177168
// A / AAAA
178169
if question.qname == *hostname {
179-
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
170+
builder =
171+
builder.add_answers(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr());
180172
}
181173
// PTR
182-
builder =
183-
Self::handle_service_type_enumeration(question, services.into_iter(), builder);
174+
builder = Self::handle_service_type_enumeration(question, &services, builder);
184175
for svc in services.find_by_type(&question.qname) {
185-
builder = svc.add_ptr_rr(builder, DEFAULT_TTL);
186-
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
187-
builder = svc.add_txt_rr(builder, DEFAULT_TTL);
188-
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
176+
builder =
177+
builder.add_answer(&svc.typ, QueryClass::IN, DEFAULT_TTL, &svc.ptr_rr());
178+
include_ip_additionals = true;
189179
}
190180
// SRV
191181
if let Some(svc) = services.find_by_name(&question.qname) {
192-
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
193-
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
182+
builder = builder.add_answer(
183+
&svc.name,
184+
QueryClass::IN,
185+
DEFAULT_TTL,
186+
&svc.srv_rr(hostname),
187+
);
188+
include_ip_additionals = true;
189+
}
190+
let mut builder = builder.move_to::<dns_parser::Additional>();
191+
// PTR (additional)
192+
for svc in services.find_by_type(&question.qname) {
193+
builder = builder
194+
.add_additional(
195+
&svc.name,
196+
QueryClass::IN,
197+
DEFAULT_TTL,
198+
&svc.srv_rr(hostname),
199+
)
200+
.add_additional(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr());
201+
include_ip_additionals = true;
202+
}
203+
204+
if include_ip_additionals {
205+
builder = builder.add_additionals(
206+
hostname,
207+
QueryClass::IN,
208+
DEFAULT_TTL,
209+
self.ip_rr(),
210+
);
194211
}
212+
builder
195213
}
196214
QueryType::PTR => {
197-
builder =
198-
Self::handle_service_type_enumeration(question, services.into_iter(), builder);
215+
let mut builder =
216+
Self::handle_service_type_enumeration(question, &services, builder);
217+
for svc in services.find_by_type(&question.qname) {
218+
builder =
219+
builder.add_answer(&svc.typ, QueryClass::IN, DEFAULT_TTL, &svc.ptr_rr())
220+
}
221+
let mut builder = builder.move_to::<dns_parser::Additional>();
199222
for svc in services.find_by_type(&question.qname) {
200-
builder = svc.add_ptr_rr(builder, DEFAULT_TTL);
201-
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
202-
builder = svc.add_txt_rr(builder, DEFAULT_TTL);
203-
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
223+
builder = builder
224+
.add_additional(
225+
&svc.name,
226+
QueryClass::IN,
227+
DEFAULT_TTL,
228+
&svc.srv_rr(hostname),
229+
)
230+
.add_additional(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr())
231+
.add_additionals(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr());
204232
}
233+
builder
205234
}
206235
QueryType::SRV => {
207236
if let Some(svc) = services.find_by_name(&question.qname) {
208-
builder = svc.add_srv_rr(hostname, builder, DEFAULT_TTL);
209-
builder = self.add_ip_rr(hostname, builder, DEFAULT_TTL);
237+
builder
238+
.add_answer(
239+
&svc.name,
240+
QueryClass::IN,
241+
DEFAULT_TTL,
242+
&svc.srv_rr(hostname),
243+
)
244+
.add_additionals(hostname, QueryClass::IN, DEFAULT_TTL, self.ip_rr())
245+
.move_to()
246+
} else {
247+
builder.move_to()
210248
}
211249
}
212250
QueryType::TXT => {
213251
if let Some(svc) = services.find_by_name(&question.qname) {
214-
builder = svc.add_txt_rr(builder, DEFAULT_TTL);
252+
builder
253+
.add_answer(&svc.name, QueryClass::IN, DEFAULT_TTL, &svc.txt_rr())
254+
.move_to()
255+
} else {
256+
builder.move_to()
215257
}
216258
}
217-
_ => (),
259+
_ => builder.move_to(),
218260
}
219-
220-
builder
221261
}
222262

223-
fn add_ip_rr(&self, hostname: &Name, mut builder: AnswerBuilder, ttl: u32) -> AnswerBuilder {
263+
fn ip_rr(&self) -> impl Iterator<Item = RRData<'static>> + '_ {
224264
let interfaces = match get_if_addrs() {
225265
Ok(interfaces) => interfaces,
226266
Err(err) => {
227267
error!("could not get list of interfaces: {}", err);
228-
return builder;
268+
vec![]
229269
}
230270
};
231-
232-
for iface in interfaces {
271+
interfaces.into_iter().filter_map(move |iface| {
233272
if iface.is_loopback() {
234-
continue;
273+
return None;
235274
}
236275

237276
trace!("found interface {:?}", iface);
238277
if !self.allowed_ip.is_empty() && !self.allowed_ip.contains(&iface.ip()) {
239278
trace!(" -> interface dropped");
240-
continue;
279+
return None;
241280
}
242281

243282
match (iface.ip(), AF::DOMAIN) {
244-
(IpAddr::V4(ip), Domain::IPV4) => {
245-
builder = builder.add_answer(hostname, QueryClass::IN, ttl, &RRData::A(ip))
246-
}
247-
(IpAddr::V6(ip), Domain::IPV6) => {
248-
builder = builder.add_answer(hostname, QueryClass::IN, ttl, &RRData::AAAA(ip))
249-
}
250-
_ => (),
283+
(IpAddr::V4(ip), Domain::IPV4) => Some(RRData::A(ip)),
284+
(IpAddr::V6(ip), Domain::IPV6) => Some(RRData::AAAA(ip)),
285+
_ => None,
251286
}
252-
}
253-
254-
builder
287+
})
255288
}
256289

257290
fn send_unsolicited(&mut self, svc: &ServiceData, ttl: u32, include_ip: bool) {
@@ -261,11 +294,17 @@ impl<AF: AddressFamily> FSM<AF> {
261294

262295
let services = self.services.read().unwrap();
263296

264-
builder = svc.add_ptr_rr(builder, ttl);
265-
builder = svc.add_srv_rr(services.get_hostname(), builder, ttl);
266-
builder = svc.add_txt_rr(builder, ttl);
297+
builder = builder.add_answer(&svc.typ, QueryClass::IN, ttl, &svc.ptr_rr());
298+
builder = builder.add_answer(
299+
&svc.name,
300+
QueryClass::IN,
301+
ttl,
302+
&svc.srv_rr(services.get_hostname()),
303+
);
304+
builder = builder.add_answer(&svc.name, QueryClass::IN, ttl, &svc.txt_rr());
267305
if include_ip {
268-
builder = self.add_ip_rr(services.get_hostname(), builder, ttl);
306+
builder =
307+
builder.add_answers(services.get_hostname(), QueryClass::IN, ttl, self.ip_rr());
269308
}
270309

271310
if !builder.is_empty() {
@@ -349,7 +388,7 @@ mod tests {
349388

350389
answer_builder = FSM::<Inet>::handle_service_type_enumeration(
351390
&question,
352-
services.read().unwrap().into_iter(),
391+
&services.read().unwrap(),
353392
answer_builder,
354393
);
355394

0 commit comments

Comments
 (0)