|
1 | 1 | package httpserver |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bytes" |
4 | 5 | "crypto/tls" |
| 6 | + "fmt" |
5 | 7 | "os" |
6 | 8 | "sync" |
7 | 9 | "time" |
@@ -35,43 +37,76 @@ func newFileBasedCert(certFile, keyFile string, certCacheDuration time.Duration) |
35 | 37 | // GetCertificate returns a certificate from the cache, or loads it from disk if |
36 | 38 | // it is not cached yet or certCacheDuration has passed. |
37 | 39 | func (c *fileBasedCert) GetCertificate() (*tls.Certificate, error) { |
38 | | - now := time.Now() |
39 | | - |
40 | 40 | c.mutex.Lock() |
41 | 41 | defer c.mutex.Unlock() |
| 42 | + now := time.Now() |
42 | 43 |
|
43 | | - // Make sure we force a refresh when the certificate has expired |
44 | | - if c.cert != nil && c.cert.Leaf != nil && now.After(c.cert.Leaf.NotAfter) { |
45 | | - log.Warn().Msg("TLS certificate has expired, reloading.") |
46 | | - c.cert = nil |
47 | | - } |
| 44 | + reload := func() (*tls.Certificate, error) { |
| 45 | + c.cert = nil // make sure we don't return the old certificate. |
| 46 | + cert, err := tls.LoadX509KeyPair(c.certFile, c.keyFile) |
48 | 47 |
|
49 | | - if c.cert != nil { |
50 | | - // Check if the certificate file has been changed by comparing the last |
51 | | - // modification time with the time we last refreshed the certificate. |
52 | | - if fileInfo, err := os.Stat(c.certFile); err == nil && fileInfo.ModTime().After(c.lastRefresh) { |
53 | | - log.Warn().Msg("TLS certificate file has been changed, reloading.") |
54 | | - c.cert = nil |
| 48 | + switch { |
| 49 | + case err != nil: |
| 50 | + return nil, err |
| 51 | + case cert.Leaf == nil: |
| 52 | + return nil, fmt.Errorf("certificate leaf is nil") |
| 53 | + case now.After(cert.Leaf.NotAfter): |
| 54 | + // This is a warning on purpose, as we don't want to fail the |
| 55 | + // server startup if the certificate is expired. We will just keep |
| 56 | + // using the expired certificate, which will be result in an error |
| 57 | + // for the client. |
| 58 | + log.Warn().Msg("reloaded TLS certificate has already expired.") |
| 59 | + |
| 60 | + // When certCacheDuration is set to a value higher than one minute, |
| 61 | + // we will retry within the next minute. This is to make sure that |
| 62 | + // we don't keep using an expired certificate for too long. |
| 63 | + if c.certCacheDuration > time.Minute { |
| 64 | + now = now.Add(time.Minute - c.certCacheDuration) |
| 65 | + } |
55 | 66 | } |
| 67 | + |
| 68 | + c.cert = &cert |
| 69 | + c.lastRefresh = now |
| 70 | + return &cert, nil |
56 | 71 | } |
57 | 72 |
|
58 | | - // Load the certificate from disk if it is not cached yet or certCacheDuration |
59 | | - // has passed. |
60 | | - if c.cert == nil || now.Sub(c.lastRefresh) > c.certCacheDuration { |
61 | | - if c.cert != nil { |
62 | | - log.Warn().Msg("TLS cache duration has expired, reloading certificate from disk.") |
| 73 | + switch { |
| 74 | + // Load the certificate from disk if we don't have one cached yet. |
| 75 | + case c.cert == nil: |
| 76 | + log.Info().Msg("No TLS certificate cached, loading.") |
| 77 | + return reload() |
| 78 | + |
| 79 | + // Reload the certificate if it has expired. |
| 80 | + case now.After(c.cert.Leaf.NotAfter): |
| 81 | + log.Warn().Msg("TLS certificate has expired, reloading.") |
| 82 | + return reload() |
| 83 | + |
| 84 | + // Check for a new certificate in regular intervals. |
| 85 | + // We only change the loaded certificate if the signature has changed. |
| 86 | + case now.Sub(c.lastRefresh) > c.certCacheDuration: |
| 87 | + log.Info().Msg("TLS certificate cache duration has passed.") |
| 88 | + |
| 89 | + // Check if the certificate file has been changed since the last |
| 90 | + // refresh. This is a simple check that only compares the modification |
| 91 | + // time of the file. |
| 92 | + if fileInfo, err := os.Stat(c.certFile); err == nil && fileInfo.ModTime().After(c.lastRefresh) { |
| 93 | + log.Warn().Msg("TLS certificate file has been modifed since last refresh.") |
| 94 | + return reload() |
63 | 95 | } |
| 96 | + |
| 97 | + // Check if the certificate signature has changed since the last |
| 98 | + // refresh. This is a more expensive check that compares the signature |
| 99 | + // of the certificate. |
64 | 100 | cert, err := tls.LoadX509KeyPair(c.certFile, c.keyFile) |
65 | | - if err != nil { |
66 | | - return nil, err |
| 101 | + if err != nil || cert.Leaf == nil { |
| 102 | + log.Error().Err(err).Msg("Failed to load TLS certificate for comparison, keeping cached certificate.") |
| 103 | + return c.cert, nil |
67 | 104 | } |
68 | 105 |
|
69 | | - if cert.Leaf != nil && now.After(cert.Leaf.NotAfter) { |
70 | | - log.Error().Msg("Reloaded TLS certificate has already expired.") |
| 106 | + if !bytes.Equal(cert.Leaf.Signature, c.cert.Leaf.Signature) { |
| 107 | + log.Warn().Msg("Detected certificate signature change, reloading.") |
| 108 | + return reload() |
71 | 109 | } |
72 | | - |
73 | | - c.cert = &cert |
74 | | - c.lastRefresh = time.Now() |
75 | 110 | } |
76 | 111 |
|
77 | 112 | return c.cert, nil |
|
0 commit comments