Skip to content

Commit 61f2f8b

Browse files
committed
Patch: if A-ttl is not expired but AAAA-ttl is expired, we should only send AAAA-query and vice versa
1. if A-ttl is not expired but AAAA-ttl is expired, we should only send AAAA-query and vice versa 2. `sendQuery` send each query in new goroutine so there is no need to run it in new goroutine.
1 parent cd4f1cd commit 61f2f8b

File tree

5 files changed

+85
-37
lines changed

5 files changed

+85
-37
lines changed

app/dns/cache_controller.go

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -102,20 +102,8 @@ func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) {
102102
switch req.reqType {
103103
case dnsmessage.TypeA:
104104
c.pub.Publish(req.domain+"4", nil)
105-
if !c.disableCache {
106-
_, _, err := rec.AAAA.getIPs()
107-
if !go_errors.Is(err, errRecordNotFound) {
108-
c.pub.Publish(req.domain+"6", nil)
109-
}
110-
}
111105
case dnsmessage.TypeAAAA:
112106
c.pub.Publish(req.domain+"6", nil)
113-
if !c.disableCache {
114-
_, _, err := rec.A.getIPs()
115-
if !go_errors.Is(err, errRecordNotFound) {
116-
c.pub.Publish(req.domain+"4", nil)
117-
}
118-
}
119107
}
120108

121109
c.Unlock()
@@ -124,13 +112,13 @@ func (c *CacheController) updateIP(req *dnsRequest, ipRec *IPRecord) {
124112
}
125113
}
126114

127-
func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, int32, error) {
115+
func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, int32, bool, bool, error) {
128116
c.RLock()
129117
record, found := c.ips[domain]
130118
c.RUnlock()
131119

132120
if !found {
133-
return nil, 0, errRecordNotFound
121+
return nil, 0, true, true, errRecordNotFound
134122
}
135123

136124
var errs []error
@@ -139,10 +127,14 @@ func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPO
139127

140128
mergeReq := option.IPv4Enable && option.IPv6Enable
141129

130+
isARecordExpired := true
142131
if option.IPv4Enable {
143132
ips, ttl, err := record.A.getIPs()
144-
if !mergeReq || go_errors.Is(err, errRecordNotFound) {
145-
return ips, ttl, err
133+
if ttl > 0 {
134+
isARecordExpired = false
135+
}
136+
if !mergeReq {
137+
return ips, ttl, isARecordExpired, true, err
146138
}
147139
if ttl < rTTL {
148140
rTTL = ttl
@@ -154,10 +146,14 @@ func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPO
154146
}
155147
}
156148

149+
isAAAARecordExpired := true
157150
if option.IPv6Enable {
158151
ips, ttl, err := record.AAAA.getIPs()
159-
if !mergeReq || go_errors.Is(err, errRecordNotFound) {
160-
return ips, ttl, err
152+
if ttl > 0 {
153+
isAAAARecordExpired = false
154+
}
155+
if !mergeReq {
156+
return ips, ttl, true, isAAAARecordExpired, err
161157
}
162158
if ttl < rTTL {
163159
rTTL = ttl
@@ -169,13 +165,17 @@ func (c *CacheController) findIPsForDomain(domain string, option dns_feature.IPO
169165
}
170166
}
171167

168+
if go_errors.Is(errs[0], errRecordNotFound) || go_errors.Is(errs[1], errRecordNotFound) {
169+
return nil, 0, isARecordExpired, isAAAARecordExpired, errRecordNotFound
170+
}
171+
172172
if len(allIPs) > 0 {
173-
return allIPs, rTTL, nil
173+
return allIPs, rTTL, isARecordExpired, isAAAARecordExpired, nil
174174
}
175175
if go_errors.Is(errs[0], errs[1]) {
176-
return nil, rTTL, errs[0]
176+
return nil, rTTL, isARecordExpired, isAAAARecordExpired, errs[0]
177177
}
178-
return nil, rTTL, errors.Combine(errs...)
178+
return nil, rTTL, isARecordExpired, isAAAARecordExpired, errors.Combine(errs...)
179179
}
180180

181181
func (c *CacheController) registerSubscribers(domain string, option dns_feature.IPOption) (sub4 *pubsub.Subscriber, sub6 *pubsub.Subscriber) {

app/dns/nameserver_doh.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,22 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_f
229229
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
230230
defer closeSubscribers(sub4, sub6)
231231

232+
queryOption := option
233+
232234
if s.cacheController.disableCache {
233235
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
234236
} else {
235-
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
237+
ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option)
238+
if sub4 != nil && !isARecordExpired {
239+
sub4.Close()
240+
sub4 = nil
241+
queryOption.IPv4Enable = false
242+
}
243+
if sub6 != nil && !isAAAARecordExpired {
244+
sub6.Close()
245+
sub6 = nil
246+
queryOption.IPv6Enable = false
247+
}
236248
if !go_errors.Is(err, errRecordNotFound) {
237249
if ttl > 0 {
238250
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
@@ -241,14 +253,14 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_f
241253
}
242254
if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) {
243255
errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips)
244-
go s.sendQuery(ctx, nil, fqdn, option)
256+
s.sendQuery(ctx, nil, fqdn, queryOption)
245257
return ips, 1, err
246258
}
247259
}
248260
}
249261

250262
noResponseErrCh := make(chan error, 2)
251-
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
263+
s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption)
252264
start := time.Now()
253265

254266
if sub4 != nil {
@@ -272,7 +284,7 @@ func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option dns_f
272284
}
273285
}
274286

275-
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
287+
ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option)
276288
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
277289
var rTTL uint32
278290
if ttl <= 0 {

app/dns/nameserver_quic.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,22 @@ func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_
196196
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
197197
defer closeSubscribers(sub4, sub6)
198198

199+
queryOption := option
200+
199201
if s.cacheController.disableCache {
200202
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
201203
} else {
202-
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
204+
ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option)
205+
if sub4 != nil && !isARecordExpired {
206+
sub4.Close()
207+
sub4 = nil
208+
queryOption.IPv4Enable = false
209+
}
210+
if sub6 != nil && !isAAAARecordExpired {
211+
sub6.Close()
212+
sub6 = nil
213+
queryOption.IPv6Enable = false
214+
}
203215
if !go_errors.Is(err, errRecordNotFound) {
204216
if ttl > 0 {
205217
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
@@ -208,14 +220,14 @@ func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_
208220
}
209221
if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) {
210222
errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips)
211-
go s.sendQuery(ctx, nil, fqdn, option)
223+
s.sendQuery(ctx, nil, fqdn, queryOption)
212224
return ips, 1, err
213225
}
214226
}
215227
}
216228

217229
noResponseErrCh := make(chan error, 2)
218-
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
230+
s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption)
219231
start := time.Now()
220232

221233
if sub4 != nil {
@@ -239,7 +251,7 @@ func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, option dns_
239251
}
240252
}
241253

242-
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
254+
ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option)
243255
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
244256
var rTTL uint32
245257
if ttl <= 0 {

app/dns/nameserver_tcp.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,22 @@ func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_f
224224
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
225225
defer closeSubscribers(sub4, sub6)
226226

227+
queryOption := option
228+
227229
if s.cacheController.disableCache {
228230
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
229231
} else {
230-
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
232+
ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option)
233+
if sub4 != nil && !isARecordExpired {
234+
sub4.Close()
235+
sub4 = nil
236+
queryOption.IPv4Enable = false
237+
}
238+
if sub6 != nil && !isAAAARecordExpired {
239+
sub6.Close()
240+
sub6 = nil
241+
queryOption.IPv6Enable = false
242+
}
231243
if !go_errors.Is(err, errRecordNotFound) {
232244
if ttl > 0 {
233245
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
@@ -236,14 +248,14 @@ func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_f
236248
}
237249
if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) {
238250
errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips)
239-
go s.sendQuery(ctx, nil, fqdn, option)
251+
s.sendQuery(ctx, nil, fqdn, queryOption)
240252
return ips, 1, err
241253
}
242254
}
243255
}
244256

245257
noResponseErrCh := make(chan error, 2)
246-
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
258+
s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption)
247259
start := time.Now()
248260

249261
if sub4 != nil {
@@ -267,7 +279,7 @@ func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, option dns_f
267279
}
268280
}
269281

270-
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
282+
ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option)
271283
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
272284
var rTTL uint32
273285
if ttl <= 0 {

app/dns/nameserver_udp.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,22 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option d
174174
sub4, sub6 := s.cacheController.registerSubscribers(fqdn, option)
175175
defer closeSubscribers(sub4, sub6)
176176

177+
queryOption := option
178+
177179
if s.cacheController.disableCache {
178180
errors.LogDebug(ctx, "DNS cache is disabled. Querying IP for ", domain, " at ", s.Name())
179181
} else {
180-
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
182+
ips, ttl, isARecordExpired, isAAAARecordExpired, err := s.cacheController.findIPsForDomain(fqdn, option)
183+
if sub4 != nil && !isARecordExpired {
184+
sub4.Close()
185+
sub4 = nil
186+
queryOption.IPv4Enable = false
187+
}
188+
if sub6 != nil && !isAAAARecordExpired {
189+
sub6.Close()
190+
sub6 = nil
191+
queryOption.IPv6Enable = false
192+
}
181193
if !go_errors.Is(err, errRecordNotFound) {
182194
if ttl > 0 {
183195
errors.LogDebugInner(ctx, err, s.Name(), " cache HIT ", domain, " -> ", ips)
@@ -186,14 +198,14 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option d
186198
}
187199
if s.cacheController.serveStale && (s.cacheController.serveExpiredTTL == 0 || s.cacheController.serveExpiredTTL < ttl) {
188200
errors.LogDebugInner(ctx, err, s.Name(), " cache OPTIMISTE ", domain, " -> ", ips)
189-
go s.sendQuery(ctx, nil, fqdn, option)
201+
s.sendQuery(ctx, nil, fqdn, queryOption)
190202
return ips, 1, err
191203
}
192204
}
193205
}
194206

195207
noResponseErrCh := make(chan error, 2)
196-
s.sendQuery(ctx, noResponseErrCh, fqdn, option)
208+
s.sendQuery(ctx, noResponseErrCh, fqdn, queryOption)
197209
start := time.Now()
198210

199211
if sub4 != nil {
@@ -217,7 +229,7 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option d
217229
}
218230
}
219231

220-
ips, ttl, err := s.cacheController.findIPsForDomain(fqdn, option)
232+
ips, ttl, _, _, err := s.cacheController.findIPsForDomain(fqdn, option)
221233
log.Record(&log.DNSLog{Server: s.Name(), Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err})
222234
var rTTL uint32
223235
if ttl <= 0 {

0 commit comments

Comments
 (0)