From 56db2b110bcfe91fcb0ff226098161a83d8a3298 Mon Sep 17 00:00:00 2001 From: AprilNEA Date: Tue, 31 Mar 2026 19:27:14 +0800 Subject: [PATCH] fix(net): store DNS response bytes in cache instead of empty records Parse answer resource records from raw DNS response bytes in cache_response() instead of storing Vec::new(). Extracts record type, TTL, and rdata (A/AAAA/raw) so cache hits return actual data. Uses minimum TTL from answer records for cache expiry. --- virt/arcbox-net/src/dns.rs | 273 ++++++++++++++++++++++++++++++++++++- 1 file changed, 268 insertions(+), 5 deletions(-) diff --git a/virt/arcbox-net/src/dns.rs b/virt/arcbox-net/src/dns.rs index 9f2a58a1..b826ef9c 100644 --- a/virt/arcbox-net/src/dns.rs +++ b/virt/arcbox-net/src/dns.rs @@ -617,18 +617,181 @@ impl DnsForwarder { Err(NetError::Dns("all upstream servers failed".to_string())) } - /// Caches a DNS response. - fn cache_response(&mut self, name: &str, qtype: DnsRecordType, _response: &[u8]) { - // Simplified caching - just store the query info + /// Caches a DNS response by parsing answer records from the raw bytes. + fn cache_response(&mut self, name: &str, qtype: DnsRecordType, response: &[u8]) { + let records = Self::parse_answer_records(response); + if records.is_empty() { + return; + } + + // Use the minimum TTL from the answer records, falling back to config. + let min_ttl = records + .iter() + .map(|r| r.ttl) + .min() + .unwrap_or(self.config.cache_ttl.as_secs() as u32); + let ttl = Duration::from_secs(u64::from(min_ttl)); + let key = (name.to_lowercase(), qtype); let entry = CacheEntry { - records: Vec::new(), // Would parse from response in full implementation + records, cached_at: Instant::now(), - ttl: self.config.cache_ttl, + ttl, }; self.cache.insert(key, entry); } + /// Parses answer resource records from a raw DNS response. + /// + /// Skips the header (12 bytes) and question section, then reads each + /// answer RR. Returns an empty vec on any parse failure -- the caller + /// simply skips caching in that case. + fn parse_answer_records(response: &[u8]) -> Vec { + if response.len() < 12 { + return Vec::new(); + } + + let ancount = u16::from_be_bytes([response[6], response[7]]) as usize; + if ancount == 0 { + return Vec::new(); + } + + // Skip past the question section. QDCOUNT is at bytes 4-5. + let qdcount = u16::from_be_bytes([response[4], response[5]]) as usize; + let mut offset = 12; + for _ in 0..qdcount { + if Self::skip_dns_name(response, &mut offset).is_err() { + return Vec::new(); + } + offset += 4; // QTYPE + QCLASS + if offset > response.len() { + return Vec::new(); + } + } + + // Parse answer records. + let mut records = Vec::with_capacity(ancount); + for _ in 0..ancount { + let Some(record) = Self::parse_one_rr(response, &mut offset) else { + break; + }; + records.push(record); + } + records + } + + /// Skips a DNS name (label sequence or compressed pointer) and advances + /// `offset` past it. Returns `Err` on malformed data. + fn skip_dns_name(data: &[u8], offset: &mut usize) -> std::result::Result<(), ()> { + loop { + if *offset >= data.len() { + return Err(()); + } + let b = data[*offset]; + if b == 0 { + *offset += 1; + return Ok(()); + } + // Compression pointer (two bytes). + if b & 0xC0 == 0xC0 { + *offset += 2; + return Ok(()); + } + // Normal label. + let len = b as usize; + *offset += 1 + len; + } + } + + /// Reads a DNS name (handling compression pointers) into a dotted string. + fn read_dns_name(data: &[u8], start: usize) -> Option { + let mut parts = Vec::new(); + let mut pos = start; + let mut jumps = 0; + loop { + if pos >= data.len() || jumps > 10 { + return None; + } + let b = data[pos]; + if b == 0 { + break; + } + if b & 0xC0 == 0xC0 { + if pos + 1 >= data.len() { + return None; + } + pos = u16::from_be_bytes([b & 0x3F, data[pos + 1]]) as usize; + jumps += 1; + continue; + } + let len = b as usize; + if pos + 1 + len > data.len() { + return None; + } + parts.push(String::from_utf8_lossy(&data[pos + 1..pos + 1 + len]).into_owned()); + pos += 1 + len; + } + Some(parts.join(".")) + } + + /// Parses a single resource record at `offset`, advancing it past the RR. + fn parse_one_rr(data: &[u8], offset: &mut usize) -> Option { + let name_start = *offset; + let name = Self::read_dns_name(data, name_start)?; + Self::skip_dns_name(data, offset).ok()?; + + if *offset + 10 > data.len() { + return None; + } + + let rtype_raw = u16::from_be_bytes([data[*offset], data[*offset + 1]]); + let class_raw = u16::from_be_bytes([data[*offset + 2], data[*offset + 3]]); + let ttl = u32::from_be_bytes([ + data[*offset + 4], + data[*offset + 5], + data[*offset + 6], + data[*offset + 7], + ]); + let rdlength = u16::from_be_bytes([data[*offset + 8], data[*offset + 9]]) as usize; + *offset += 10; + + if *offset + rdlength > data.len() { + return None; + } + + let rdata_bytes = &data[*offset..*offset + rdlength]; + *offset += rdlength; + + let rtype = DnsRecordType::try_from(rtype_raw).ok()?; + let class = if class_raw == 1 { + DnsClass::In + } else { + return None; + }; + + let rdata = match rtype { + DnsRecordType::A if rdlength == 4 => DnsRdata::A(Ipv4Addr::new( + rdata_bytes[0], + rdata_bytes[1], + rdata_bytes[2], + rdata_bytes[3], + )), + DnsRecordType::Aaaa if rdlength == 16 => { + let octets: [u8; 16] = rdata_bytes.try_into().ok()?; + DnsRdata::Aaaa(Ipv6Addr::from(octets)) + } + _ => DnsRdata::Raw(rdata_bytes.to_vec()), + }; + + Some(DnsRecord { + name, + rtype, + class, + ttl, + rdata, + }) + } + /// Clears the DNS cache. pub fn clear_cache(&mut self) { self.cache.clear(); @@ -975,4 +1138,104 @@ nameserver 10.0.0.2 let servers = parse_resolv_conf_nameservers(conf); assert!(servers.is_empty()); } + + /// Builds a minimal DNS response with one A record answer. + fn build_test_response(query: &[u8], ip: Ipv4Addr, ttl: u32) -> Vec { + let mut resp = Vec::with_capacity(64); + // Copy the 12-byte header from query. + resp.extend_from_slice(&query[..12]); + // QR=1, RD=1, RA=1, RCODE=0 + resp[2] = 0x81; + resp[3] = 0x80; + // ANCOUNT = 1 + resp[6] = 0x00; + resp[7] = 0x01; + + // Copy question section (everything after header). + resp.extend_from_slice(&query[12..]); + + // Answer: name pointer to offset 12 (question name). + resp.extend_from_slice(&[0xC0, 0x0C]); + // TYPE = A + resp.extend_from_slice(&[0x00, 0x01]); + // CLASS = IN + resp.extend_from_slice(&[0x00, 0x01]); + // TTL + resp.extend_from_slice(&ttl.to_be_bytes()); + // RDLENGTH = 4 + resp.extend_from_slice(&[0x00, 0x04]); + // RDATA + resp.extend_from_slice(&ip.octets()); + + resp + } + + #[test] + fn test_cache_response_stores_records() { + let config = DnsConfig::default(); + let mut forwarder = DnsForwarder::new(config); + + let query_pkt = build_test_query("example.com"); + let response_pkt = build_test_response(&query_pkt, Ipv4Addr::new(93, 184, 216, 34), 120); + + forwarder.cache_response("example.com", DnsRecordType::A, &response_pkt); + + // Cache should contain the entry with a non-empty records vec. + let cached = forwarder + .check_cache("example.com", DnsRecordType::A) + .expect("cache should contain the entry"); + assert_eq!(cached.len(), 1, "should have one cached record"); + + let rec = &cached[0]; + assert_eq!(rec.rtype, DnsRecordType::A); + assert_eq!(rec.ttl, 120); + match &rec.rdata { + DnsRdata::A(ip) => assert_eq!(*ip, Ipv4Addr::new(93, 184, 216, 34)), + other => panic!("expected A record, got {:?}", other), + } + } + + #[test] + fn test_cache_hit_returns_valid_response() { + let config = DnsConfig::default(); + let mut forwarder = DnsForwarder::new(config); + + let query_pkt = build_test_query("cached.test"); + let ip = Ipv4Addr::new(10, 0, 0, 42); + let response_pkt = build_test_response(&query_pkt, ip, 60); + + forwarder.cache_response("cached.test", DnsRecordType::A, &response_pkt); + + // handle_query should return the cached response without forwarding. + let result = forwarder + .handle_query(&query_pkt) + .expect("handle_query should succeed from cache"); + + // Verify the response is valid: QR=1, ANCOUNT=1, contains the IP. + assert_eq!(result[2] & 0x80, 0x80, "QR bit should be set"); + assert_eq!(result[7], 1, "ANCOUNT should be 1"); + // The A record RDATA (last 4 bytes of the answer) should match the IP. + let rdata_start = result.len() - 4; + assert_eq!(&result[rdata_start..], &ip.octets()); + } + + #[test] + fn test_cache_skips_empty_response() { + let config = DnsConfig::default(); + let mut forwarder = DnsForwarder::new(config); + + // Build a response with ANCOUNT=0 (no answers). + let query_pkt = build_test_query("empty.test"); + let mut response_pkt = query_pkt.clone(); + response_pkt[2] = 0x81; + response_pkt[3] = 0x80; + response_pkt[6] = 0x00; + response_pkt[7] = 0x00; // ANCOUNT=0 + + forwarder.cache_response("empty.test", DnsRecordType::A, &response_pkt); + + // Nothing should be cached for a response with no answers. + let cached = forwarder.check_cache("empty.test", DnsRecordType::A); + assert!(cached.is_none(), "should not cache empty responses"); + } }