diff --git a/telegram/cdn.go b/telegram/cdn.go index 49f68a00b8..111c77e8ce 100644 --- a/telegram/cdn.go +++ b/telegram/cdn.go @@ -1,17 +1,60 @@ package telegram import ( + "context" "crypto/rsa" "encoding/pem" "github.com/go-faster/errors" "github.com/gotd/td/crypto" + "github.com/gotd/td/exchange" "github.com/gotd/td/tg" ) -func parseCDNKeys(keys ...tg.CDNPublicKey) ([]*rsa.PublicKey, error) { - r := make([]*rsa.PublicKey, 0, len(keys)) +// Single key because help.getCdnConfig has no request params. +const helpGetCDNConfigSingleflightKey = "help.getCdnConfig" + +type cdnKeyEntry struct { + dcID int + key *rsa.PublicKey +} + +type fetchedCDNKeys struct { + all []exchange.PublicKey + byDC map[int][]exchange.PublicKey +} + +func clonePublicKeys(keys []exchange.PublicKey) []exchange.PublicKey { + return append([]exchange.PublicKey(nil), keys...) +} + +func mergePublicKeys(primary, fallback []exchange.PublicKey) []exchange.PublicKey { + if len(primary) == 0 && len(fallback) == 0 { + return nil + } + + out := make([]exchange.PublicKey, 0, len(primary)+len(fallback)) + seen := make(map[int64]struct{}, len(primary)+len(fallback)) + appendUnique := func(keys []exchange.PublicKey) { + for _, key := range keys { + fp := key.Fingerprint() + if _, ok := seen[fp]; ok { + continue + } + seen[fp] = struct{}{} + out = append(out, key) + } + } + + // Prefer primary keyset order and use fallback only for missing fingerprints. + appendUnique(primary) + appendUnique(fallback) + return out +} + +func parseCDNKeyEntries(keys ...tg.CDNPublicKey) ([]cdnKeyEntry, error) { + r := make([]cdnKeyEntry, 0, len(keys)) for _, key := range keys { block, _ := pem.Decode([]byte(key.PublicKey)) @@ -19,13 +62,230 @@ func parseCDNKeys(keys ...tg.CDNPublicKey) ([]*rsa.PublicKey, error) { continue } - key, err := crypto.ParseRSA(block.Bytes) + parsedKey, err := crypto.ParseRSA(block.Bytes) if err != nil { return nil, errors.Wrap(err, "parse RSA from PEM") } - r = append(r, key) + r = append(r, cdnKeyEntry{ + dcID: key.DCID, + key: parsedKey, + }) } return r, nil } + +func buildCDNKeysCache(entries []cdnKeyEntry) fetchedCDNKeys { + result := fetchedCDNKeys{ + all: make([]exchange.PublicKey, 0, len(entries)), + byDC: make(map[int][]exchange.PublicKey), + } + + seenAll := make(map[int64]struct{}, len(entries)) + seenByDC := make(map[int]map[int64]struct{}) + + for _, entry := range entries { + key := exchange.PublicKey{RSA: entry.key} + fingerprint := key.Fingerprint() + + if _, ok := seenAll[fingerprint]; !ok { + seenAll[fingerprint] = struct{}{} + result.all = append(result.all, key) + } + + seen, ok := seenByDC[entry.dcID] + if !ok { + seen = map[int64]struct{}{} + seenByDC[entry.dcID] = seen + } + if _, ok := seen[fingerprint]; ok { + continue + } + seen[fingerprint] = struct{}{} + result.byDC[entry.dcID] = append(result.byDC[entry.dcID], key) + } + + return result +} + +func copyCDNKeysByDC(byDC map[int][]exchange.PublicKey) map[int][]exchange.PublicKey { + if len(byDC) == 0 { + return nil + } + + r := make(map[int][]exchange.PublicKey, len(byDC)) + for dcID, keys := range byDC { + r[dcID] = append([]exchange.PublicKey(nil), keys...) + } + return r +} + +func cloneFetchedCDNKeys(keys fetchedCDNKeys) fetchedCDNKeys { + return fetchedCDNKeys{ + all: clonePublicKeys(keys.all), + byDC: copyCDNKeysByDC(keys.byDC), + } +} + +func (c *Client) cachedCDNKeys() ([]exchange.PublicKey, bool, uint64) { + c.cdnKeysMux.Lock() + defer c.cdnKeysMux.Unlock() + + return clonePublicKeys(c.cdnKeys), c.cdnKeysSet, c.cdnKeysGen +} + +func (c *Client) cachedCDNKeysForDC(dcID int) ([]exchange.PublicKey, bool) { + c.cdnKeysMux.Lock() + defer c.cdnKeysMux.Unlock() + + return clonePublicKeys(c.cdnKeysByDC[dcID]), c.cdnKeysSet +} + +func (c *Client) cdnConfigFetchContext(caller context.Context) context.Context { + if c.ctx != nil { + // Bind network request lifetime to client lifecycle, not to the first + // singleflight caller. + return c.ctx + } + + // Caller cancellation is handled outside singleflight wait loop; request + // itself should not inherit first caller deadline/cancellation. + return context.WithoutCancel(caller) +} + +func (c *Client) loadCDNKeys(ctx context.Context) (fetchedCDNKeys, error) { + resultCh := c.cdnKeysLoad.DoChan(helpGetCDNConfigSingleflightKey, func() (interface{}, error) { + // singleflight ensures only one goroutine issues help.getCdnConfig; + // others wait and reuse same result. + cfg, err := c.tg.HelpGetCDNConfig(c.cdnConfigFetchContext(ctx)) + if err != nil { + return nil, errors.Wrap(err, "help.getCdnConfig") + } + + entries, err := parseCDNKeyEntries(cfg.PublicKeys...) + if err != nil { + return nil, errors.Wrap(err, "parse CDN public keys") + } + return buildCDNKeysCache(entries), nil + }) + + select { + case <-ctx.Done(): + return fetchedCDNKeys{}, ctx.Err() + case result := <-resultCh: + if result.Err != nil { + return fetchedCDNKeys{}, result.Err + } + + keys, ok := result.Val.(fetchedCDNKeys) + if !ok { + return fetchedCDNKeys{}, errors.Errorf("unexpected CDN keys type %T", result.Val) + } + return cloneFetchedCDNKeys(keys), nil + } +} + +func (c *Client) fetchCDNKeys(ctx context.Context) ([]exchange.PublicKey, error) { + const maxVersionRetries = 3 + for attempt := 0; attempt < maxVersionRetries; attempt++ { + // Fast path: fully cached, no network requests. + cached, set, startGen := c.cachedCDNKeys() + if set { + return cached, nil + } + // Snapshot generation to detect invalidation races after in-flight load. + + keys, err := c.loadCDNKeys(ctx) + if err != nil { + return nil, err + } + + c.cdnKeysMux.Lock() + switch { + case c.cdnKeysSet: + // Another goroutine already populated cache while we were waiting. + cached := clonePublicKeys(c.cdnKeys) + c.cdnKeysMux.Unlock() + return cached, nil + case c.cdnKeysGen != startGen: + // Cache was invalidated (fingerprint miss) during in-flight request. + // Discard stale result and retry from fresh generation. + c.cdnKeysMux.Unlock() + continue + default: + // Safe to commit fetched keys into cache. + c.cdnKeys = clonePublicKeys(keys.all) + c.cdnKeysByDC = copyCDNKeysByDC(keys.byDC) + c.cdnKeysSet = true + cached := clonePublicKeys(c.cdnKeys) + c.cdnKeysMux.Unlock() + return cached, nil + } + } + + return nil, errors.New("cdn keys cache changed concurrently") +} + +func (c *Client) refreshCDNKeys(ctx context.Context) ([]exchange.PublicKey, error) { + const maxVersionRetries = 3 + for attempt := 0; attempt < maxVersionRetries; attempt++ { + c.cdnKeysMux.Lock() + startGen := c.cdnKeysGen + c.cdnKeysMux.Unlock() + + keys, err := c.loadCDNKeys(ctx) + if err != nil { + return nil, err + } + + c.cdnKeysMux.Lock() + if c.cdnKeysGen != startGen { + // Fingerprint invalidation happened while refresh was in-flight. + // Discard stale result and refetch for fresh generation. + c.cdnKeysMux.Unlock() + continue + } + c.cdnKeys = clonePublicKeys(keys.all) + c.cdnKeysByDC = copyCDNKeysByDC(keys.byDC) + c.cdnKeysSet = true + cached := clonePublicKeys(c.cdnKeys) + c.cdnKeysMux.Unlock() + + return cached, nil + } + + return nil, errors.New("cdn keys cache changed concurrently") +} + +func (c *Client) fetchCDNKeysForDC(ctx context.Context, dcID int) ([]exchange.PublicKey, error) { + keys, set := c.cachedCDNKeysForDC(dcID) + if !set { + if _, err := c.fetchCDNKeys(ctx); err != nil { + return nil, err + } + } + + const maxRefreshAttempts = 3 + for attempt := 0; attempt < maxRefreshAttempts; attempt++ { + if err := ctx.Err(); err != nil { + return nil, err + } + + keys, _ = c.cachedCDNKeysForDC(dcID) + if len(keys) > 0 { + return keys, nil + } + if attempt == maxRefreshAttempts-1 { + break + } + + // Requested CDN DC is missing in current snapshot; retry bounded + // help.getCdnConfig refreshes to handle eventual config propagation. + if _, err := c.refreshCDNKeys(ctx); err != nil { + return nil, err + } + } + + return nil, errors.Errorf("no CDN public keys for CDN DC %d after %d refresh attempts", dcID, maxRefreshAttempts) +} diff --git a/telegram/cdn_conn_dead.go b/telegram/cdn_conn_dead.go new file mode 100644 index 0000000000..e1c8a4f023 --- /dev/null +++ b/telegram/cdn_conn_dead.go @@ -0,0 +1,60 @@ +package telegram + +import ( + "github.com/go-faster/errors" + "go.uber.org/zap" + + "github.com/gotd/td/exchange" + "github.com/gotd/td/mtproto" + "github.com/gotd/td/pool" +) + +func (c *Client) handleCDNConnDead(dcID int, err error) { + if errors.Is(err, exchange.ErrKeyFingerprintNotFound) { + c.log.Warn("Resetting cached CDN keys after fingerprint miss", + zap.Int("dc_id", dcID), + ) + c.cdnKeysMux.Lock() + c.cdnKeys = nil + c.cdnKeysByDC = nil + c.cdnKeysSet = false + // Bump generation so in-flight help.getCdnConfig results are discarded if + // they were started before invalidation. + c.cdnKeysGen++ + c.cdnKeysMux.Unlock() + // Drop current singleflight entry so next attempt refetches keys. + c.cdnKeysLoad.Forget(helpGetCDNConfigSingleflightKey) + + // Close asynchronously: callback may be invoked from pool worker + // goroutine, and synchronous self-close can deadlock on Wait(). + // Queue closes through bounded workers to avoid unbounded goroutine fan-out. + c.cdnPools.invalidateDC(dcID) + // Fingerprint miss is recoverable and handled internally by invalidation + // + reconnect with fresh keys, no external onDead signal is needed. + return + } + + if !errors.Is(err, mtproto.ErrPFSDropKeysRequired) { + // Keep legacy callback semantics for all non-PFS errors. + if c.onDead != nil { + c.onDead(err) + } + return + } + + c.log.Warn("Dropping stored CDN session key after PFS key reset request", + zap.Int("dc_id", dcID), + ) + c.sessionsMux.Lock() + s, ok := c.cdnSessions[dcID] + if !ok { + s = pool.NewSyncSession(pool.Session{DC: dcID}) + c.cdnSessions[dcID] = s + } + s.Store(pool.Session{DC: dcID}) + c.sessionsMux.Unlock() + + if c.onDead != nil { + c.onDead(err) + } +} diff --git a/telegram/cdn_pool_manager.go b/telegram/cdn_pool_manager.go new file mode 100644 index 0000000000..1d920bc4be --- /dev/null +++ b/telegram/cdn_pool_manager.go @@ -0,0 +1,334 @@ +package telegram + +import ( + "context" + "math/bits" + "sync" + "sync/atomic" + + "github.com/go-faster/errors" + + "github.com/gotd/td/bin" +) + +type cachedCDNPool struct { + conn CloseInvoker + // max is normalized bucket size used for reuse matching. + max int64 +} + +var ( + errCDNPoolHandleClosed = errors.New("CDN pool handle is closed") + errCDNPoolHandleDouble = errors.New("CDN pool handle already closed") +) + +// cdnPoolHandle is a per-call wrapper around shared cached CDN pool. +// Close() releases only this borrowed handle; underlying pool is managed by +// client cache lifecycle (fingerprint invalidation or client shutdown). +type cdnPoolHandle struct { + manager *cdnPoolManager + conn CloseInvoker + closed atomic.Bool +} + +var _ CloseInvoker = (*cdnPoolHandle)(nil) + +func (h *cdnPoolHandle) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder) error { + if h.closed.Load() { + return errCDNPoolHandleClosed + } + return h.conn.Invoke(ctx, input, output) +} + +func (h *cdnPoolHandle) Close() error { + if h.closed.Swap(true) { + return errCDNPoolHandleDouble + } + if !h.manager.releaseCachedHandle(h.conn) { + return nil + } + return h.conn.Close() +} + +type cdnPoolManager struct { + mux sync.Mutex + + conns map[int][]cachedCDNPool + // refs tracks active handle references for shared CDN pools. + refs map[CloseInvoker]int + // closing tracks pools already known to close pipeline to avoid duplicate + // queue entries and duplicate Close() calls. + // + // Value denotes whether conn is currently queued for worker processing. + closing map[CloseInvoker]bool + + // closeQueue contains stale CDN pools waiting for async close. + // Close() may block on unstable network/proxy, so queue is processed by + // bounded worker count. + closeQueue []CloseInvoker + closePending []CloseInvoker + closeWorkers int + closeBusy int +} + +const ( + maxCDNCloseWorkers = 4 + // Historical sizing hint for close backlog in tests and heuristics. + // Queue growth is controlled by stale-pool production rate and de-dup via + // closing map, while close fan-out remains bounded by workers. + maxCDNCloseQueue = 256 +) + +func newCDNPoolManager() cdnPoolManager { + return cdnPoolManager{ + conns: map[int][]cachedCDNPool{}, + refs: map[CloseInvoker]int{}, + closing: map[CloseInvoker]bool{}, + } +} + +func (p cachedCDNPool) covers(need int64) bool { + // pool max < 1 means unlimited in pool package. + if p.max < 1 { + return true + } + // Requested max < 1 means unlimited, finite pool does not satisfy it. + if need < 1 { + return false + } + return p.max >= need +} + +func pickCDNPool(pools []cachedCDNPool, need int64) (CloseInvoker, bool) { + // Pick the smallest pool that still covers requested capacity. + best := -1 + for i, p := range pools { + if !p.covers(need) { + continue + } + if best == -1 { + best = i + continue + } + // Prefer tighter finite limit to avoid over-allocating. + if pools[best].max < 1 { + best = i + continue + } + if p.max > 0 && p.max < pools[best].max { + best = i + } + } + if best == -1 { + return nil, false + } + return pools[best].conn, true +} + +func (m *cdnPoolManager) cachedHandleLocked(conn CloseInvoker) CloseInvoker { + refs, ok := m.refs[conn] + if !ok { + // Keep one cache-owner reference so pool can be reused between + // sequential downloads. + refs = 1 + } + m.refs[conn] = refs + 1 + + return &cdnPoolHandle{ + manager: m, + conn: conn, + } +} + +func (m *cdnPoolManager) releaseCachedHandle(conn CloseInvoker) bool { + m.mux.Lock() + defer m.mux.Unlock() + + refs, ok := m.refs[conn] + if !ok || refs < 1 { + // Connection is already evicted/closed by another path. + return false + } + refs-- + if refs == 0 { + delete(m.refs, conn) + return true + } + m.refs[conn] = refs + return false +} + +func (m *cdnPoolManager) acquire(dc int, need int64) (CloseInvoker, bool) { + m.mux.Lock() + defer m.mux.Unlock() + + cached, ok := pickCDNPool(m.conns[dc], need) + if !ok { + return nil, false + } + return m.cachedHandleLocked(cached), true +} + +func (m *cdnPoolManager) publishOrAcquire(dc int, need int64, created CloseInvoker) (CloseInvoker, bool) { + m.mux.Lock() + defer m.mux.Unlock() + + if existing, ok := pickCDNPool(m.conns[dc], need); ok { + return m.cachedHandleLocked(existing), true + } + m.conns[dc] = append(m.conns[dc], cachedCDNPool{ + conn: created, + max: need, + }) + return m.cachedHandleLocked(created), false +} + +func (m *cdnPoolManager) drain() []CloseInvoker { + m.mux.Lock() + defer m.mux.Unlock() + + seen := map[CloseInvoker]struct{}{} + cdnConns := make([]CloseInvoker, 0, len(m.conns)+len(m.closeQueue)+len(m.closePending)) + add := func(conn CloseInvoker) { + if conn == nil { + return + } + if _, ok := seen[conn]; ok { + return + } + seen[conn] = struct{}{} + cdnConns = append(cdnConns, conn) + } + for _, pools := range m.conns { + for _, cached := range pools { + add(cached.conn) + } + } + for _, conn := range m.closeQueue { + add(conn) + } + for _, conn := range m.closePending { + add(conn) + } + m.conns = map[int][]cachedCDNPool{} + m.refs = map[CloseInvoker]int{} + m.closing = map[CloseInvoker]bool{} + m.closeQueue = nil + m.closePending = nil + return cdnConns +} + +func (m *cdnPoolManager) refillCloseQueueLocked() { + for len(m.closeQueue) < maxCDNCloseQueue && len(m.closePending) > 0 { + conn := m.closePending[0] + m.closePending[0] = nil + m.closePending = m.closePending[1:] + if conn == nil { + continue + } + + queued, ok := m.closing[conn] + if !ok || queued { + // Already closed/queued by another path. + continue + } + m.closing[conn] = true + m.closeQueue = append(m.closeQueue, conn) + } +} + +func (m *cdnPoolManager) enqueueCloseLocked(stale []CloseInvoker) { + if len(stale) == 0 { + return + } + + for _, conn := range stale { + if conn == nil { + continue + } + if _, ok := m.closing[conn]; ok { + continue + } + if len(m.closeQueue) < maxCDNCloseQueue { + m.closing[conn] = true + m.closeQueue = append(m.closeQueue, conn) + continue + } + + // Queue is saturated, keep pending task deduplicated and promote when + // worker frees queue slots. + m.closing[conn] = false + m.closePending = append(m.closePending, conn) + } + + m.refillCloseQueueLocked() + + // Start enough workers to avoid head-of-line blocking on slow Close(), + // but keep fan-out bounded. + for m.closeWorkers < maxCDNCloseWorkers { + available := m.closeWorkers - m.closeBusy + if available >= len(m.closeQueue) { + break + } + m.closeWorkers++ + go m.runCloseWorker() + } +} + +func (m *cdnPoolManager) runCloseWorker() { + for { + m.mux.Lock() + if len(m.closeQueue) == 0 { + m.closeWorkers-- + m.mux.Unlock() + return + } + conn := m.closeQueue[0] + m.closeQueue[0] = nil + m.closeQueue = m.closeQueue[1:] + m.closeBusy++ + m.mux.Unlock() + + _ = conn.Close() + + m.mux.Lock() + delete(m.closing, conn) + m.closeBusy-- + m.refillCloseQueueLocked() + m.mux.Unlock() + } +} + +func (m *cdnPoolManager) invalidateDC(dcID int) { + m.mux.Lock() + stale := append([]cachedCDNPool(nil), m.conns[dcID]...) + for _, cached := range stale { + delete(m.refs, cached.conn) + } + delete(m.conns, dcID) + + toClose := make([]CloseInvoker, 0, len(stale)) + for _, cached := range stale { + toClose = append(toClose, cached.conn) + } + m.enqueueCloseLocked(toClose) + m.mux.Unlock() +} + +func normalizeCDNPoolMax(max int64) int64 { + // Keep unlimited pools as-is. + if max < 1 { + return max + } + // Collapse close finite values into power-of-two buckets to cap the number + // of cached CDN pools per DC. + if max < 2 { + return max + } + shift := bits.Len64(uint64(max - 1)) + // Guard signed overflow for extremely large values. + if shift >= 63 { + return max + } + return int64(1) << shift +} diff --git a/telegram/cdn_test.go b/telegram/cdn_test.go index f3de20e0d1..e3e3615332 100644 --- a/telegram/cdn_test.go +++ b/telegram/cdn_test.go @@ -1,13 +1,35 @@ package telegram import ( + "context" + "crypto/rsa" + "sync" + "sync/atomic" "testing" + "time" + "github.com/go-faster/errors" "github.com/stretchr/testify/require" + "go.uber.org/zap" + "github.com/gotd/td/bin" + "github.com/gotd/td/exchange" "github.com/gotd/td/tg" ) +func parseCDNKeysForTest(keys ...tg.CDNPublicKey) ([]*rsa.PublicKey, error) { + entries, err := parseCDNKeyEntries(keys...) + if err != nil { + return nil, err + } + + r := make([]*rsa.PublicKey, 0, len(entries)) + for _, entry := range entries { + r = append(r, entry.key) + } + return r, nil +} + func Test_parseCDNKeys(t *testing.T) { keys := []string{ `-----BEGIN RSA PUBLIC KEY----- @@ -37,7 +59,655 @@ FRX7DdwIA/FdOzfWyXYLlCFaSX8K/6CnSQIDAQAB }) } - publicKeys, err := parseCDNKeys(cdnKeys...) + publicKeys, err := parseCDNKeysForTest(cdnKeys...) require.NoError(t, err) require.Len(t, publicKeys, 2) } + +func Test_fetchCDNKeysRetriesAfterFailure(t *testing.T) { + // Regression guard: + // failed first fetch must not poison cache; second call should retry network + // and then populate cache for subsequent calls. + a := require.New(t) + + const key = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEA+Lf3PvgE1yxbJUCMaEAk +V0QySTVpnaDjiednB5RbtNWjCeqSVakYHbqqGMIIv5WCGdFdrqOfMNcNSstPtSU6 +R9UmRw6tquOIykpSuUOje9H+4XVIKqujyL2ISdK+4ZOMl4hCMkqauw4bP1Sbr03v +ZRQbU6qEA04V4j879BAyBVhr3WG9+Zi+t5XfGSTgSExPYEl8rZNHYNV5RB+BuroV +H2HLTOpT/mJVfikYpgjfWF5ldezV4Wo9LSH0cZGSFIaeJl8d0A8Eiy5B9gtBO8mL ++XfQRKOOmr7a4BM4Ro2de5rr2i2od7hYXd3DO9FRSl4y1zA8Am48Rfd95WHF3N/O +mQIDAQAB +-----END RSA PUBLIC KEY-----` + + var calls int + c := &Client{} + c.init() + c.log = zap.NewNop() + c.tg = tg.NewClient(InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error { + _, ok := input.(*tg.HelpGetCDNConfigRequest) + a.True(ok) + result, ok := output.(*tg.CDNConfig) + a.True(ok) + + calls++ + if calls == 1 { + return errors.New("temporary fetch error") + } + result.PublicKeys = []tg.CDNPublicKey{{ + DCID: 1, + PublicKey: key, + }} + return nil + })) + + _, err := c.fetchCDNKeys(context.Background()) + a.Error(err) + + keys, err := c.fetchCDNKeys(context.Background()) + a.NoError(err) + a.Len(keys, 1) + + cached, err := c.fetchCDNKeys(context.Background()) + a.NoError(err) + a.Len(cached, 1) + a.Equal(2, calls) +} + +func Test_fetchCDNKeysInvalidationDropsStaleResult(t *testing.T) { + // Critical race case: + // if fingerprint miss invalidates cache while singleflight fetch is in-flight, + // stale result must be discarded and replaced with fresh key set. + a := require.New(t) + + const staleKey = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEA+Lf3PvgE1yxbJUCMaEAk +V0QySTVpnaDjiednB5RbtNWjCeqSVakYHbqqGMIIv5WCGdFdrqOfMNcNSstPtSU6 +R9UmRw6tquOIykpSuUOje9H+4XVIKqujyL2ISdK+4ZOMl4hCMkqauw4bP1Sbr03v +ZRQbU6qEA04V4j879BAyBVhr3WG9+Zi+t5XfGSTgSExPYEl8rZNHYNV5RB+BuroV +H2HLTOpT/mJVfikYpgjfWF5ldezV4Wo9LSH0cZGSFIaeJl8d0A8Eiy5B9gtBO8mL ++XfQRKOOmr7a4BM4Ro2de5rr2i2od7hYXd3DO9FRSl4y1zA8Am48Rfd95WHF3N/O +mQIDAQAB +-----END RSA PUBLIC KEY-----` + + const freshKey = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEAyu5PXyfp+VFLc2hKJsq/cvQ+wq9V2s1iGMMwcrkXrKAqX0S5QEcY +W9b6pV5LulbsvNcxp/YniiSL4FsAja28B9fH//Y+AolWASomCB0NSVHwS1Pqfe3m +GdLTwDmqU17tSWk/48+Kfn4B+WT85ZIKt8bOnABwnM1AtykX0zKwzm9yKcTX0MeY +rwzgiOQax6J1cfgtLdxl8HVKT6wCOS1e43zpXMU+UoWqRqIan+J6q+ubi1yF4PWl +DyDgJSw8uxlhNNMP4tAnshIRZ1ZZ25O/g58jw1qz5XMztZwLNA2pUxaFtyy1LdHC +FRX7DdwIA/FdOzfWyXYLlCFaSX8K/6CnSQIDAQAB +-----END RSA PUBLIC KEY-----` + + firstStarted := make(chan struct{}) + unblockFirst := make(chan struct{}) + + var calls int + c := &Client{} + c.init() + c.log = zap.NewNop() + c.tg = tg.NewClient(InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error { + _, ok := input.(*tg.HelpGetCDNConfigRequest) + a.True(ok) + result, ok := output.(*tg.CDNConfig) + a.True(ok) + + calls++ + switch calls { + case 1: + close(firstStarted) + <-unblockFirst + result.PublicKeys = []tg.CDNPublicKey{{ + DCID: 1, + PublicKey: staleKey, + }} + default: + result.PublicKeys = []tg.CDNPublicKey{{ + DCID: 1, + PublicKey: freshKey, + }} + } + + return nil + })) + + type fetchResult struct { + keys []exchange.PublicKey + err error + } + done := make(chan fetchResult, 1) + go func() { + keys, err := c.fetchCDNKeys(context.Background()) + done <- fetchResult{keys: keys, err: err} + }() + + <-firstStarted + c.handleCDNConnDead(203, exchange.ErrKeyFingerprintNotFound) + close(unblockFirst) + + result := <-done + a.NoError(result.err) + a.Len(result.keys, 1) + a.Equal(2, calls) + + parsed, err := parseCDNKeysForTest(tg.CDNPublicKey{DCID: 1, PublicKey: freshKey}) + a.NoError(err) + a.Len(parsed, 1) + a.Equal(exchange.PublicKey{RSA: parsed[0]}.Fingerprint(), result.keys[0].Fingerprint()) + + cached, err := c.fetchCDNKeys(context.Background()) + a.NoError(err) + a.Len(cached, 1) + a.Equal(result.keys[0].Fingerprint(), cached[0].Fingerprint()) + a.Equal(2, calls) +} + +func Test_refreshCDNKeysInvalidationDropsStaleResult(t *testing.T) { + // Critical race case for forced refresh path: + // if fingerprint miss happens while refresh is in-flight, stale result must + // not become the cached keyset. + a := require.New(t) + + const staleKey = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEA+Lf3PvgE1yxbJUCMaEAk +V0QySTVpnaDjiednB5RbtNWjCeqSVakYHbqqGMIIv5WCGdFdrqOfMNcNSstPtSU6 +R9UmRw6tquOIykpSuUOje9H+4XVIKqujyL2ISdK+4ZOMl4hCMkqauw4bP1Sbr03v +ZRQbU6qEA04V4j879BAyBVhr3WG9+Zi+t5XfGSTgSExPYEl8rZNHYNV5RB+BuroV +H2HLTOpT/mJVfikYpgjfWF5ldezV4Wo9LSH0cZGSFIaeJl8d0A8Eiy5B9gtBO8mL ++XfQRKOOmr7a4BM4Ro2de5rr2i2od7hYXd3DO9FRSl4y1zA8Am48Rfd95WHF3N/O +mQIDAQAB +-----END RSA PUBLIC KEY-----` + + const freshKey = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEAyu5PXyfp+VFLc2hKJsq/cvQ+wq9V2s1iGMMwcrkXrKAqX0S5QEcY +W9b6pV5LulbsvNcxp/YniiSL4FsAja28B9fH//Y+AolWASomCB0NSVHwS1Pqfe3m +GdLTwDmqU17tSWk/48+Kfn4B+WT85ZIKt8bOnABwnM1AtykX0zKwzm9yKcTX0MeY +rwzgiOQax6J1cfgtLdxl8HVKT6wCOS1e43zpXMU+UoWqRqIan+J6q+ubi1yF4PWl +DyDgJSw8uxlhNNMP4tAnshIRZ1ZZ25O/g58jw1qz5XMztZwLNA2pUxaFtyy1LdHC +FRX7DdwIA/FdOzfWyXYLlCFaSX8K/6CnSQIDAQAB +-----END RSA PUBLIC KEY-----` + + firstStarted := make(chan struct{}) + unblockFirst := make(chan struct{}) + var calls atomic.Int32 + + c := &Client{} + c.init() + c.log = zap.NewNop() + c.tg = tg.NewClient(InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error { + _, ok := input.(*tg.HelpGetCDNConfigRequest) + a.True(ok) + result, ok := output.(*tg.CDNConfig) + a.True(ok) + + switch calls.Add(1) { + case 1: + close(firstStarted) + <-unblockFirst + result.PublicKeys = []tg.CDNPublicKey{{ + DCID: 1, + PublicKey: staleKey, + }} + default: + result.PublicKeys = []tg.CDNPublicKey{{ + DCID: 1, + PublicKey: freshKey, + }} + } + + return nil + })) + + type fetchResult struct { + keys []exchange.PublicKey + err error + } + done := make(chan fetchResult, 1) + go func() { + keys, err := c.refreshCDNKeys(context.Background()) + done <- fetchResult{keys: keys, err: err} + }() + + <-firstStarted + c.handleCDNConnDead(203, exchange.ErrKeyFingerprintNotFound) + close(unblockFirst) + + refreshResult := <-done + a.NoError(refreshResult.err) + a.Len(refreshResult.keys, 1) + + parsedFresh, err := parseCDNKeysForTest(tg.CDNPublicKey{ + DCID: 1, + PublicKey: freshKey, + }) + a.NoError(err) + a.Len(parsedFresh, 1) + freshFingerprint := exchange.PublicKey{RSA: parsedFresh[0]}.Fingerprint() + + keys, err := c.fetchCDNKeys(context.Background()) + a.NoError(err) + a.Len(keys, 1) + a.Equal(freshFingerprint, keys[0].Fingerprint()) + a.GreaterOrEqual(calls.Load(), int32(2)) + + callsBeforeCacheRead := calls.Load() + cached, err := c.fetchCDNKeys(context.Background()) + a.NoError(err) + a.Len(cached, 1) + a.Equal(freshFingerprint, cached[0].Fingerprint()) + a.Equal(callsBeforeCacheRead, calls.Load(), "second fetch should hit cache") +} + +func Test_fetchCDNKeysForDCReturnsOnlyRequestedDC(t *testing.T) { + a := require.New(t) + + const keyDC1 = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEA+Lf3PvgE1yxbJUCMaEAk +V0QySTVpnaDjiednB5RbtNWjCeqSVakYHbqqGMIIv5WCGdFdrqOfMNcNSstPtSU6 +R9UmRw6tquOIykpSuUOje9H+4XVIKqujyL2ISdK+4ZOMl4hCMkqauw4bP1Sbr03v +ZRQbU6qEA04V4j879BAyBVhr3WG9+Zi+t5XfGSTgSExPYEl8rZNHYNV5RB+BuroV +H2HLTOpT/mJVfikYpgjfWF5ldezV4Wo9LSH0cZGSFIaeJl8d0A8Eiy5B9gtBO8mL ++XfQRKOOmr7a4BM4Ro2de5rr2i2od7hYXd3DO9FRSl4y1zA8Am48Rfd95WHF3N/O +mQIDAQAB +-----END RSA PUBLIC KEY-----` + const keyDC2 = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEAyu5PXyfp+VFLc2hKJsq/cvQ+wq9V2s1iGMMwcrkXrKAqX0S5QEcY +W9b6pV5LulbsvNcxp/YniiSL4FsAja28B9fH//Y+AolWASomCB0NSVHwS1Pqfe3m +GdLTwDmqU17tSWk/48+Kfn4B+WT85ZIKt8bOnABwnM1AtykX0zKwzm9yKcTX0MeY +rwzgiOQax6J1cfgtLdxl8HVKT6wCOS1e43zpXMU+UoWqRqIan+J6q+ubi1yF4PWl +DyDgJSw8uxlhNNMP4tAnshIRZ1ZZ25O/g58jw1qz5XMztZwLNA2pUxaFtyy1LdHC +FRX7DdwIA/FdOzfWyXYLlCFaSX8K/6CnSQIDAQAB +-----END RSA PUBLIC KEY-----` + + var calls int + c := &Client{} + c.init() + c.log = zap.NewNop() + c.tg = tg.NewClient(InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error { + _, ok := input.(*tg.HelpGetCDNConfigRequest) + a.True(ok) + result, ok := output.(*tg.CDNConfig) + a.True(ok) + calls++ + result.PublicKeys = []tg.CDNPublicKey{ + { + DCID: 1, + PublicKey: keyDC1, + }, + { + DCID: 2, + PublicKey: keyDC2, + }, + } + return nil + })) + + all, err := c.fetchCDNKeys(context.Background()) + a.NoError(err) + a.Len(all, 2) + + dc1, err := c.fetchCDNKeysForDC(context.Background(), 1) + a.NoError(err) + a.Len(dc1, 1) + + dc2, err := c.fetchCDNKeysForDC(context.Background(), 2) + a.NoError(err) + a.Len(dc2, 1) + + a.NotEqual(dc1[0].Fingerprint(), dc2[0].Fingerprint()) + a.Equal(1, calls, "help.getCdnConfig must stay cached") +} + +func Test_fetchCDNKeysCanceledCallerDoesNotPoisonConcurrentWaiters(t *testing.T) { + a := require.New(t) + + const key = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEA+Lf3PvgE1yxbJUCMaEAk +V0QySTVpnaDjiednB5RbtNWjCeqSVakYHbqqGMIIv5WCGdFdrqOfMNcNSstPtSU6 +R9UmRw6tquOIykpSuUOje9H+4XVIKqujyL2ISdK+4ZOMl4hCMkqauw4bP1Sbr03v +ZRQbU6qEA04V4j879BAyBVhr3WG9+Zi+t5XfGSTgSExPYEl8rZNHYNV5RB+BuroV +H2HLTOpT/mJVfikYpgjfWF5ldezV4Wo9LSH0cZGSFIaeJl8d0A8Eiy5B9gtBO8mL ++XfQRKOOmr7a4BM4Ro2de5rr2i2od7hYXd3DO9FRSl4y1zA8Am48Rfd95WHF3N/O +mQIDAQAB +-----END RSA PUBLIC KEY-----` + + started := make(chan struct{}) + release := make(chan struct{}) + var startedOnce sync.Once + var calls atomic.Int32 + + c := &Client{} + c.init() + c.log = zap.NewNop() + c.ctx, c.cancel = context.WithCancel(context.Background()) + defer c.cancel() + c.tg = tg.NewClient(InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error { + _, ok := input.(*tg.HelpGetCDNConfigRequest) + a.True(ok) + result, ok := output.(*tg.CDNConfig) + a.True(ok) + + calls.Add(1) + startedOnce.Do(func() { + close(started) + }) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-release: + result.PublicKeys = []tg.CDNPublicKey{{ + DCID: 1, + PublicKey: key, + }} + return nil + } + })) + + type fetchResult struct { + keys []exchange.PublicKey + err error + } + + firstCtx, firstCancel := context.WithCancel(context.Background()) + defer firstCancel() + + firstDone := make(chan fetchResult, 1) + go func() { + keys, err := c.fetchCDNKeys(firstCtx) + firstDone <- fetchResult{keys: keys, err: err} + }() + + <-started + + secondDone := make(chan fetchResult, 1) + go func() { + keys, err := c.fetchCDNKeys(context.Background()) + secondDone <- fetchResult{keys: keys, err: err} + }() + + firstCancel() + close(release) + + first := <-firstDone + second := <-secondDone + + a.ErrorIs(first.err, context.Canceled) + a.NoError(second.err) + a.Len(second.keys, 1) + a.GreaterOrEqual(calls.Load(), int32(1)) + a.LessOrEqual(calls.Load(), int32(2)) +} + +func Test_fetchCDNKeysDeadlineCallerDoesNotPoisonConcurrentWaiters(t *testing.T) { + a := require.New(t) + + const key = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEA+Lf3PvgE1yxbJUCMaEAk +V0QySTVpnaDjiednB5RbtNWjCeqSVakYHbqqGMIIv5WCGdFdrqOfMNcNSstPtSU6 +R9UmRw6tquOIykpSuUOje9H+4XVIKqujyL2ISdK+4ZOMl4hCMkqauw4bP1Sbr03v +ZRQbU6qEA04V4j879BAyBVhr3WG9+Zi+t5XfGSTgSExPYEl8rZNHYNV5RB+BuroV +H2HLTOpT/mJVfikYpgjfWF5ldezV4Wo9LSH0cZGSFIaeJl8d0A8Eiy5B9gtBO8mL ++XfQRKOOmr7a4BM4Ro2de5rr2i2od7hYXd3DO9FRSl4y1zA8Am48Rfd95WHF3N/O +mQIDAQAB +-----END RSA PUBLIC KEY-----` + + started := make(chan struct{}) + release := make(chan struct{}) + var startedOnce sync.Once + var calls atomic.Int32 + + c := &Client{} + c.init() + c.log = zap.NewNop() + c.ctx, c.cancel = context.WithCancel(context.Background()) + defer c.cancel() + c.tg = tg.NewClient(InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error { + _, ok := input.(*tg.HelpGetCDNConfigRequest) + a.True(ok) + result, ok := output.(*tg.CDNConfig) + a.True(ok) + + calls.Add(1) + startedOnce.Do(func() { + close(started) + }) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-release: + result.PublicKeys = []tg.CDNPublicKey{{ + DCID: 1, + PublicKey: key, + }} + return nil + } + })) + + type fetchResult struct { + keys []exchange.PublicKey + err error + } + + firstCtx, firstCancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer firstCancel() + + firstDone := make(chan fetchResult, 1) + go func() { + keys, err := c.fetchCDNKeys(firstCtx) + firstDone <- fetchResult{keys: keys, err: err} + }() + + <-started + + secondDone := make(chan fetchResult, 1) + go func() { + keys, err := c.fetchCDNKeys(context.Background()) + secondDone <- fetchResult{keys: keys, err: err} + }() + + time.Sleep(35 * time.Millisecond) + close(release) + + first := <-firstDone + second := <-secondDone + + a.Error(first.err) + a.True(errors.Is(first.err, context.DeadlineExceeded)) + a.NoError(second.err) + a.Len(second.keys, 1) + a.GreaterOrEqual(calls.Load(), int32(1)) + a.LessOrEqual(calls.Load(), int32(2)) +} + +func Test_fetchCDNKeysForDCRetriesWhenCachedConfigMissesRequestedDC(t *testing.T) { + a := require.New(t) + + const keyDC1 = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEA+Lf3PvgE1yxbJUCMaEAk +V0QySTVpnaDjiednB5RbtNWjCeqSVakYHbqqGMIIv5WCGdFdrqOfMNcNSstPtSU6 +R9UmRw6tquOIykpSuUOje9H+4XVIKqujyL2ISdK+4ZOMl4hCMkqauw4bP1Sbr03v +ZRQbU6qEA04V4j879BAyBVhr3WG9+Zi+t5XfGSTgSExPYEl8rZNHYNV5RB+BuroV +H2HLTOpT/mJVfikYpgjfWF5ldezV4Wo9LSH0cZGSFIaeJl8d0A8Eiy5B9gtBO8mL ++XfQRKOOmr7a4BM4Ro2de5rr2i2od7hYXd3DO9FRSl4y1zA8Am48Rfd95WHF3N/O +mQIDAQAB +-----END RSA PUBLIC KEY-----` + const keyDC2 = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEAyu5PXyfp+VFLc2hKJsq/cvQ+wq9V2s1iGMMwcrkXrKAqX0S5QEcY +W9b6pV5LulbsvNcxp/YniiSL4FsAja28B9fH//Y+AolWASomCB0NSVHwS1Pqfe3m +GdLTwDmqU17tSWk/48+Kfn4B+WT85ZIKt8bOnABwnM1AtykX0zKwzm9yKcTX0MeY +rwzgiOQax6J1cfgtLdxl8HVKT6wCOS1e43zpXMU+UoWqRqIan+J6q+ubi1yF4PWl +DyDgJSw8uxlhNNMP4tAnshIRZ1ZZ25O/g58jw1qz5XMztZwLNA2pUxaFtyy1LdHC +FRX7DdwIA/FdOzfWyXYLlCFaSX8K/6CnSQIDAQAB +-----END RSA PUBLIC KEY-----` + + var calls int + c := &Client{} + c.init() + c.log = zap.NewNop() + c.tg = tg.NewClient(InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error { + _, ok := input.(*tg.HelpGetCDNConfigRequest) + a.True(ok) + result, ok := output.(*tg.CDNConfig) + a.True(ok) + + calls++ + switch calls { + case 1: + // First load has no keys for DC 2. + result.PublicKeys = []tg.CDNPublicKey{{ + DCID: 1, + PublicKey: keyDC1, + }} + default: + // Refresh includes keys for requested DC. + result.PublicKeys = []tg.CDNPublicKey{ + { + DCID: 1, + PublicKey: keyDC1, + }, + { + DCID: 2, + PublicKey: keyDC2, + }, + } + } + return nil + })) + + _, err := c.fetchCDNKeys(context.Background()) + a.NoError(err) + a.Equal(1, calls) + + keysDC2, err := c.fetchCDNKeysForDC(context.Background(), 2) + a.NoError(err) + a.Len(keysDC2, 1) + a.Equal(2, calls) + + keysDC2Cached, err := c.fetchCDNKeysForDC(context.Background(), 2) + a.NoError(err) + a.Len(keysDC2Cached, 1) + a.Equal(2, calls, "successful refresh should be cached") +} + +func Test_fetchCDNKeysForDCRecoversFromEmptyCachedSnapshot(t *testing.T) { + a := require.New(t) + + const keyDC2 = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEAyu5PXyfp+VFLc2hKJsq/cvQ+wq9V2s1iGMMwcrkXrKAqX0S5QEcY +W9b6pV5LulbsvNcxp/YniiSL4FsAja28B9fH//Y+AolWASomCB0NSVHwS1Pqfe3m +GdLTwDmqU17tSWk/48+Kfn4B+WT85ZIKt8bOnABwnM1AtykX0zKwzm9yKcTX0MeY +rwzgiOQax6J1cfgtLdxl8HVKT6wCOS1e43zpXMU+UoWqRqIan+J6q+ubi1yF4PWl +DyDgJSw8uxlhNNMP4tAnshIRZ1ZZ25O/g58jw1qz5XMztZwLNA2pUxaFtyy1LdHC +FRX7DdwIA/FdOzfWyXYLlCFaSX8K/6CnSQIDAQAB +-----END RSA PUBLIC KEY-----` + + var calls int + c := &Client{} + c.init() + c.log = zap.NewNop() + c.cdnKeysSet = true + c.cdnKeys = nil + c.cdnKeysByDC = map[int][]PublicKey{} + c.tg = tg.NewClient(InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error { + _, ok := input.(*tg.HelpGetCDNConfigRequest) + a.True(ok) + cfg, ok := output.(*tg.CDNConfig) + a.True(ok) + + calls++ + cfg.PublicKeys = []tg.CDNPublicKey{{ + DCID: 2, + PublicKey: keyDC2, + }} + return nil + })) + + keys, err := c.fetchCDNKeysForDC(context.Background(), 2) + a.NoError(err) + a.Len(keys, 1) + a.Equal(1, calls) + + // Ensure recovered keyset is now cached and reused. + keysCached, err := c.fetchCDNKeysForDC(context.Background(), 2) + a.NoError(err) + a.Len(keysCached, 1) + a.Equal(1, calls) +} + +func Test_fetchCDNKeysForDCMissingAfterRefreshRecoversWithinSingleCall(t *testing.T) { + a := require.New(t) + + const keyDC1 = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEA+Lf3PvgE1yxbJUCMaEAk +V0QySTVpnaDjiednB5RbtNWjCeqSVakYHbqqGMIIv5WCGdFdrqOfMNcNSstPtSU6 +R9UmRw6tquOIykpSuUOje9H+4XVIKqujyL2ISdK+4ZOMl4hCMkqauw4bP1Sbr03v +ZRQbU6qEA04V4j879BAyBVhr3WG9+Zi+t5XfGSTgSExPYEl8rZNHYNV5RB+BuroV +H2HLTOpT/mJVfikYpgjfWF5ldezV4Wo9LSH0cZGSFIaeJl8d0A8Eiy5B9gtBO8mL ++XfQRKOOmr7a4BM4Ro2de5rr2i2od7hYXd3DO9FRSl4y1zA8Am48Rfd95WHF3N/O +mQIDAQAB +-----END RSA PUBLIC KEY-----` + const keyDC2 = `-----BEGIN RSA PUBLIC KEY----- +MIIBCgKCAQEAyu5PXyfp+VFLc2hKJsq/cvQ+wq9V2s1iGMMwcrkXrKAqX0S5QEcY +W9b6pV5LulbsvNcxp/YniiSL4FsAja28B9fH//Y+AolWASomCB0NSVHwS1Pqfe3m +GdLTwDmqU17tSWk/48+Kfn4B+WT85ZIKt8bOnABwnM1AtykX0zKwzm9yKcTX0MeY +rwzgiOQax6J1cfgtLdxl8HVKT6wCOS1e43zpXMU+UoWqRqIan+J6q+ubi1yF4PWl +DyDgJSw8uxlhNNMP4tAnshIRZ1ZZ25O/g58jw1qz5XMztZwLNA2pUxaFtyy1LdHC +FRX7DdwIA/FdOzfWyXYLlCFaSX8K/6CnSQIDAQAB +-----END RSA PUBLIC KEY-----` + + var calls int + c := &Client{} + c.init() + c.log = zap.NewNop() + c.tg = tg.NewClient(InvokeFunc(func(ctx context.Context, input bin.Encoder, output bin.Decoder) error { + _, ok := input.(*tg.HelpGetCDNConfigRequest) + a.True(ok) + result, ok := output.(*tg.CDNConfig) + a.True(ok) + + calls++ + switch calls { + case 1: + // Initial load has no keys for DC 2. + result.PublicKeys = []tg.CDNPublicKey{{ + DCID: 1, + PublicKey: keyDC1, + }} + case 2: + // First refresh still misses requested DC. + result.PublicKeys = []tg.CDNPublicKey{{ + DCID: 1, + PublicKey: keyDC1, + }} + default: + // Next refresh returns keys for requested DC. + result.PublicKeys = []tg.CDNPublicKey{ + { + DCID: 1, + PublicKey: keyDC1, + }, + { + DCID: 2, + PublicKey: keyDC2, + }, + } + } + return nil + })) + + _, err := c.fetchCDNKeys(context.Background()) + a.NoError(err) + a.Equal(1, calls) + + recovered, err := c.fetchCDNKeysForDC(context.Background(), 2) + a.NoError(err) + a.Len(recovered, 1) + a.Equal(3, calls) + + cached, err := c.fetchCDNKeysForDC(context.Background(), 2) + a.NoError(err) + a.Len(cached, 1) + a.Equal(3, calls, "successful recovery should be cached") +} diff --git a/telegram/client.go b/telegram/client.go index 38337a8085..bfd4e0d792 100644 --- a/telegram/client.go +++ b/telegram/client.go @@ -10,6 +10,7 @@ import ( "go.opentelemetry.io/otel/trace" "go.uber.org/atomic" "go.uber.org/zap" + "golang.org/x/sync/singleflight" "github.com/gotd/td/bin" "github.com/gotd/td/clock" @@ -102,8 +103,22 @@ type Client struct { // Connections to non-primary DC. subConns map[int]CloseInvoker subConnsMux sync.Mutex - sessions map[int]*pool.SyncSession + // Shared CDN pools and handle references. + cdnPools cdnPoolManager + // sessions stores regular non-primary DC sessions. + sessions map[int]*pool.SyncSession + // cdnSessions stores session state for CDN pools separately from regular DCs. + cdnSessions map[int]*pool.SyncSession sessionsMux sync.Mutex + // CDN public keys loaded from help.getCdnConfig and cached per CDN DC. + cdnKeys []PublicKey + cdnKeysByDC map[int][]PublicKey + cdnKeysSet bool + // cdnKeysGen increments on cache invalidation to avoid storing stale + // singleflight result after fingerprint miss. + cdnKeysGen uint64 + cdnKeysMux sync.Mutex + cdnKeysLoad singleflight.Group // Wrappers for external world, like logs or PRNG. rand io.Reader // immutable @@ -117,6 +132,8 @@ type Client struct { // Client config. appID int // immutable appHash string // immutable + // allowCDN is the explicit downloader policy copied from Options.AllowCDN. + allowCDN bool // immutable // Session storage. storage clientStorage // immutable, nillable @@ -155,6 +172,7 @@ func NewClient(appID int, appHash string, opt Options) *Client { log: opt.Logger, appID: appID, appHash: appHash, + allowCDN: opt.AllowCDN, updateHandler: opt.UpdateHandler, session: pool.NewSyncSession(pool.Session{ DC: opt.DC, @@ -232,7 +250,15 @@ func (c *Client) init() { c.restart = make(chan struct{}) c.migration = make(chan struct{}, 1) c.sessions = map[int]*pool.SyncSession{} + c.cdnSessions = map[int]*pool.SyncSession{} c.subConns = map[int]CloseInvoker{} + c.cdnPools = newCDNPoolManager() + // CDN key cache is cold-started and filled lazily on first CDN pool create. + c.cdnKeys = nil + c.cdnKeysByDC = nil + c.cdnKeysSet = false + c.cdnKeysGen = 0 + c.cdnKeysLoad = singleflight.Group{} c.invoker = chainMiddlewares(InvokeFunc(c.invokeDirect), c.mw...) c.tg = tg.NewClient(c.invoker) } diff --git a/telegram/conn_builder.go b/telegram/conn_builder.go index 8f13af2a17..44cf7ccf52 100644 --- a/telegram/conn_builder.go +++ b/telegram/conn_builder.go @@ -32,6 +32,26 @@ func (c *Client) asHandler() manager.Handler { } } +type cdnClientHandler struct { + client *Client +} + +func (c cdnClientHandler) OnSession(cfg tg.Config, s mtproto.Session) error { + // CDN sessions are stored separately from regular DC sessions. + return c.client.onCDNSession(cfg, s) +} + +func (cdnClientHandler) OnMessage(*bin.Buffer) error { + // CDN connections never deliver updates. + return nil +} + +func (c *Client) asCDNHandler() manager.Handler { + return cdnClientHandler{ + client: c, + } +} + type connConstructor func( create mtproto.Dialer, mode manager.ConnMode, diff --git a/telegram/connect.go b/telegram/connect.go index c096c7c652..be08d749b7 100644 --- a/telegram/connect.go +++ b/telegram/connect.go @@ -142,18 +142,32 @@ func (c *Client) Run(ctx context.Context, f func(ctx context.Context) error) (er c.log.Info("Starting") defer c.log.Info("Closed") - // Cancel client on exit. - defer c.cancel() defer func() { c.subConnsMux.Lock() - defer c.subConnsMux.Unlock() - + subConns := make([]CloseInvoker, 0, len(c.subConns)) for _, conn := range c.subConns { + subConns = append(subConns, conn) + } + c.subConns = map[int]CloseInvoker{} + c.subConnsMux.Unlock() + + cdnConns := c.cdnPools.drain() + + // Close outside locks to avoid lock inversion with pool callbacks. + for _, conn := range subConns { + if closeErr := conn.Close(); !errors.Is(closeErr, context.Canceled) { + multierr.AppendInto(&err, closeErr) + } + } + for _, conn := range cdnConns { if closeErr := conn.Close(); !errors.Is(closeErr, context.Canceled) { multierr.AppendInto(&err, closeErr) } } }() + // Cancel client before deferred cleanup snapshot to block concurrent pool + // creations while we are draining cached connections. + defer c.cancel() c.resetReady() if err := c.restoreConnection(ctx); err != nil { diff --git a/telegram/connect_cdn_test.go b/telegram/connect_cdn_test.go new file mode 100644 index 0000000000..6001d5d6e7 --- /dev/null +++ b/telegram/connect_cdn_test.go @@ -0,0 +1,148 @@ +package telegram + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/gotd/td/bin" +) + +type runClientConn struct { + run func(ctx context.Context) error +} + +func (r runClientConn) Run(ctx context.Context) error { + return r.run(ctx) +} + +func (runClientConn) Invoke(context.Context, bin.Encoder, bin.Decoder) error { + return nil +} + +func (runClientConn) Ping(context.Context) error { + return nil +} + +type cancelCheckInvoker struct { + client *Client + + closed atomic.Bool + canceledBeforeClose atomic.Bool +} + +type inProgressCloseInvoker struct { + started chan struct{} + unblock chan struct{} + done chan struct{} + + calls atomic.Int32 + once sync.Once +} + +func (*inProgressCloseInvoker) Invoke(context.Context, bin.Encoder, bin.Decoder) error { + return nil +} + +func (c *inProgressCloseInvoker) Close() error { + if c.calls.Add(1) > 1 { + return errors.New("DC already closed") + } + + c.once.Do(func() { + close(c.started) + }) + <-c.unblock + close(c.done) + return nil +} + +func (*cancelCheckInvoker) Invoke(context.Context, bin.Encoder, bin.Decoder) error { + return nil +} + +func (c *cancelCheckInvoker) Close() error { + c.closed.Store(true) + + if c.client != nil && c.client.ctx != nil { + select { + case <-c.client.ctx.Done(): + c.canceledBeforeClose.Store(true) + default: + c.canceledBeforeClose.Store(false) + } + } + + return nil +} + +func TestClientRunCancelsContextBeforeClosingManagedConns(t *testing.T) { + c := NewClient(1, "hash", Options{ + NoUpdates: true, + Logger: zap.NewNop(), + }) + + checker := &cancelCheckInvoker{client: c} + c.subConns[1] = checker + c.conn = runClientConn{ + run: func(ctx context.Context) error { + c.onReady() + <-ctx.Done() + return ctx.Err() + }, + } + + err := c.Run(context.Background(), func(context.Context) error { return nil }) + require.NoError(t, err) + require.True(t, checker.closed.Load()) + require.True(t, checker.canceledBeforeClose.Load()) +} + +func TestClientRunSkipsDoubleCloseForAlreadyClosingCDNConn(t *testing.T) { + c := NewClient(1, "hash", Options{ + NoUpdates: true, + Logger: zap.NewNop(), + }) + + inv := &inProgressCloseInvoker{ + started: make(chan struct{}), + unblock: make(chan struct{}), + done: make(chan struct{}), + } + c.cdnPools.conns[203] = []cachedCDNPool{{ + conn: inv, + max: 1, + }} + c.cdnPools.invalidateDC(203) + + select { + case <-inv.started: + case <-time.After(time.Second): + t.Fatal("expected close worker to start") + } + + c.conn = runClientConn{ + run: func(ctx context.Context) error { + c.onReady() + <-ctx.Done() + return ctx.Err() + }, + } + + err := c.Run(context.Background(), func(context.Context) error { return nil }) + require.NoError(t, err) + require.EqualValues(t, 1, inv.calls.Load(), "shutdown must not issue second close") + + close(inv.unblock) + select { + case <-inv.done: + case <-time.After(time.Second): + t.Fatal("expected close worker to finish") + } +} diff --git a/telegram/dc_transfer.go b/telegram/dc_transfer.go new file mode 100644 index 0000000000..9b91ef02e3 --- /dev/null +++ b/telegram/dc_transfer.go @@ -0,0 +1,20 @@ +package telegram + +import ( + "context" + + "github.com/gotd/td/telegram/auth" + "github.com/gotd/td/telegram/internal/manager" + "github.com/gotd/td/tg" +) + +func (c *Client) dcTransferSetup(dcID int) manager.SetupCallback { + return func(ctx context.Context, invoker tg.Invoker) error { + // Run export/import authorization only when the connection is already up. + _, err := c.transfer(ctx, tg.NewClient(invoker), dcID) + if auth.IsUnauthorized(err) { + return nil + } + return err + } +} diff --git a/telegram/dcs/plain.go b/telegram/dcs/plain.go index 8fd2a491a9..bf1c6d927f 100644 --- a/telegram/dcs/plain.go +++ b/telegram/dcs/plain.go @@ -57,7 +57,17 @@ func (p plain) MediaOnly(ctx context.Context, dc int, list List) (transport.Conn } func (p plain) CDN(ctx context.Context, dc int, list List) (transport.Conn, error) { - return nil, errors.Errorf("can't resolve %d: CDN is unsupported", dc) + candidates := FindDCs(list.Options, dc, p.preferIPv6) + // Filter (in place) from SliceTricks. + n := 0 + for _, x := range candidates { + if x.CDN { + candidates[n] = x + n++ + } + } + // connect() keeps existing racing-dial behavior used by other DC selectors. + return p.connect(ctx, dc, list.Test, candidates[:n]) } func (p plain) dialTransport(ctx context.Context, test bool, dc tg.DCOption) (_ transport.Conn, rerr error) { diff --git a/telegram/download.go b/telegram/download.go new file mode 100644 index 0000000000..8e0e78824f --- /dev/null +++ b/telegram/download.go @@ -0,0 +1,94 @@ +package telegram + +import ( + "context" + "io" + + "github.com/go-faster/errors" + + "github.com/gotd/td/telegram/downloader" + "github.com/gotd/td/tg" +) + +type downloadClient struct { + client *Client +} + +func (d downloadClient) api() *tg.Client { + return d.client.API() +} + +func (d downloadClient) UploadGetFile( + ctx context.Context, + request *tg.UploadGetFileRequest, +) (tg.UploadFileClass, error) { + resp, err := d.api().UploadGetFile(ctx, request) + + return resp, err +} + +func (d downloadClient) UploadGetFileHashes( + ctx context.Context, + request *tg.UploadGetFileHashesRequest, +) ([]tg.FileHash, error) { + resp, err := d.api().UploadGetFileHashes(ctx, request) + + return resp, err +} + +func (d downloadClient) UploadReuploadCDNFile( + ctx context.Context, + request *tg.UploadReuploadCDNFileRequest, +) ([]tg.FileHash, error) { + resp, err := d.api().UploadReuploadCDNFile(ctx, request) + + return resp, err +} + +func (d downloadClient) UploadGetCDNFileHashes( + ctx context.Context, + request *tg.UploadGetCDNFileHashesRequest, +) ([]tg.FileHash, error) { + resp, err := d.api().UploadGetCDNFileHashes(ctx, request) + + return resp, err +} + +func (d downloadClient) UploadGetWebFile( + ctx context.Context, + request *tg.UploadGetWebFileRequest, +) (*tg.UploadWebFile, error) { + return d.api().UploadGetWebFile(ctx, request) +} + +func (d downloadClient) CDN(ctx context.Context, dc int, max int64) (downloader.CDN, io.Closer, error) { + invoker, err := d.client.CDN(ctx, dc, max) + if err != nil { + return nil, nil, err + } + if invoker == nil { + return nil, nil, errors.New("telegram CDN pool returned nil invoker") + } + + // CDN pools are cached on client level; downloader should not close them + // by itself; lifecycle is controlled by caller via returned closer. + cdnClient := tg.NewClient(invoker) + + return cdnClient, invoker, nil +} + +// Downloader returns file downloader configured for current client. +func (c *Client) Downloader() *downloader.Downloader { + // Propagate explicit client-level CDN policy into downloader. + return downloader.NewDownloader().WithAllowCDN(c.allowCDN) +} + +// Download creates Builder for plain file downloads. +func (c *Client) Download(location tg.InputFileLocationClass) *downloader.Builder { + return c.Downloader().Download(downloadClient{client: c}, location) +} + +// DownloadWeb creates Builder for web file downloads. +func (c *Client) DownloadWeb(location tg.InputWebFileLocationClass) *downloader.Builder { + return c.Downloader().Web(downloadClient{client: c}, location) +} diff --git a/telegram/download_cdn_test.go b/telegram/download_cdn_test.go new file mode 100644 index 0000000000..34364e2d31 --- /dev/null +++ b/telegram/download_cdn_test.go @@ -0,0 +1,34 @@ +package telegram + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDownloadClientCDNCloserKeepsSharedPoolCached(t *testing.T) { + c := newCDNPoolTestClient() + defer c.cancel() + + d := downloadClient{client: c} + _, closer, err := d.CDN(context.Background(), 203, 1) + require.NoError(t, err) + require.NotNil(t, closer) + + c.cdnPools.mux.Lock() + require.EqualValues(t, 1, len(c.cdnPools.refs)) + require.EqualValues(t, 1, len(c.cdnPools.conns[203])) + c.cdnPools.mux.Unlock() + + require.NoError(t, closer.Close()) + + c.cdnPools.mux.Lock() + require.EqualValues(t, 1, len(c.cdnPools.refs)) + require.EqualValues(t, 1, len(c.cdnPools.conns[203])) + c.cdnPools.mux.Unlock() + + _, closer2, err := d.CDN(context.Background(), 203, 1) + require.NoError(t, err) + require.NoError(t, closer2.Close()) +} diff --git a/telegram/download_more_test.go b/telegram/download_more_test.go new file mode 100644 index 0000000000..ae9d7f3a51 --- /dev/null +++ b/telegram/download_more_test.go @@ -0,0 +1,131 @@ +package telegram + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gotd/td/bin" + "github.com/gotd/td/tg" +) + +func TestDownloadClientDelegatesUploadMethods(t *testing.T) { + ctx := context.Background() + location := &tg.InputFileLocation{ + VolumeID: 1, + LocalID: 2, + Secret: 3, + FileReference: []byte{4}, + } + webLocation := &tg.InputWebFileLocation{ + URL: "https://example.com/file.bin", + AccessHash: 55, + } + + client := newTestClient(func(_ int64, body bin.Encoder) (bin.Encoder, error) { + switch req := body.(type) { + case *tg.UploadGetFileRequest: + require.EqualValues(t, 64, req.Offset) + require.EqualValues(t, 16, req.Limit) + require.Same(t, location, req.Location) + return &tg.UploadFile{ + Type: &tg.StorageFileUnknown{}, + Bytes: []byte("file"), + }, nil + case *tg.UploadGetFileHashesRequest: + require.EqualValues(t, 128, req.Offset) + require.Same(t, location, req.Location) + return &tg.FileHashVector{Elems: []tg.FileHash{ + {Offset: 128, Limit: 4, Hash: []byte{1, 2, 3, 4}}, + }}, nil + case *tg.UploadReuploadCDNFileRequest: + require.Equal(t, []byte{9}, req.FileToken) + require.Equal(t, []byte{8}, req.RequestToken) + return &tg.FileHashVector{Elems: []tg.FileHash{ + {Offset: 0, Limit: 8, Hash: []byte{5, 6}}, + }}, nil + case *tg.UploadGetCDNFileHashesRequest: + require.Equal(t, []byte{9}, req.FileToken) + require.EqualValues(t, 256, req.Offset) + return &tg.FileHashVector{Elems: []tg.FileHash{ + {Offset: 256, Limit: 8, Hash: []byte{7, 8}}, + }}, nil + case *tg.UploadGetWebFileRequest: + require.Same(t, webLocation, req.Location) + require.EqualValues(t, 7, req.Offset) + require.EqualValues(t, 11, req.Limit) + return &tg.UploadWebFile{ + FileType: &tg.StorageFileUnknown{}, + Bytes: []byte("web"), + }, nil + default: + t.Fatalf("unexpected request type %T", body) + return nil, nil + } + }) + + d := downloadClient{client: client} + + file, err := d.UploadGetFile(ctx, &tg.UploadGetFileRequest{ + Location: location, + Offset: 64, + Limit: 16, + }) + require.NoError(t, err) + typedFile, ok := file.(*tg.UploadFile) + require.True(t, ok) + require.Equal(t, []byte("file"), typedFile.Bytes) + + hashes, err := d.UploadGetFileHashes(ctx, &tg.UploadGetFileHashesRequest{ + Location: location, + Offset: 128, + }) + require.NoError(t, err) + require.Len(t, hashes, 1) + require.EqualValues(t, 128, hashes[0].Offset) + + reupload, err := d.UploadReuploadCDNFile(ctx, &tg.UploadReuploadCDNFileRequest{ + FileToken: []byte{9}, + RequestToken: []byte{8}, + }) + require.NoError(t, err) + require.Len(t, reupload, 1) + require.EqualValues(t, 8, reupload[0].Limit) + + cdnHashes, err := d.UploadGetCDNFileHashes(ctx, &tg.UploadGetCDNFileHashesRequest{ + FileToken: []byte{9}, + Offset: 256, + }) + require.NoError(t, err) + require.Len(t, cdnHashes, 1) + require.EqualValues(t, 256, cdnHashes[0].Offset) + + web, err := d.UploadGetWebFile(ctx, &tg.UploadGetWebFileRequest{ + Location: webLocation, + Offset: 7, + Limit: 11, + }) + require.NoError(t, err) + require.Equal(t, []byte("web"), web.Bytes) +} + +func TestClientDownloadBuilders(t *testing.T) { + c := &Client{allowCDN: true} + + require.NotNil(t, c.Downloader()) + require.NotNil(t, c.Download(&tg.InputFileLocation{})) + require.NotNil(t, c.DownloadWeb(&tg.InputWebFileLocation{ + URL: "https://example.com/file.bin", + AccessHash: 1, + })) +} + +func TestDownloadClientCDNErrorPropagation(t *testing.T) { + c := newCDNPoolTestClient() + defer c.cancel() + + d := downloadClient{client: c} + _, _, err := d.CDN(context.Background(), 404, 1) + require.Error(t, err) +} diff --git a/telegram/downloader/builder.go b/telegram/downloader/builder.go index b324e12fb3..f0dfbd10f7 100644 --- a/telegram/downloader/builder.go +++ b/telegram/downloader/builder.go @@ -16,8 +16,11 @@ import ( type Builder struct { downloader *Downloader - schema schema - hashes []tg.FileHash + schema schema + hashes []tg.FileHash + // verify controls legacy outer verifier (reader + verifier queue). + // CDN redirect path has mandatory protocol-level verification in cdn schema, + // independent from this flag. verify bool threads int } @@ -38,14 +41,116 @@ func (b *Builder) WithThreads(threads int) *Builder { return b } -// WithVerify sets verify parameter. -// If verify is true, file hashes will be checked -// Verify is true by default for CDN downloads. +// WithRetryHandler sets callback for transient download errors that are retried +// internally. +// +// Handler can be called concurrently from download workers. +func (b *Builder) WithRetryHandler(handler RetryHandler) *Builder { + switch s := b.schema.(type) { + case master: + s.retryHandler = handler + b.schema = s + case web: + s.retryHandler = handler + b.schema = s + case *cdn: + s.retryHandler = handler + } + + return b +} + +// WithVerify controls global hash verification behavior. +// +// `true` enables classic verifier reader (preloads hash queue and validates all +// chunks, both legacy and CDN). +// `false` disables classic verifier reader. +// +// If not called explicitly: +// - non-CDN path preserves old behavior (no upfront hash requests); +// - CDN path enables strict inline CDN verification after redirect. +// +// Use WithVerify(true) to force verifier-queue mode on all paths. func (b *Builder) WithVerify(verify bool) *Builder { b.verify = verify return b } +func (b *Builder) prepareMaster(m master, allowCDN bool) *Builder { + clone := *b + masterSchema := m + // Keep explicit switch in schema to guarantee old request path when + // CDN is disabled or unavailable. + masterSchema.allowCDN = allowCDN + clone.schema = masterSchema + clone.hashes = nil + return &clone +} + +func (b *Builder) prepareCDNPath(m master, provider CDNProvider) *Builder { + // Enable redirect errors on master schema (`upload.fileCdnRedirect`) while + // still serving regular files from master when redirect is not required. + m.allowCDN = true + + clone := *b + // Avoid outer verifier on default path to keep non-redirect requests + // equivalent to legacy master flow; CDN schema still verifies redirected + // chunks inline according to Telegram CDN protocol. + verifyCDNInline := !clone.verify + clone.hashes = nil + clone.schema = newCDNSchema( + m, + provider, + b.downloader.pool, + int64(b.threads), + verifyCDNInline, + m.retryHandler, + ) + return &clone +} + +func (b *Builder) shouldAllowCDN() bool { + // CDN redirect flow is explicit-only to avoid hidden behavior changes and + // preserve backwards compatibility for callers that never opted in. + if b.downloader.allowCDN == nil { + return false + } + return *b.downloader.allowCDN +} + +func closeSchema(s schema) func() error { + if closer, ok := s.(interface{ Close() error }); ok { + return closer.Close + } + return nil +} + +func (b *Builder) prepare() (_ *Builder, closeCDN func() error, err error) { + m, ok := b.schema.(master) + if !ok { + return b, closeSchema(b.schema), nil + } + + // Fast path compatibility guarantee: + // if CDN is not explicitly allowed we keep legacy master flow exactly as is, + // without CDN pool creation and without extra hash requests. + if !b.shouldAllowCDN() { + prepared := b.prepareMaster(m, false) + return prepared, closeSchema(prepared.schema), nil + } + + // Even with AllowCDN=true, fallback to legacy master flow if client does not + // provide CDN transport factory. + provider, hasProvider := m.client.(CDNProvider) + if !hasProvider { + prepared := b.prepareMaster(m, false) + return prepared, closeSchema(prepared.schema), nil + } + + prepared := b.prepareCDNPath(m, provider) + return prepared, closeSchema(prepared.schema), nil +} + func (b *Builder) reader() *reader { if b.verify { return verifiedReader(b.schema, newVerifier(b.schema, b.hashes...)) @@ -56,13 +161,34 @@ func (b *Builder) reader() *reader { // Stream downloads file to given io.Writer. // NB: in this mode download can't be parallel. -func (b *Builder) Stream(ctx context.Context, output io.Writer) (tg.StorageFileTypeClass, error) { - return b.downloader.stream(ctx, b.reader(), output) +func (b *Builder) Stream(ctx context.Context, output io.Writer) (_ tg.StorageFileTypeClass, err error) { + prepared, closeCDN, err := b.prepare() + if err != nil { + return nil, err + } + defer func() { + if closeCDN != nil { + multierr.AppendInto(&err, closeCDN()) + } + }() + typ, runErr := prepared.downloader.stream(ctx, prepared.reader(), output) + return typ, runErr } // Parallel downloads file to given io.WriterAt. -func (b *Builder) Parallel(ctx context.Context, output io.WriterAt) (tg.StorageFileTypeClass, error) { - return b.downloader.parallel(ctx, b.reader(), b.threads, output) +func (b *Builder) Parallel(ctx context.Context, output io.WriterAt) (_ tg.StorageFileTypeClass, err error) { + prepared, closeCDN, err := b.prepare() + if err != nil { + return nil, err + } + defer func() { + if closeCDN != nil { + multierr.AppendInto(&err, closeCDN()) + } + }() + + typ, runErr := prepared.downloader.parallel(ctx, prepared.reader(), prepared.threads, output) + return typ, runErr } // ToPath downloads file to given path. @@ -75,5 +201,6 @@ func (b *Builder) ToPath(ctx context.Context, path string) (_ tg.StorageFileType multierr.AppendInto(&err, f.Close()) }() - return b.Parallel(ctx, f) + typ, runErr := b.Parallel(ctx, f) + return typ, runErr } diff --git a/telegram/downloader/builder_prepare_test.go b/telegram/downloader/builder_prepare_test.go new file mode 100644 index 0000000000..20777b458b --- /dev/null +++ b/telegram/downloader/builder_prepare_test.go @@ -0,0 +1,98 @@ +package downloader + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBuilder_preparePaths(t *testing.T) { + newMock := func() *mock { + return &mock{data: []byte("hello")} + } + + tests := []struct { + name string + build func() *Builder + wantCDNSchema bool + wantMasterAllow bool + wantInlineCDN bool + wantOuterVerify bool + }{ + { + name: "LegacyFastPathWhenFlagUnset", + build: func() *Builder { + return NewDownloader().Download(newMock(), nil) + }, + wantCDNSchema: false, + wantMasterAllow: false, + wantOuterVerify: false, + }, + { + name: "LegacyFastPathWhenNoProvider", + build: func() *Builder { + m := newMock() + return NewDownloader().WithAllowCDN(true).Download(&noCDNClient{base: m}, nil) + }, + wantCDNSchema: false, + wantMasterAllow: false, + wantOuterVerify: false, + }, + { + name: "CDNPathDefaultEnablesInlineVerify", + build: func() *Builder { + return NewDownloader().WithAllowCDN(true).Download(newMock(), nil) + }, + wantCDNSchema: true, + wantMasterAllow: true, + wantInlineCDN: true, + wantOuterVerify: false, + }, + { + name: "CDNPathExplicitVerifyKeepsOuterVerifier", + build: func() *Builder { + return NewDownloader().WithAllowCDN(true).Download(newMock(), nil).WithVerify(true) + }, + wantCDNSchema: true, + wantMasterAllow: true, + wantInlineCDN: false, + wantOuterVerify: true, + }, + { + name: "CDNPathExplicitDisableKeepsInlineVerify", + build: func() *Builder { + return NewDownloader().WithAllowCDN(true).Download(newMock(), nil).WithVerify(false) + }, + wantCDNSchema: true, + wantMasterAllow: true, + wantInlineCDN: true, + wantOuterVerify: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + prepared, closeCDN, err := tc.build().prepare() + require.NoError(t, err) + if closeCDN != nil { + defer func() { + require.NoError(t, closeCDN()) + }() + } + + require.Equal(t, tc.wantOuterVerify, prepared.verify) + + if tc.wantCDNSchema { + s, ok := prepared.schema.(*cdn) + require.True(t, ok) + require.Equal(t, tc.wantMasterAllow, s.master.allowCDN) + require.Equal(t, tc.wantInlineCDN, s.verify) + return + } + + s, ok := prepared.schema.(master) + require.True(t, ok) + require.Equal(t, tc.wantMasterAllow, s.allowCDN) + }) + } +} diff --git a/telegram/downloader/cdn.go b/telegram/downloader/cdn.go index 4fb912e673..36cf967799 100644 --- a/telegram/downloader/cdn.go +++ b/telegram/downloader/cdn.go @@ -1,98 +1,176 @@ package downloader import ( - "context" - "crypto/aes" - "crypto/cipher" - "encoding/binary" + "io" + "sync" - "github.com/go-faster/errors" + "golang.org/x/sync/singleflight" "github.com/gotd/td/bin" "github.com/gotd/td/tg" ) -// ExpiredTokenError error is returned when Downloader get expired file token for CDN. -// See https://core.telegram.org/constructor/upload.fileCdnRedirect. -type ExpiredTokenError struct { - *tg.UploadCDNFileReuploadNeeded -} - -// Error implements error interface. -func (r *ExpiredTokenError) Error() string { - return "redirect to master DC for requesting new file token" -} - -// cdn is a CDN DC download schema. -// See https://core.telegram.org/cdn#getting-files-from-a-cdn. +// cdn is a download schema that starts on the master DC and switches to CDN on +// upload.fileCdnRedirect without losing the original request. type cdn struct { - cdn CDN + provider CDNProvider client Client pool *bin.Pool - redirect *tg.UploadFileCDNRedirect + // retryHandler observes retried transient downloader errors. + retryHandler RetryHandler + + // master preserves regular path that may return redirect errors when + // allowCDN=true. + master master + // max is forwarded to provider pool creation, usually mapped from number of + // download threads. + max int64 + // verify enables inline verification for decrypted CDN chunks. + verify bool + + // stateMux guards mode/redirect/client pointer/revision. + stateMux sync.RWMutex + // refreshMux serializes redirect refreshes so only one goroutine asks master + // for new token when CDN reports token invalid. + refreshMux sync.Mutex + // clientMux serializes CDN client (re)creation per schema instance. + clientMux sync.Mutex + // hashesMux guards in-memory cache of CDN hashes by offset. + hashesMux sync.RWMutex + // windowsMux guards bounded cache of verified CDN hash windows used to + // handle custom part sizes that split hash windows. + windowsMux sync.Mutex + // windowsLoad deduplicates concurrent fetches of the same full hash window. + windowsLoad singleflight.Group + + mode cdnMode + redirect *tg.UploadFileCDNRedirect + cdn CDN + closer io.Closer + clientDC int + rev uint64 + hashes map[int64]tg.FileHash + hashOffsets []int64 + windows map[int64][]byte + windowsFIFO []int64 } -var _ schema = cdn{} +var _ schema = (*cdn)(nil) -// decrypt decrypts file chunk from Telegram CDN. -// See https://core.telegram.org/cdn#decrypting-files. -func (c cdn) decrypt(src []byte, offset int64) ([]byte, error) { - block, err := aes.NewCipher(c.redirect.EncryptionKey) - if err != nil { - return nil, errors.Wrap(err, "create cipher") - } +type cdnMode uint8 - if block.BlockSize() != len(c.redirect.EncryptionIv) { - return nil, errors.Errorf( - "invalid IV or key length, block size %d != IV %d", - block.BlockSize(), len(c.redirect.EncryptionIv), - ) +const ( + // modeMaster means request master first and switch only on redirect. + modeMaster cdnMode = iota + // modeCDN means active redirect exists and chunks should be fetched from CDN. + modeCDN +) + +// maxVerifiedWindowCache bounds memory used by split-window verification. +// +// Split windows are needed only when downloader part size does not align with +// Telegram CDN hash window size (typically 128KB). In that case we may fetch a +// full hash window once, verify it, and reuse verified bytes for neighboring +// chunks. A small bounded cache is enough because sequential/parallel readers +// usually work on nearby offsets. +const maxVerifiedWindowCache = 16 + +func newCDNSchema( + masterSchema master, + provider CDNProvider, + pool *bin.Pool, + max int64, + verifyCDNInline bool, + retryHandler RetryHandler, +) *cdn { + if max < 1 { + max = 1 } - // Copy IV to buffer from Pool. - iv := c.pool.GetSize(len(c.redirect.EncryptionIv)) - defer c.pool.Put(iv) - copy(iv.Buf, c.redirect.EncryptionIv) + return &cdn{ + provider: provider, + client: masterSchema.client, + pool: pool, + retryHandler: retryHandler, + master: masterSchema, + max: max, + verify: verifyCDNInline, + mode: modeMaster, + } +} - // For IV, it should use the value of encryption_iv, modified in the following manner: - // for each offset replace the last 4 bytes of the encryption_iv with offset / 16 in big-endian. - binary.BigEndian.PutUint32(iv.Buf[iv.Len()-4:], uint32(offset/16)) +func (c *cdn) reportRetry(operation string, attempt int, err error) { + if attempt < 1 || err == nil || c.retryHandler == nil { + return + } + c.retryHandler(RetryEvent{ + Operation: operation, + Attempt: attempt, + Err: err, + }) +} - dst := make([]byte, len(src)) - cipher.NewCTR(block, iv.Buf).XORKeyStream(dst, src) - return dst, nil +func (c *cdn) Close() error { + // Close is called by Builder defer path and should release only schema-local + // CDN resources. Shared client-level pools are managed in telegram package. + c.stateMux.Lock() + closer := c.closer + c.cdn = nil + c.closer = nil + c.clientDC = 0 + c.stateMux.Unlock() + + if closer != nil { + return closer.Close() + } + return nil } -func (c cdn) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) { - r, err := c.cdn.UploadGetCDNFile(ctx, &tg.UploadGetCDNFileRequest{ - Offset: offset, - Limit: limit, - FileToken: c.redirect.FileToken, - }) - if err != nil { - return chunk{}, err +func (c *cdn) closeClient() { + // Internal best-effort close used on fingerprint/token recovery loops. + c.stateMux.Lock() + closer := c.closer + c.cdn = nil + c.closer = nil + c.clientDC = 0 + c.stateMux.Unlock() + + if closer != nil { + _ = closer.Close() } +} - switch result := r.(type) { - case *tg.UploadCDNFile: - data, err := c.decrypt(result.Bytes, offset) - if err != nil { - return chunk{}, err - } - - return chunk{ - data: data, - }, nil - case *tg.UploadCDNFileReuploadNeeded: - return chunk{}, &ExpiredTokenError{UploadCDNFileReuploadNeeded: result} - default: - return chunk{}, errors.Errorf("unexpected type %T", r) +func (c *cdn) snapshot() (mode cdnMode, redirect *tg.UploadFileCDNRedirect, rev uint64) { + c.stateMux.RLock() + defer c.stateMux.RUnlock() + + return c.mode, c.redirect, c.rev +} + +func (c *cdn) setRedirect(redirect *tg.UploadFileCDNRedirect) { + c.stateMux.Lock() + c.mode = modeCDN + c.redirect = redirect + c.rev++ + c.stateMux.Unlock() + + // Redirect update invalidates hash cache scope (file token / offset range + // may change). Seed with hashes returned in redirect when available. + c.resetHashes() + c.resetWindows() + if redirect != nil { + c.cacheHashes(redirect.FileHashes) } } -func (c cdn) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) { - return c.client.UploadGetCDNFileHashes(ctx, &tg.UploadGetCDNFileHashesRequest{ - FileToken: c.redirect.FileToken, - Offset: offset, - }) +func (c *cdn) setMaster() { + c.stateMux.Lock() + c.mode = modeMaster + c.redirect = nil + c.rev++ + c.stateMux.Unlock() + + // Leaving CDN mode invalidates CDN hash cache. + c.resetHashes() + c.resetWindows() } diff --git a/telegram/downloader/cdn_plan.go b/telegram/downloader/cdn_plan.go new file mode 100644 index 0000000000..e4a5f5f18b --- /dev/null +++ b/telegram/downloader/cdn_plan.go @@ -0,0 +1,64 @@ +package downloader + +import "github.com/go-faster/errors" + +type cdnRequestRange struct { + offset int64 + limit int +} + +const ( + cdnMinChunk = 4 * 1024 + cdnMaxChunk = 1024 * 1024 +) + +func largestCDNValidLimit(max int) int { + for size := max; size >= cdnMinChunk; size -= cdnMinChunk { + if cdnMaxChunk%size == 0 { + return size + } + } + return 0 +} + +func buildCDNRequestPlan(offset int64, limit int) ([]cdnRequestRange, error) { + if limit <= 0 { + return nil, errors.Errorf("invalid CDN limit %d", limit) + } + if offset < 0 { + return nil, errors.Errorf("invalid CDN offset %d", offset) + } + if offset%cdnMinChunk != 0 { + return nil, errors.Errorf("CDN offset %d must be divisible by %d", offset, cdnMinChunk) + } + if limit%cdnMinChunk != 0 { + return nil, errors.Errorf("CDN limit %d must be divisible by %d", limit, cdnMinChunk) + } + + remaining := limit + current := offset + plan := make([]cdnRequestRange, 0, 1+limit/cdnMaxChunk) + for remaining > 0 { + mbUsed := int(current % cdnMaxChunk) + mbLeft := cdnMaxChunk - mbUsed + maxForStep := remaining + if maxForStep > mbLeft { + maxForStep = mbLeft + } + // Step size is chosen from values allowed by CDN docs: + // - divisible by 4KB + // - divisor of 1MB. + step := largestCDNValidLimit(maxForStep) + if step == 0 { + return nil, errors.Errorf("unable to build CDN request plan for offset=%d limit=%d", offset, limit) + } + plan = append(plan, cdnRequestRange{ + offset: current, + limit: step, + }) + current += int64(step) + remaining -= step + } + + return plan, nil +} diff --git a/telegram/downloader/cdn_state_machine.go b/telegram/downloader/cdn_state_machine.go new file mode 100644 index 0000000000..87cd660ca3 --- /dev/null +++ b/telegram/downloader/cdn_state_machine.go @@ -0,0 +1,350 @@ +package downloader + +import ( + "context" + + "github.com/go-faster/errors" + + "github.com/gotd/td/tg" +) + +const cdnRefreshProbeLimit = 4 * 1024 + +func (c *cdn) ensureClient(ctx context.Context, dcID int) (CDN, error) { + c.clientMux.Lock() + defer c.clientMux.Unlock() + + c.stateMux.RLock() + current := c.cdn + currentDC := c.clientDC + c.stateMux.RUnlock() + if current != nil && currentDC == dcID { + return current, nil + } + + // Redirect may switch DC; recreate client lazily on demand. + c.closeClient() + + cdnClient, closer, err := c.provider.CDN(ctx, dcID, c.max) + if err != nil { + return nil, err + } + if cdnClient == nil { + if closer != nil { + _ = closer.Close() + } + return nil, errors.New("cdn provider returned nil client") + } + + c.stateMux.Lock() + c.cdn = cdnClient + c.closer = closer + c.clientDC = dcID + c.stateMux.Unlock() + + return cdnClient, nil +} + +func (c *cdn) activateRedirect(ctx context.Context, redirect *tg.UploadFileCDNRedirect) error { + if redirect == nil { + c.setMaster() + return nil + } + if _, err := c.ensureClient(ctx, redirect.DCID); err != nil { + return err + } + c.setRedirect(redirect) + return nil +} + +func (c *cdn) recoverCDNControlError( + ctx context.Context, + err error, + offset int64, + limit int, + rev uint64, + loopAttempt int, +) (fallback *chunk, retry bool, handled bool, outErr error) { + if isCDNFingerprintErr(err) { + c.closeClient() + return nil, true, true, nil + } + if !isCDNMasterFallbackErr(err) { + return nil, false, false, nil + } + + masterChunk, refreshErr := c.refreshRedirect(ctx, offset, limit, rev, loopAttempt) + if refreshErr != nil { + return nil, false, true, refreshErr + } + if masterChunk != nil { + return masterChunk, false, true, nil + } + + return nil, true, true, nil +} + +func (c *cdn) refreshRedirect( + ctx context.Context, + offset int64, + limit int, + prevRev uint64, + loopAttempt int, +) (*chunk, error) { + if limit <= 0 { + limit = cdnRefreshProbeLimit + } + + c.refreshMux.Lock() + defer c.refreshMux.Unlock() + + _, _, currentRev := c.snapshot() + if currentRev != prevRev { + // Another goroutine already refreshed state; just retry outer loop. + return nil, nil + } + + masterChunk, err := retryRequest( + ctx, + "refresh CDN redirect", + func(attempt int, err error) { + c.reportRetry(RetryOperationRefreshRedirect, attempt, err) + }, + func() (chunk, error) { + return c.master.Chunk(ctx, offset, limit) + }, + ) + if err == nil { + // Server stopped redirecting this file/token; return to master mode. + c.setMaster() + return &masterChunk, nil + } + + var redirectErr *RedirectError + if errors.As(err, &redirectErr) { + if err := c.activateRedirect(ctx, redirectErr.Redirect); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, errors.Wrapf(err, "create CDN client for DC %d", redirectErr.Redirect.DCID) + } + if isCDNFingerprintErr(err) { + c.reportRetry(RetryOperationCreateClient, loopAttempt, err) + c.closeClient() + return nil, nil + } + return nil, errors.Wrapf(err, "create CDN client for DC %d", redirectErr.Redirect.DCID) + } + return nil, nil + } + + return nil, errors.Wrap(err, "refresh CDN redirect") +} + +func (c *cdn) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) { + // Unified state machine: + // modeMaster -> try master and switch on redirect; + // modeCDN -> serve from CDN with token/keys refresh handling. + for attempt := 0; attempt < maxRetryAttempts; attempt++ { + if err := ctx.Err(); err != nil { + return chunk{}, err + } + + mode, redirect, rev := c.snapshot() + switch mode { + case modeMaster: + r, err := c.master.Chunk(ctx, offset, limit) + if err == nil { + return r, nil + } + + var redirectErr *RedirectError + if errors.As(err, &redirectErr) { + // Redirect is expected protocol path when file is CDN-backed. + if err := c.activateRedirect(ctx, redirectErr.Redirect); err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return chunk{}, errors.Wrapf(err, "create CDN client for DC %d", redirectErr.Redirect.DCID) + } + if isCDNFingerprintErr(err) { + // CDN keys changed while pool still uses stale keys. + // Close and retry so provider can reopen with fresh keys. + c.reportRetry(RetryOperationCreateClient, attempt+1, err) + c.closeClient() + continue + } + return chunk{}, errors.Wrapf(err, "create CDN client for DC %d", redirectErr.Redirect.DCID) + } + continue + } + + return chunk{}, errors.Wrapf(err, "master chunk offset=%d limit=%d", offset, limit) + + case modeCDN: + if redirect == nil { + c.setMaster() + continue + } + + cdnClient, err := c.ensureClient(ctx, redirect.DCID) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return chunk{}, errors.Wrapf(err, "create CDN client for DC %d", redirect.DCID) + } + if isCDNFingerprintErr(err) { + // Force recreate with fresh key set. + c.reportRetry(RetryOperationCreateClient, attempt+1, err) + c.closeClient() + continue + } + return chunk{}, errors.Wrapf(err, "create CDN client for DC %d", redirect.DCID) + } + + plan, err := buildCDNRequestPlan(offset, limit) + if err != nil { + return chunk{}, errors.Wrapf(err, "cdn request plan offset=%d limit=%d", offset, limit) + } + + data := make([]byte, 0, limit) + retryChunk := false + + partLoop: + for _, req := range plan { + result, err := cdnClient.UploadGetCDNFile(ctx, &tg.UploadGetCDNFileRequest{ + Offset: req.offset, + Limit: req.limit, + FileToken: redirect.FileToken, + }) + if err != nil { + fallback, retry, handled, recoverErr := c.recoverCDNControlError(ctx, err, offset, limit, rev, attempt+1) + if recoverErr != nil { + return chunk{}, recoverErr + } + if handled { + if fallback != nil { + return *fallback, nil + } + if retry { + c.reportRetry(RetryOperationGetFile, attempt+1, err) + retryChunk = true + break partLoop + } + continue + } + return chunk{}, errors.Wrapf( + err, + "cdn chunk dc=%d offset=%d limit=%d", + redirect.DCID, req.offset, req.limit, + ) + } + + switch typed := result.(type) { + case *tg.UploadCDNFile: + part, err := c.decrypt(typed.Bytes, req.offset, redirect) + if err != nil { + return chunk{}, err + } + data = append(data, part...) + if len(part) < req.limit { + // Reached file tail, remaining plan segments are beyond EOF. + break partLoop + } + + case *tg.UploadCDNFileReuploadNeeded: + // Ask master DC to reissue CDN token window for this file. + hashes, err := c.client.UploadReuploadCDNFile(ctx, &tg.UploadReuploadCDNFileRequest{ + FileToken: redirect.FileToken, + RequestToken: typed.RequestToken, + }) + if err != nil { + fallback, retry, handled, recoverErr := c.recoverCDNControlError(ctx, err, offset, limit, rev, attempt+1) + if recoverErr != nil { + return chunk{}, recoverErr + } + if handled { + if fallback != nil { + return *fallback, nil + } + if retry { + c.reportRetry(RetryOperationReupload, attempt+1, err) + retryChunk = true + break partLoop + } + continue + } + return chunk{}, errors.Wrapf( + err, + "cdn reupload dc=%d offset=%d limit=%d", + redirect.DCID, req.offset, req.limit, + ) + } + // Reupload returns fresh CDN hashes for the requested token + // window. Cache them immediately (same strategy as TDesktop) to + // avoid an extra UploadGetCDNFileHashes call on retry. + c.cacheHashes(hashes) + retryChunk = true + break partLoop + + default: + return chunk{}, errors.Errorf("unexpected type %T", result) + } + } + if retryChunk { + continue + } + + if err := c.verifyChunk(ctx, offset, limit, data); err != nil { + return chunk{}, err + } + return chunk{data: data}, nil + } + } + + return chunk{}, retryLimitErr("cdn chunk", maxRetryAttempts, errors.New("state loop")) +} + +func (c *cdn) Hashes(ctx context.Context, offset int64) ([]tg.FileHash, error) { + // Hash retrieval follows same state machine as chunks to stay consistent + // during concurrent token/redirect changes. + for attempt := 0; attempt < maxRetryAttempts; attempt++ { + if err := ctx.Err(); err != nil { + return nil, err + } + + mode, redirect, rev := c.snapshot() + switch mode { + case modeMaster: + hashes, err := c.master.Hashes(ctx, offset) + if err != nil { + return nil, errors.Wrapf(err, "master hashes offset=%d", offset) + } + return hashes, nil + + case modeCDN: + if redirect == nil { + c.setMaster() + continue + } + + hashes, err := c.client.UploadGetCDNFileHashes(ctx, &tg.UploadGetCDNFileHashesRequest{ + FileToken: redirect.FileToken, + Offset: offset, + }) + if err != nil { + _, retry, handled, recoverErr := c.recoverCDNControlError(ctx, err, offset, cdnRefreshProbeLimit, rev, attempt+1) + if recoverErr != nil { + return nil, recoverErr + } + if handled && retry { + c.reportRetry(RetryOperationGetFileHashes, attempt+1, err) + continue + } + if handled { + continue + } + return nil, errors.Wrapf(err, "cdn hashes dc=%d offset=%d", redirect.DCID, offset) + } + c.cacheHashes(hashes) + return hashes, nil + } + } + + return nil, retryLimitErr("cdn hashes", maxRetryAttempts, errors.New("state loop")) +} diff --git a/telegram/downloader/cdn_test.go b/telegram/downloader/cdn_test.go index 2b2dd91a9b..e81abff5eb 100644 --- a/telegram/downloader/cdn_test.go +++ b/telegram/downloader/cdn_test.go @@ -1,6 +1,7 @@ package downloader import ( + "github.com/gotd/td/crypto" "testing" "github.com/stretchr/testify/require" @@ -27,10 +28,164 @@ func Test_cdn_decrypt(t *testing.T) { EncryptionIv: test.iv, }, } - _, err := c.decrypt(testdata, 0) + _, err := c.decrypt(testdata, 0, c.redirect) if test.err { require.Error(t, err) } }) } } + +func Test_cdn_hashLookupUnalignedOffset(t *testing.T) { + c := &cdn{} + c.cacheHashes([]tg.FileHash{ + {Offset: 128, Limit: 128}, + {Offset: 0, Limit: 128}, + {Offset: 256, Limit: 64}, + }) + + type check struct { + offset int64 + ok bool + start int64 + } + checks := []check{ + {offset: 0, ok: true, start: 0}, + {offset: 64, ok: true, start: 0}, + {offset: 190, ok: true, start: 128}, + {offset: 319, ok: true, start: 256}, + {offset: 320, ok: false}, + } + for _, tc := range checks { + got, ok := c.hash(tc.offset) + require.Equal(t, tc.ok, ok) + if !tc.ok { + continue + } + require.Equal(t, tc.start, got.Offset) + } +} + +func Test_cdn_cacheHashesMaintainsSortedOffsets(t *testing.T) { + c := &cdn{} + + c.cacheHashes([]tg.FileHash{ + {Offset: 256, Limit: 128}, + {Offset: 0, Limit: 128}, + {Offset: 128, Limit: 128}, + {Offset: 128, Limit: 64}, // overwrite existing window, no duplicate offset. + }) + + require.Equal(t, []int64{0, 128, 256}, c.hashOffsets) + h, ok := c.hash(160) + require.True(t, ok) + require.EqualValues(t, 128, h.Offset) + require.EqualValues(t, 64, h.Limit) +} + +func Test_cdn_stateTransitions(t *testing.T) { + makeHash := func(offset int64) tg.FileHash { + payload := []byte{byte(offset), byte(offset + 1), byte(offset + 2), byte(offset + 3)} + return tg.FileHash{ + Offset: offset, + Limit: len(payload), + Hash: crypto.SHA256(payload), + } + } + makeRedirect := func(offset int64) *tg.UploadFileCDNRedirect { + hash := makeHash(offset) + return &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{1, 2, 3}, + FileHashes: []tg.FileHash{hash}, + } + } + + redirectA := makeRedirect(0) + redirectB := makeRedirect(64) + + tests := []struct { + name string + setup func(c *cdn) + action func(c *cdn) + check func(t *testing.T, c *cdn, prevRev uint64) + }{ + { + name: "MasterToCDNSeedsHashesAndClearsWindows", + setup: func(c *cdn) { + c.cacheHashes([]tg.FileHash{makeHash(32)}) + c.cacheWindow(makeHash(32), []byte{1, 2, 3, 4}) + }, + action: func(c *cdn) { + c.setRedirect(redirectA) + }, + check: func(t *testing.T, c *cdn, prevRev uint64) { + require.Equal(t, modeCDN, c.mode) + require.Same(t, redirectA, c.redirect) + require.Equal(t, prevRev+1, c.rev) + _, oldOK := c.hash(32) + require.False(t, oldOK, "old hash scope should be dropped on redirect") + seeded, seededOK := c.hash(0) + require.True(t, seededOK) + require.EqualValues(t, 0, seeded.Offset) + require.Nil(t, c.windows) + require.Nil(t, c.windowsFIFO) + }, + }, + { + name: "CDNToMasterClearsRedirectAndCaches", + setup: func(c *cdn) { + c.setRedirect(redirectA) + c.cacheWindow(redirectA.FileHashes[0], []byte{5, 6, 7, 8}) + }, + action: func(c *cdn) { + c.setMaster() + }, + check: func(t *testing.T, c *cdn, prevRev uint64) { + require.Equal(t, modeMaster, c.mode) + require.Nil(t, c.redirect) + require.Equal(t, prevRev+1, c.rev) + _, ok := c.hash(0) + require.False(t, ok, "master mode should not keep CDN hash cache") + require.Nil(t, c.windows) + require.Nil(t, c.windowsFIFO) + }, + }, + { + name: "CDNToCDNReplacesScope", + setup: func(c *cdn) { + c.setRedirect(redirectA) + c.cacheWindow(redirectA.FileHashes[0], []byte{9, 9, 9, 9}) + }, + action: func(c *cdn) { + c.setRedirect(redirectB) + }, + check: func(t *testing.T, c *cdn, prevRev uint64) { + require.Equal(t, modeCDN, c.mode) + require.Same(t, redirectB, c.redirect) + require.Equal(t, prevRev+1, c.rev) + _, oldOK := c.hash(0) + require.False(t, oldOK, "previous redirect hash scope must be invalidated") + seeded, newOK := c.hash(64) + require.True(t, newOK) + require.EqualValues(t, 64, seeded.Offset) + require.Nil(t, c.windows) + require.Nil(t, c.windowsFIFO) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + c := &cdn{ + mode: modeMaster, + } + if tc.setup != nil { + tc.setup(c) + } + prevRev := c.rev + tc.action(c) + tc.check(t, c, prevRev) + }) + } +} diff --git a/telegram/downloader/cdn_verify.go b/telegram/downloader/cdn_verify.go new file mode 100644 index 0000000000..ce80aa32f4 --- /dev/null +++ b/telegram/downloader/cdn_verify.go @@ -0,0 +1,335 @@ +package downloader + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "encoding/binary" + "sort" + "strconv" + + "github.com/go-faster/errors" + + "github.com/gotd/td/crypto" + "github.com/gotd/td/tg" +) + +func (c *cdn) resetHashes() { + c.hashesMux.Lock() + c.hashes = nil + c.hashOffsets = nil + c.hashesMux.Unlock() +} + +func (c *cdn) resetWindows() { + c.windowsMux.Lock() + // Drop all previously verified window payloads whenever redirect/token scope + // changes to avoid mixing data that belongs to different CDN contexts. + c.windows = nil + c.windowsFIFO = nil + c.windowsMux.Unlock() +} + +func (c *cdn) cachedWindow(hash tg.FileHash) ([]byte, bool) { + c.windowsMux.Lock() + defer c.windowsMux.Unlock() + + data, ok := c.windows[hash.Offset] + if !ok { + return nil, false + } + if len(data) == 0 || len(data) > hash.Limit { + return nil, false + } + + // Returned slice is treated as read-only by callers. + return data, true +} + +func (c *cdn) cacheWindow(hash tg.FileHash, data []byte) { + if hash.Limit <= 0 || len(data) == 0 || len(data) > hash.Limit { + return + } + + c.windowsMux.Lock() + defer c.windowsMux.Unlock() + if c.windows == nil { + c.windows = make(map[int64][]byte) + } + if _, ok := c.windows[hash.Offset]; !ok { + c.windowsFIFO = append(c.windowsFIFO, hash.Offset) + } + // Store a copy to keep cache immutable relative to caller buffers. + c.windows[hash.Offset] = append([]byte(nil), data...) + for len(c.windowsFIFO) > maxVerifiedWindowCache { + evict := c.windowsFIFO[0] + c.windowsFIFO = c.windowsFIFO[1:] + // FIFO eviction is enough here: access pattern is near-sequential and we + // only need to cap memory, not optimize for perfect hit rate. + delete(c.windows, evict) + } +} + +func (c *cdn) cacheHashes(hashes []tg.FileHash) { + if len(hashes) == 0 { + return + } + + c.hashesMux.Lock() + if c.hashes == nil { + c.hashes = make(map[int64]tg.FileHash, len(hashes)) + } + for _, hash := range hashes { + if hash.Limit <= 0 { + continue + } + + // Keep sorted unique offsets index for O(log n) range lookup. + if _, exists := c.hashes[hash.Offset]; !exists { + idx := sort.Search(len(c.hashOffsets), func(i int) bool { + return c.hashOffsets[i] >= hash.Offset + }) + if idx == len(c.hashOffsets) { + c.hashOffsets = append(c.hashOffsets, hash.Offset) + } else if c.hashOffsets[idx] != hash.Offset { + c.hashOffsets = append(c.hashOffsets, 0) + copy(c.hashOffsets[idx+1:], c.hashOffsets[idx:]) + c.hashOffsets[idx] = hash.Offset + } + } + c.hashes[hash.Offset] = hash + } + c.hashesMux.Unlock() +} + +func (c *cdn) hash(offset int64) (tg.FileHash, bool) { + c.hashesMux.RLock() + hash, ok := c.hashes[offset] + if ok { + c.hashesMux.RUnlock() + return hash, true + } + + // Fast-path map lookup works only for exact hash offsets. For unaligned + // part sizes resolve containing window by predecessor offset in sorted index. + if len(c.hashOffsets) == 0 { + c.hashesMux.RUnlock() + return tg.FileHash{}, false + } + + idx := sort.Search(len(c.hashOffsets), func(i int) bool { + return c.hashOffsets[i] > offset + }) - 1 + if idx < 0 { + c.hashesMux.RUnlock() + return tg.FileHash{}, false + } + + candidate, exists := c.hashes[c.hashOffsets[idx]] + if !exists || candidate.Limit <= 0 { + c.hashesMux.RUnlock() + return tg.FileHash{}, false + } + end := candidate.Offset + int64(candidate.Limit) + hash = candidate + ok = offset >= candidate.Offset && offset < end + c.hashesMux.RUnlock() + return hash, ok +} + +func (c *cdn) hashForOffset(ctx context.Context, offset int64) (tg.FileHash, error) { + if hash, ok := c.hash(offset); ok { + return hash, nil + } + + // Ask server for the current offset window and cache returned range. + for attempt := 0; attempt < maxRetryAttempts; attempt++ { + if err := ctx.Err(); err != nil { + return tg.FileHash{}, err + } + + hashes, err := c.Hashes(ctx, offset) + if err != nil { + return tg.FileHash{}, errors.Wrapf(err, "load CDN hashes at offset=%d", offset) + } + // Cache batch and retry lookup: server may return a range of windows + // where requested offset is not the first element. + c.cacheHashes(hashes) + if hash, ok := c.hash(offset); ok { + return hash, nil + } + } + + return tg.FileHash{}, retryLimitErr( + "cdn hash lookup", + maxRetryAttempts, + errors.Errorf("hash for offset %d not found", offset), + ) +} + +func windowLoadKey(hash tg.FileHash) string { + key := make([]byte, 0, len(hash.Hash)+64) + key = strconv.AppendInt(key, hash.Offset, 10) + key = append(key, ':') + key = strconv.AppendInt(key, int64(hash.Limit), 10) + key = append(key, ':') + key = append(key, hash.Hash...) + return string(key) +} + +func (c *cdn) loadAndVerifyWindow(ctx context.Context, hash tg.FileHash) ([]byte, error) { + if window, ok := c.cachedWindow(hash); ok { + return window, nil + } + + key := windowLoadKey(hash) + v, err, _ := c.windowsLoad.Do(key, func() (interface{}, error) { + if window, ok := c.cachedWindow(hash); ok { + return window, nil + } + + // Fetching a whole window can return a shorter payload only for the + // last file segment, so we accept len(full.data) <= hash.Limit. + full, err := c.Chunk(ctx, hash.Offset, hash.Limit) + if err != nil { + return nil, errors.Wrapf(err, "load full hash window at offset=%d limit=%d", hash.Offset, hash.Limit) + } + if len(full.data) == 0 || len(full.data) > hash.Limit { + return nil, errors.Errorf( + "invalid CDN window length at offset=%d max=%d got=%d", + hash.Offset, hash.Limit, len(full.data), + ) + } + if !bytes.Equal(crypto.SHA256(full.data), hash.Hash) { + return nil, errors.Wrapf( + ErrHashMismatch, + "at offset=%d size=%d", + hash.Offset, hash.Limit, + ) + } + + c.cacheWindow(hash, full.data) + return full.data, nil + }) + if err != nil { + return nil, err + } + + window, ok := v.([]byte) + if !ok { + return nil, errors.Errorf("unexpected window type %T", v) + } + return window, nil +} + +func (c *cdn) verifyChunk(ctx context.Context, offset int64, requestedLimit int, data []byte) error { + if !c.verify || len(data) == 0 { + return nil + } + shortResponse := requestedLimit > 0 && len(data) < requestedLimit + + // Inline mode validates every hash window covered by this chunk. + // For windows split by custom part sizes, we load and verify the full window + // (cached) and then patch overlapping bytes in this chunk with verified data. + chunkStart := offset + chunkEnd := offset + int64(len(data)) + for current := chunkStart; current < chunkEnd; { + hash, err := c.hashForOffset(ctx, current) + if err != nil { + return err + } + if hash.Limit <= 0 { + return errors.Errorf("invalid CDN hash limit %d at offset %d", hash.Limit, current) + } + windowStart := hash.Offset + windowEnd := hash.Offset + int64(hash.Limit) + if windowEnd <= current { + return errors.Errorf("invalid CDN hash window [%d,%d) at offset %d", windowStart, windowEnd, current) + } + + switch { + case windowStart >= chunkStart && windowEnd <= chunkEnd: + // Full hash window is present in current chunk: verify directly. + from := int(windowStart - chunkStart) + to := int(windowEnd - chunkStart) + if !bytes.Equal(crypto.SHA256(data[from:to]), hash.Hash) { + return errors.Wrapf( + ErrHashMismatch, + "at offset=%d size=%d", + windowStart, hash.Limit, + ) + } + + case shortResponse && windowStart >= chunkStart && windowStart < chunkEnd && windowEnd > chunkEnd: + // Final short chunk: Telegram keeps nominal hash limit, but hash is + // computed on actual remaining tail bytes. + from := int(windowStart - chunkStart) + if !bytes.Equal(crypto.SHA256(data[from:]), hash.Hash) { + return ErrHashMismatch + } + return nil + + default: + // Hash window crosses current chunk boundaries. + // + // TDesktop-style behavior: validate complete hash window and then apply + // verified overlap to current chunk. This preserves integrity checks for + // custom part sizes without forcing eager verifier mode globally. + window, err := c.loadAndVerifyWindow(ctx, hash) + if err != nil { + return err + } + + overlapStart := chunkStart + if windowStart > overlapStart { + overlapStart = windowStart + } + windowDataEnd := windowStart + int64(len(window)) + overlapEnd := chunkEnd + if windowDataEnd < overlapEnd { + overlapEnd = windowDataEnd + } + if overlapEnd <= overlapStart { + return errors.Errorf( + "invalid overlap for hash window [%d,%d) and chunk [%d,%d)", + windowStart, windowEnd, chunkStart, chunkEnd, + ) + } + chunkFrom := int(overlapStart - chunkStart) + chunkTo := int(overlapEnd - chunkStart) + windowFrom := int(overlapStart - windowStart) + windowTo := int(overlapEnd - windowStart) + // Replace bytes in-place with verified overlap so the caller receives a + // fully verified chunk even when hash windows are split by part size. + copy(data[chunkFrom:chunkTo], window[windowFrom:windowTo]) + } + current = windowEnd + } + return nil +} + +// decrypt decrypts file chunk from Telegram CDN. +// See https://core.telegram.org/cdn#decrypting-files. +func (c *cdn) decrypt(src []byte, offset int64, redirect *tg.UploadFileCDNRedirect) ([]byte, error) { + block, err := aes.NewCipher(redirect.EncryptionKey) + if err != nil { + return nil, errors.Wrap(err, "create cipher") + } + if block.BlockSize() != len(redirect.EncryptionIv) { + return nil, errors.Errorf( + "invalid IV or key length, block size %d != IV %d", + block.BlockSize(), len(redirect.EncryptionIv), + ) + } + + iv := c.pool.GetSize(len(redirect.EncryptionIv)) + defer c.pool.Put(iv) + copy(iv.Buf, redirect.EncryptionIv) + + binary.BigEndian.PutUint32(iv.Buf[iv.Len()-4:], uint32(offset/16)) + + dst := make([]byte, len(src)) + cipher.NewCTR(block, iv.Buf).XORKeyStream(dst, src) + return dst, nil +} diff --git a/telegram/downloader/client.go b/telegram/downloader/client.go index c365a4571c..b876b56d17 100644 --- a/telegram/downloader/client.go +++ b/telegram/downloader/client.go @@ -2,6 +2,7 @@ package downloader import ( "context" + "io" "github.com/gotd/td/tg" ) @@ -22,6 +23,13 @@ type Client interface { UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error) } +// CDNProvider creates client bound to requested CDN DC. +// Returned closer is schema-scoped; for shared client-level pool adapters this +// can be a no-op closer. +type CDNProvider interface { + CDN(ctx context.Context, dc int, max int64) (CDN, io.Closer, error) +} + type chunk struct { data []byte tag tg.StorageFileTypeClass diff --git a/telegram/downloader/coverage_extra_test.go b/telegram/downloader/coverage_extra_test.go new file mode 100644 index 0000000000..9734684d4c --- /dev/null +++ b/telegram/downloader/coverage_extra_test.go @@ -0,0 +1,136 @@ +package downloader + +import ( + "context" + "errors" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gotd/td/tg" + "github.com/gotd/td/tgerr" +) + +func TestRetryLimitErrWraps(t *testing.T) { + baseErr := errors.New("boom") + err := retryLimitErr("download", 3, baseErr) + require.ErrorIs(t, err, baseErr) + require.Contains(t, err.Error(), "download: retry limit reached (3)") +} + +func TestRetryRequestTimeoutLimit(t *testing.T) { + var attempts []int + calls := 0 + + _, err := retryRequest[int]( + context.Background(), + "read chunk", + func(attempt int, err error) { + require.Error(t, err) + attempts = append(attempts, attempt) + }, + func() (int, error) { + calls++ + return 0, tgerr.New(500, tg.ErrTimeout) + }, + ) + require.Error(t, err) + require.Equal(t, maxRetryAttempts, calls) + require.Len(t, attempts, maxRetryAttempts-1) + require.Equal(t, maxRetryAttempts-1, attempts[len(attempts)-1]) + require.Contains(t, err.Error(), "read chunk: retry limit reached") +} + +func TestCDNPlanValidationAndSplit(t *testing.T) { + _, err := buildCDNRequestPlan(0, 0) + require.ErrorContains(t, err, "invalid CDN limit") + + _, err = buildCDNRequestPlan(-1, cdnMinChunk) + require.ErrorContains(t, err, "invalid CDN offset") + + _, err = buildCDNRequestPlan(1, cdnMinChunk) + require.ErrorContains(t, err, "must be divisible") + + _, err = buildCDNRequestPlan(0, cdnMinChunk+1) + require.ErrorContains(t, err, "must be divisible") + + plan, err := buildCDNRequestPlan(cdnMaxChunk-int64(cdnMinChunk), 2*cdnMinChunk) + require.NoError(t, err) + require.Equal(t, []cdnRequestRange{ + {offset: cdnMaxChunk - int64(cdnMinChunk), limit: cdnMinChunk}, + {offset: cdnMaxChunk, limit: cdnMinChunk}, + }, plan) + + require.Zero(t, largestCDNValidLimit(cdnMinChunk-1)) + require.Equal(t, cdnMinChunk, largestCDNValidLimit(cdnMinChunk)) +} + +func TestWebReportRetryAndHashes(t *testing.T) { + var events []RetryEvent + w := web{ + retryHandler: func(event RetryEvent) { + events = append(events, event) + }, + } + + w.reportRetry("noop", 0, errors.New("skip")) + w.reportRetry("noop", 1, nil) + require.Empty(t, events) + + retryErr := errors.New("retry") + w.reportRetry("getWebFile", 2, retryErr) + require.Len(t, events, 1) + require.Equal(t, "getWebFile", events[0].Operation) + require.Equal(t, 2, events[0].Attempt) + require.ErrorIs(t, events[0].Err, retryErr) + + hashes, err := w.Hashes(context.Background(), 0) + require.Nil(t, hashes) + require.ErrorIs(t, err, errHashesNotSupported) +} + +func TestBuilderToPath(t *testing.T) { + ctx := context.Background() + data := []byte("hello downloader") + m := &mock{data: data} + + path := filepath.Join(t.TempDir(), "out.bin") + typ, err := NewDownloader(). + Download(m, nil). + WithThreads(2). + ToPath(ctx, path) + require.NoError(t, err) + require.Nil(t, typ) + + content, err := os.ReadFile(path) + require.NoError(t, err) + require.Equal(t, data, content) +} + +func TestBuilderToPathCreateError(t *testing.T) { + ctx := context.Background() + m := &mock{data: []byte("x")} + + path := filepath.Join(t.TempDir(), "missing", "out.bin") + _, err := NewDownloader().Download(m, nil).ToPath(ctx, path) + require.Error(t, err) + require.Contains(t, err.Error(), "create output file") +} + +func TestDownloaderWithRetryHandlerAndRedirectError(t *testing.T) { + var called bool + handler := func(RetryEvent) { + called = true + } + + d := NewDownloader() + require.Same(t, d, d.WithRetryHandler(handler)) + require.NotNil(t, d.retryHandler) + d.retryHandler(RetryEvent{Operation: "x", Attempt: 1, Err: errors.New("e")}) + require.True(t, called) + + err := (&RedirectError{Redirect: &tg.UploadFileCDNRedirect{DCID: 203}}).Error() + require.Equal(t, "redirect to CDN DC 203", err) +} diff --git a/telegram/downloader/downloader.go b/telegram/downloader/downloader.go index 57a5d9cf15..cc4061c36d 100644 --- a/telegram/downloader/downloader.go +++ b/telegram/downloader/downloader.go @@ -10,6 +10,13 @@ import ( type Downloader struct { partSize int pool *bin.Pool + // allowCDN is tri-state: + // nil -> keep default downloader behavior (CDN disabled), + // true -> allow redirect flow, + // false-> force legacy master-only flow. + allowCDN *bool + // retryHandler observes transient downloader errors that are retried. + retryHandler RetryHandler } const defaultPartSize = 512 * 1024 // 512 kb @@ -29,20 +36,42 @@ func (d *Downloader) WithPartSize(partSize int) *Downloader { return d } +// WithAllowCDN explicitly enables or disables CDN redirect flow. +// +// This flag is explicit: if it is not set, downloader keeps legacy +// master-DC-only behavior and does not attempt CDN redirect handling. +// Client integration (`telegram.Client.Downloader`) sets this option from +// `telegram.Options.AllowCDN`. +func (d *Downloader) WithAllowCDN(allow bool) *Downloader { + d.allowCDN = &allow + return d +} + +// WithRetryHandler sets callback for transient download errors that are retried +// internally by downloader. +// +// Handler can be called concurrently from download workers. +func (d *Downloader) WithRetryHandler(handler RetryHandler) *Downloader { + d.retryHandler = handler + return d +} + // Download creates Builder for plain downloads. func (d *Downloader) Download(rpc Client, location tg.InputFileLocationClass) *Builder { return newBuilder(d, master{ - client: rpc, - precise: true, - allowCDN: false, - location: location, + client: rpc, + precise: true, + allowCDN: false, + retryHandler: d.retryHandler, + location: location, }) } // Web creates Builder for web files downloads. func (d *Downloader) Web(rpc Client, location tg.InputWebFileLocationClass) *Builder { return newBuilder(d, web{ - client: rpc, - location: location, + client: rpc, + retryHandler: d.retryHandler, + location: location, }) } diff --git a/telegram/downloader/downloader_test.go b/telegram/downloader/downloader_test.go index e0050d4bfc..467483c982 100644 --- a/telegram/downloader/downloader_test.go +++ b/telegram/downloader/downloader_test.go @@ -10,29 +10,89 @@ import ( "io" "runtime" "strconv" + "sync" + "sync/atomic" "testing" + "time" "github.com/go-faster/errors" "github.com/stretchr/testify/require" "github.com/gotd/td/crypto" + "github.com/gotd/td/exchange" "github.com/gotd/td/syncio" "github.com/gotd/td/testutil" "github.com/gotd/td/tg" + "github.com/gotd/td/tgerr" ) type mock struct { - data []byte - hashes mockHashes + data []byte + hashes mockHashes + // reupload emulates hashes returned by UploadReuploadCDNFile. When empty, + // mock keeps old behavior and returns nil hashes. + reupload []tg.FileHash migrate bool err bool hashesErr bool redirect *tg.UploadFileCDNRedirect + // enforceCDNRequestRules enables strict Telegram CDN parameter checks + // from docs: offset/limit are 4KB-aligned, limit divides 1MB and request + // stays within a single 1MB window. + enforceCDNRequestRules bool + // If > 0, redirect starts from this offset (when cdn_supported is set). + redirectAtOffset int64 + + // trackWindow* is test-only instrumentation for full hash-window fetches + // used by split-window verification path. + trackWindowOffset int64 + trackWindowLimit int + trackWindowBlock <-chan struct{} + trackWindowCalls atomic.Int32 + + migrateOnce atomic.Bool + reuploadNeeded atomic.Bool + cdnUploadTO atomic.Bool + tokenInvalid atomic.Bool + getTimeout atomic.Bool + cdnGetTimeout atomic.Bool + cdnHashTimeout atomic.Bool + cdnFingerprint atomic.Bool + cdnHashFP atomic.Bool + getFileCalls atomic.Int32 + hashesCalls atomic.Int32 + cdnGetCalls atomic.Int32 + cdnReupCalls atomic.Int32 + cdnHashCalls atomic.Int32 } var testErr = testutil.TestError() -func (m mock) getPart(offset int64, limit int) []byte { +func validCDNRequest(offset int64, limit int) bool { + if limit <= 0 { + return false + } + if offset < 0 { + return false + } + if offset%4096 != 0 { + return false + } + if limit%4096 != 0 { + return false + } + const oneMB = 1024 * 1024 + if oneMB%limit != 0 { + return false + } + end := offset + int64(limit) - 1 + if end < offset { + return false + } + return (offset / oneMB) == (end / oneMB) +} + +func (m *mock) getPart(offset int64, limit int) []byte { length := len(m.data) if offset >= int64(length) { return []byte{} @@ -48,12 +108,24 @@ func (m mock) getPart(offset int64, limit int) []byte { return r } -func (m mock) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) { +func (m *mock) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) { + m.getFileCalls.Add(1) if m.err { return nil, testErr } + if m.getTimeout.CompareAndSwap(true, false) { + return nil, tgerr.New(500, tg.ErrTimeout) + } + + if request.GetCDNSupported() && m.migrateOnce.CompareAndSwap(true, false) { + return m.redirect, nil + } + + if request.GetCDNSupported() && m.redirectAtOffset > 0 && request.Offset >= m.redirectAtOffset { + return m.redirect, nil + } - if m.migrate { + if request.GetCDNSupported() && m.migrate { return m.redirect, nil } @@ -62,7 +134,8 @@ func (m mock) UploadGetFile(ctx context.Context, request *tg.UploadGetFileReques }, nil } -func (m mock) UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error) { +func (m *mock) UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error) { + m.hashesCalls.Add(1) if m.hashesErr { return nil, testErr } @@ -70,16 +143,58 @@ func (m mock) UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFile return m.hashes.Hashes(ctx, request.Offset) } -func (m mock) UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error) { - panic("implement me") +func (m *mock) UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error) { + m.cdnReupCalls.Add(1) + if m.err { + return nil, testErr + } + if m.cdnUploadTO.CompareAndSwap(true, false) { + return nil, tgerr.New(500, "CDN_UPLOAD_TIMEOUT") + } + + // Explicit copy avoids accidental aliasing between downloader cache and test + // fixture slices. + if len(m.reupload) == 0 { + return nil, nil + } + + r := make([]tg.FileHash, len(m.reupload)) + copy(r, m.reupload) + return r, nil } -func (m mock) UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error) { +func (m *mock) UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error) { + m.cdnGetCalls.Add(1) if m.err { return nil, testErr } + if m.trackWindowLimit > 0 && + request.Offset == m.trackWindowOffset && + request.Limit == m.trackWindowLimit { + m.trackWindowCalls.Add(1) + if m.trackWindowBlock != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-m.trackWindowBlock: + } + } + } + if m.enforceCDNRequestRules && !validCDNRequest(request.Offset, request.Limit) { + return nil, tgerr.New(400, "LIMIT_INVALID") + } + if m.cdnGetTimeout.CompareAndSwap(true, false) { + return nil, tgerr.New(500, tg.ErrTimeout) + } + if m.cdnFingerprint.CompareAndSwap(true, false) { + return nil, exchange.ErrKeyFingerprintNotFound + } + + if m.tokenInvalid.CompareAndSwap(true, false) { + return nil, tgerr.New(400, "FILE_TOKEN_INVALID") + } - if m.migrate { + if m.reuploadNeeded.CompareAndSwap(true, false) { return &tg.UploadCDNFileReuploadNeeded{ RequestToken: []byte{1, 2, 3}, }, nil @@ -102,7 +217,15 @@ func (m mock) UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFile }, nil } -func (m mock) UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error) { +func (m *mock) UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error) { + m.cdnHashCalls.Add(1) + m.hashesCalls.Add(1) + if m.cdnHashTimeout.CompareAndSwap(true, false) { + return nil, tgerr.New(500, tg.ErrTimeout) + } + if m.cdnHashFP.CompareAndSwap(true, false) { + return nil, exchange.ErrKeyFingerprintNotFound + } if m.hashesErr { return nil, testErr } @@ -110,7 +233,7 @@ func (m mock) UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetC return m.hashes.Hashes(ctx, request.Offset) } -func (m mock) UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error) { +func (m *mock) UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error) { if m.err { return nil, testErr } @@ -120,6 +243,148 @@ func (m mock) UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFile }, nil } +type noopCloser struct{} + +func (noopCloser) Close() error { + return nil +} + +func (m *mock) CDN(ctx context.Context, dc int, max int64) (CDN, io.Closer, error) { + return m, noopCloser{}, nil +} + +type noCDNClient struct { + base *mock +} + +func (c *noCDNClient) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) { + return c.base.UploadGetFile(ctx, request) +} + +func (c *noCDNClient) UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error) { + return c.base.UploadGetFileHashes(ctx, request) +} + +func (c *noCDNClient) UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error) { + return c.base.UploadReuploadCDNFile(ctx, request) +} + +func (c *noCDNClient) UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error) { + return c.base.UploadGetCDNFileHashes(ctx, request) +} + +func (c *noCDNClient) UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error) { + return c.base.UploadGetWebFile(ctx, request) +} + +type nilCDNProvider struct { + *mock +} + +func (c *nilCDNProvider) CDN(ctx context.Context, dc int, max int64) (CDN, io.Closer, error) { + return nil, noopCloser{}, nil +} + +type errCDNProvider struct { + *mock + err error +} + +func (c *errCDNProvider) CDN(ctx context.Context, dc int, max int64) (CDN, io.Closer, error) { + return nil, nil, c.err +} + +// retryAttemptProvider emulates redirect refresh to another DC and one +// fingerprint error during client creation on refresh path. +type retryAttemptProvider struct { + base *mock + + initialRedirect *tg.UploadFileCDNRedirect + refreshRedirect *tg.UploadFileCDNRedirect + + masterCalls atomic.Int32 + cdnCalls atomic.Int32 +} + +func (p *retryAttemptProvider) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) { + if request.GetCDNSupported() { + if p.masterCalls.Add(1) == 1 { + return p.initialRedirect, nil + } + return p.refreshRedirect, nil + } + return p.base.UploadGetFile(ctx, request) +} + +func (p *retryAttemptProvider) UploadGetFileHashes(ctx context.Context, request *tg.UploadGetFileHashesRequest) ([]tg.FileHash, error) { + return p.base.UploadGetFileHashes(ctx, request) +} + +func (p *retryAttemptProvider) UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error) { + return p.base.UploadReuploadCDNFile(ctx, request) +} + +func (p *retryAttemptProvider) UploadGetCDNFileHashes(ctx context.Context, request *tg.UploadGetCDNFileHashesRequest) ([]tg.FileHash, error) { + return p.base.UploadGetCDNFileHashes(ctx, request) +} + +func (p *retryAttemptProvider) UploadGetWebFile(ctx context.Context, request *tg.UploadGetWebFileRequest) (*tg.UploadWebFile, error) { + return p.base.UploadGetWebFile(ctx, request) +} + +func (p *retryAttemptProvider) UploadGetCDNFile(ctx context.Context, request *tg.UploadGetCDNFileRequest) (tg.UploadCDNFileClass, error) { + return p.base.UploadGetCDNFile(ctx, request) +} + +func (p *retryAttemptProvider) CDN(ctx context.Context, dc int, max int64) (CDN, io.Closer, error) { + if p.cdnCalls.Add(1) == 3 { + return nil, nil, exchange.ErrKeyFingerprintNotFound + } + return p, noopCloser{}, nil +} + +// refreshRetryProvider emulates one timeout from master while refreshing CDN +// redirect after token invalidation. +type refreshRetryProvider struct { + *mock + + redirect *tg.UploadFileCDNRedirect + masterCalls atomic.Int32 + refreshTimeoutOnce atomic.Bool +} + +func (p *refreshRetryProvider) UploadGetFile(ctx context.Context, request *tg.UploadGetFileRequest) (tg.UploadFileClass, error) { + if request.GetCDNSupported() { + if p.masterCalls.Add(1) == 2 && p.refreshTimeoutOnce.CompareAndSwap(true, false) { + return nil, tgerr.New(500, tg.ErrTimeout) + } + return p.redirect, nil + } + return p.mock.UploadGetFile(ctx, request) +} + +func (p *refreshRetryProvider) CDN(ctx context.Context, dc int, max int64) (CDN, io.Closer, error) { + return p, noopCloser{}, nil +} + +// reuploadRetryProvider emulates one retryable token error on reupload call. +type reuploadRetryProvider struct { + *mock + + tokenInvalidOnce atomic.Bool +} + +func (p *reuploadRetryProvider) UploadReuploadCDNFile(ctx context.Context, request *tg.UploadReuploadCDNFileRequest) ([]tg.FileHash, error) { + if p.tokenInvalidOnce.CompareAndSwap(true, false) { + return nil, tgerr.New(400, "REQUEST_TOKEN_INVALID") + } + return p.mock.UploadReuploadCDNFile(ctx, request) +} + +func (p *reuploadRetryProvider) CDN(ctx context.Context, dc int, max int64) (CDN, io.Closer, error) { + return p, noopCloser{}, nil +} + func countHashes(data []byte, partSize int) (r [][]tg.FileHash) { actions := data batchSize := partSize @@ -194,26 +459,30 @@ func TestDownloader(t *testing.T) { } tests := []struct { - name string - data []byte - migrate bool - err bool - hashesErr bool + name string + data []byte + migrate bool + cdnReupload bool + cdnTokenErr bool + err bool + hashesErr bool }{ - {"5b", []byte{1, 2, 3, 4, 5}, false, false, false}, - {strconv.Itoa(len(testData)) + "b", testData, false, false, false}, - {"Error", []byte{}, false, true, false}, - {"HashesError", []byte{}, false, true, true}, - {"Migrate", []byte{}, true, false, false}, + {"5b", []byte{1, 2, 3, 4, 5}, false, false, false, false, false}, + {strconv.Itoa(len(testData)) + "b", testData, false, false, false, false, false}, + {"Error", []byte{}, false, false, false, true, false}, + {"HashesError", testData, false, false, false, false, true}, + {"Migrate", testData, true, false, false, false, false}, + {"MigrateReupload", testData, true, true, false, false, false}, + {"MigrateTokenInvalid", testData, true, false, true, false, false}, } schemas := []struct { name string - creator func(c Client, cdn CDN) *Builder + creator func(c Client) *Builder }{ - {"Master", func(c Client, cdn CDN) *Builder { - return NewDownloader().Download(c, nil) + {"Master", func(c Client) *Builder { + return NewDownloader().WithAllowCDN(true).Download(c, nil) }}, - {"Web", func(c Client, cdn CDN) *Builder { + {"Web", func(c Client) *Builder { return NewDownloader().Web(c, nil) }}, } @@ -272,20 +541,25 @@ func TestDownloader(t *testing.T) { hashes: mockHashes{ ranges: countHashes(test.data, 128*1024), }, - migrate: test.migrate, - err: test.err, - redirect: redirect, + migrate: test.migrate, + err: test.err, + hashesErr: test.hashesErr, + redirect: redirect, + } + if test.cdnReupload { + client.reuploadNeeded.Store(true) + } + if test.cdnTokenErr { + client.tokenInvalid.Store(true) } - b := schema.creator(client, client) + b := schema.creator(client) b = option.action(b) data, err := way.action(b) - switch { - case test.migrate: - a.Error(err) - case test.err: + shouldErr := test.err || (test.hashesErr && option.name == "Verify") + if shouldErr { a.Error(err) - default: + } else { a.NoError(err) a.True(bytes.Equal(test.data, data)) } @@ -298,3 +572,1648 @@ func TestDownloader(t *testing.T) { }) } } + +func TestDownloader_CDNFallbackWithoutProvider(t *testing.T) { + ctx := context.Background() + data := []byte("fallback-without-cdn-provider") + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + redirect := &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + } + + t.Run("NoProvider", func(t *testing.T) { + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: redirect, + } + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(&noCDNClient{base: m}, nil). + WithVerify(true). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + require.EqualValues(t, 1, m.getFileCalls.Load()) + }) + + t.Run("NilProvider", func(t *testing.T) { + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: redirect, + } + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(&nilCDNProvider{mock: m}, nil). + WithVerify(true). + Stream(ctx, output) + require.Error(t, err) + }) + + t.Run("ProviderErrorReturnsError", func(t *testing.T) { + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: redirect, + } + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(&errCDNProvider{ + mock: m, + err: testErr, + }, nil). + WithVerify(true). + Stream(ctx, output) + require.Error(t, err) + require.ErrorIs(t, err, testErr) + }) + + t.Run("ProviderContextError", func(t *testing.T) { + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: redirect, + } + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(&errCDNProvider{ + mock: m, + err: context.Canceled, + }, nil). + WithVerify(true). + Stream(ctx, output) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + }) + + t.Run("TokenInvalidFallbackToMaster", func(t *testing.T) { + m := &mock{ + data: data, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: redirect, + } + m.migrateOnce.Store(true) + m.tokenInvalid.Store(true) + + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + WithVerify(true). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + }) +} +func TestDownloader_CDNDisabledByDefault(t *testing.T) { + // Default NewDownloader() must stay strictly backward compatible: + // redirect-capable mock still should run through master-only flow. + ctx := context.Background() + data := []byte("cdn-policy-disabled") + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + redirect := &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + } + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: redirect, + } + output := new(bytes.Buffer) + _, err := NewDownloader().Download(m, nil).WithVerify(true).Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + // CDN is opt-in only. + require.EqualValues(t, 1, m.getFileCalls.Load()) +} + +func TestDownloader_AllowCDNNoRedirectKeepsLegacyLoad(t *testing.T) { + // Main compatibility check for explicit AllowCDN=true: + // when server does not return redirect we should not issue extra hash RPCs. + ctx := context.Background() + const threads = 4 + data := make([]byte, defaultPartSize*2) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + } + + output := new(syncio.BufWriterAt) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + WithThreads(threads). + Parallel(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + require.Zero(t, m.hashesCalls.Load()) + require.Zero(t, m.cdnHashCalls.Load()) +} + +func TestDownloader_AllowCDNNoRedirectNoExtraProbeRequest(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize*2) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + } + + output := new(syncio.BufWriterAt) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + WithThreads(1). + Parallel(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + // 2 parts + EOF probe, no prepare-stage prefetch. + require.EqualValues(t, 3, m.getFileCalls.Load()) +} + +func TestDownloader_NonCDNDefaultAvoidsExtraGetFile(t *testing.T) { + ctx := context.Background() + const threads = 4 + data := make([]byte, defaultPartSize*2) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + } + + output := new(syncio.BufWriterAt) + _, err := NewDownloader().Download(m, nil).WithThreads(threads).Parallel(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + // Unknown-size parallel downloads may have up to threads-1 in-flight tail probes. + // Baseline here is 3 calls: 2 full parts + 1 EOF probe. + calls := m.getFileCalls.Load() + require.GreaterOrEqual(t, calls, int32(3)) + require.LessOrEqual(t, calls, int32(3+(threads-1))) +} + +func TestDownloader_WithAllowCDNDisabledMatchesLegacy(t *testing.T) { + ctx := context.Background() + data := []byte("legacy-master-only") + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 1, + FileToken: []byte{1}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + output := new(bytes.Buffer) + _, err := NewDownloader().WithAllowCDN(false).Download(m, nil).Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + require.EqualValues(t, 1, m.getFileCalls.Load()) +} + +func TestDownloader_CDNLateRedirectDefaultEnablesVerify(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize*2) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirectAtOffset: defaultPartSize, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + // Hash validation must be enabled by default when CDN flow is allowed, + // even if redirect happens after the first chunk. + require.Greater(t, m.hashesCalls.Load(), int32(0)) +} + +func TestDownloader_CDNDefaultVerifyDetectsHashMismatch(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + hashes := countHashes(data, 128*1024) + hashes[0][0].Hash = bytes.Repeat([]byte{0x42}, len(hashes[0][0].Hash)) + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: hashes, + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + Stream(ctx, io.Discard) + require.ErrorIs(t, err, ErrHashMismatch) +} + +func TestDownloader_CDNVerifyCannotBeDisabled(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + hashes := countHashes(data, 128*1024) + hashes[0][0].Hash = bytes.Repeat([]byte{0x42}, len(hashes[0][0].Hash)) + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: hashes, + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + WithVerify(false). + Stream(ctx, io.Discard) + require.ErrorIs(t, err, ErrHashMismatch) +} + +func TestDownloader_CDNSplitWindowFullFetchDeduplicated(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + data := make([]byte, 128*1024) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + windowBlock := make(chan struct{}) + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{11}, + EncryptionKey: key, + EncryptionIv: iv, + }, + trackWindowOffset: 0, + trackWindowLimit: 128 * 1024, + trackWindowBlock: windowBlock, + } + + writer := new(syncio.BufWriterAt) + errCh := make(chan error, 1) + go func() { + _, err := NewDownloader(). + WithPartSize(64*1024). + WithAllowCDN(true). + Download(m, nil). + WithThreads(2). + Parallel(ctx, writer) + errCh <- err + }() + + require.Eventually(t, func() bool { + return m.trackWindowCalls.Load() >= 1 + }, time.Second, 10*time.Millisecond) + // Keep the first full-window request blocked for a short period to give + // concurrent chunk verification a chance to request the same window. + time.Sleep(50 * time.Millisecond) + close(windowBlock) + + require.NoError(t, <-errCh) + require.Equal(t, int32(1), m.trackWindowCalls.Load()) + require.Equal(t, data, writer.Bytes()) +} + +func TestDownloader_CDNDefaultVerifyAllowsShortFinalChunk(t *testing.T) { + ctx := context.Background() + data := make([]byte, 131072+10093) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + // Last hash has nominal 128KB limit but hash bytes are for short tail. + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithPartSize(128*1024). + WithAllowCDN(true). + Download(m, nil). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) +} + +func TestDownloader_WithAllowCDNDisabledNoCDNMethodsCalled(t *testing.T) { + ctx := context.Background() + const threads = 4 + data := make([]byte, defaultPartSize*2) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, // Server would redirect if cdn_supported is set. + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + output := new(syncio.BufWriterAt) + _, err := NewDownloader(). + WithAllowCDN(false). + Download(m, nil). + WithThreads(threads). + Parallel(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + require.Zero(t, m.cdnGetCalls.Load()) + require.Zero(t, m.cdnReupCalls.Load()) + require.Zero(t, m.cdnHashCalls.Load()) +} + +func TestDownloader_CDNFingerprintMissRetriesGetFile(t *testing.T) { + ctx := context.Background() + data := []byte("cdn-fingerprint-retry") + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + m.cdnFingerprint.Store(true) + + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + WithVerify(true). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + require.GreaterOrEqual(t, m.cdnGetCalls.Load(), int32(2)) +} + +func TestDownloader_CDNFingerprintMissRetriesHashes(t *testing.T) { + ctx := context.Background() + data := []byte("cdn-hash-fingerprint-retry") + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + m.cdnHashFP.Store(true) + + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + WithVerify(true). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + require.GreaterOrEqual(t, m.cdnHashCalls.Load(), int32(2)) +} + +func TestDownloader_CDNParallelMultiThread(t *testing.T) { + ctx := context.Background() + const threads = 8 + data := make([]byte, defaultPartSize*6+777) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + output := new(syncio.BufWriterAt) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + WithThreads(threads). + Parallel(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + require.Greater(t, m.cdnGetCalls.Load(), int32(0)) + require.Greater(t, m.cdnHashCalls.Load(), int32(0)) +} + +func TestDownloader_ConcurrentMixedCDNAndNonCDN(t *testing.T) { + ctx := context.Background() + + mustRandom := func(size int) []byte { + b := make([]byte, size) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + t.Fatal(err) + } + return b + } + newRedirect := func() *tg.UploadFileCDNRedirect { + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + return &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + } + } + + cdnData := mustRandom(defaultPartSize*2 + 111) + legacyData := mustRandom(defaultPartSize*2 + 222) + noRedirectData := mustRandom(defaultPartSize*2 + 333) + + cdnMock := &mock{ + data: cdnData, + migrate: true, + hashes: mockHashes{ranges: countHashes(cdnData, 128*1024)}, + redirect: newRedirect(), + } + legacyMock := &mock{ + data: legacyData, + migrate: true, + hashes: mockHashes{ranges: countHashes(legacyData, 128*1024)}, + redirect: newRedirect(), + } + noRedirectMock := &mock{ + data: noRedirectData, + hashes: mockHashes{ranges: countHashes(noRedirectData, 128*1024)}, + redirect: newRedirect(), + } + + type result struct { + name string + data []byte + err error + } + results := make(chan result, 3) + var wg sync.WaitGroup + wg.Add(3) + + go func() { + defer wg.Done() + out := new(syncio.BufWriterAt) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(cdnMock, nil). + WithThreads(4). + Parallel(ctx, out) + results <- result{name: "cdn", data: out.Bytes(), err: err} + }() + + go func() { + defer wg.Done() + out := new(syncio.BufWriterAt) + _, err := NewDownloader(). + WithAllowCDN(false). + Download(legacyMock, nil). + WithThreads(4). + Parallel(ctx, out) + results <- result{name: "legacy", data: out.Bytes(), err: err} + }() + + go func() { + defer wg.Done() + out := new(syncio.BufWriterAt) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(noRedirectMock, nil). + WithThreads(4). + Parallel(ctx, out) + results <- result{name: "no-redirect", data: out.Bytes(), err: err} + }() + + wg.Wait() + close(results) + + got := map[string]result{} + for r := range results { + got[r.name] = r + require.NoError(t, r.err) + } + + require.Equal(t, cdnData, got["cdn"].data) + require.Equal(t, legacyData, got["legacy"].data) + require.Equal(t, noRedirectData, got["no-redirect"].data) + require.Greater(t, cdnMock.cdnGetCalls.Load(), int32(0)) + require.Zero(t, legacyMock.cdnGetCalls.Load()) + require.Zero(t, noRedirectMock.cdnGetCalls.Load()) +} + +func TestDownloader_CDNRetriesOnTimeout(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize*2) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + m.cdnGetTimeout.Store(true) + m.cdnHashTimeout.Store(true) + + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + require.GreaterOrEqual(t, m.cdnGetCalls.Load(), int32(2)) + require.GreaterOrEqual(t, m.cdnHashCalls.Load(), int32(2)) +} + +func TestDownloader_RetryHandlerCDNPath(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + // Cover two retry sources: + // - reader/verifier retry on timeout from CDN hashes + // - CDN state machine retry on fingerprint miss from getCdnFile + m.cdnHashTimeout.Store(true) + m.cdnFingerprint.Store(true) + + output := new(bytes.Buffer) + var ( + mu sync.Mutex + events []RetryEvent + ) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + WithRetryHandler(func(event RetryEvent) { + mu.Lock() + events = append(events, event) + mu.Unlock() + }). + WithVerify(true). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, events) + + ops := make(map[string]bool, len(events)) + for _, event := range events { + require.NotEmpty(t, event.Operation) + require.GreaterOrEqual(t, event.Attempt, 1) + require.Error(t, event.Err) + ops[event.Operation] = true + } + + require.True(t, ops[RetryOperationReaderHashes]) + require.True(t, ops[RetryOperationGetFile]) +} + +func TestDownloader_RetryHandlerLegacyPath(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize+13) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + } + m.getTimeout.Store(true) + + output := new(bytes.Buffer) + var ( + mu sync.Mutex + events []RetryEvent + ) + _, err := NewDownloader(). + Download(m, nil). + WithRetryHandler(func(event RetryEvent) { + mu.Lock() + events = append(events, event) + mu.Unlock() + }). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + + mu.Lock() + defer mu.Unlock() + require.NotEmpty(t, events) + + hasReaderChunk := false + for _, event := range events { + if event.Operation == RetryOperationReaderChunk { + hasReaderChunk = true + require.GreaterOrEqual(t, event.Attempt, 1) + require.Error(t, event.Err) + } + } + require.True(t, hasReaderChunk) +} + +func TestDownloader_RetryHandlerBuilderIsolationCDNPath(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + // Force retries from both verifier and CDN state machine. + m.cdnHashTimeout.Store(true) + m.cdnFingerprint.Store(true) + + downloader := NewDownloader().WithAllowCDN(true) + + var ( + firstMu sync.Mutex + first []RetryEvent + secondMu sync.Mutex + second []RetryEvent + ) + + firstBuilder := downloader. + Download(m, nil). + WithRetryHandler(func(event RetryEvent) { + firstMu.Lock() + first = append(first, event) + firstMu.Unlock() + }). + WithVerify(true) + + // Build second request on same downloader but do not run it. Its handler + // must not receive retries from first builder. + _ = downloader.Download(m, nil).WithRetryHandler(func(event RetryEvent) { + secondMu.Lock() + second = append(second, event) + secondMu.Unlock() + }) + + output := new(bytes.Buffer) + _, err := firstBuilder.Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + + firstMu.Lock() + firstEvents := len(first) + firstMu.Unlock() + secondMu.Lock() + secondEvents := len(second) + secondMu.Unlock() + + require.NotZero(t, firstEvents) + require.Zero(t, secondEvents) +} + +func TestDownloader_RetryHandlerCreateClientAttemptFromRefresh(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + redirectOne := &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + } + redirectTwo := &tg.UploadFileCDNRedirect{ + DCID: 204, + FileToken: []byte{11}, + EncryptionKey: key, + EncryptionIv: iv, + } + + base := &mock{ + data: data, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: redirectOne, + } + // First CDN request: fingerprint miss (outer retry #1). + // Second CDN request: token invalid -> refresh redirect. + base.cdnFingerprint.Store(true) + base.tokenInvalid.Store(true) + + provider := &retryAttemptProvider{ + base: base, + initialRedirect: redirectOne, + refreshRedirect: redirectTwo, + } + + output := new(bytes.Buffer) + var ( + mu sync.Mutex + events []RetryEvent + ) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(provider, nil). + WithRetryHandler(func(event RetryEvent) { + mu.Lock() + events = append(events, event) + mu.Unlock() + }). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + + mu.Lock() + defer mu.Unlock() + + createAttempts := make([]int, 0, 1) + for _, event := range events { + if event.Operation == RetryOperationCreateClient { + createAttempts = append(createAttempts, event.Attempt) + } + } + require.Equal(t, []int{3}, createAttempts) +} + +func TestDownloader_RetryHandlerRefreshRedirect(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + redirect := &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + } + + base := &mock{ + data: data, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + migrate: true, + redirect: redirect, + } + // Trigger refresh path from CDN state machine. + base.tokenInvalid.Store(true) + provider := &refreshRetryProvider{ + mock: base, + redirect: redirect, + } + provider.refreshTimeoutOnce.Store(true) + + output := new(bytes.Buffer) + var ( + mu sync.Mutex + events []RetryEvent + ) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(provider, nil). + WithRetryHandler(func(event RetryEvent) { + mu.Lock() + events = append(events, event) + mu.Unlock() + }). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + + mu.Lock() + defer mu.Unlock() + + var refreshAttempts []int + for _, event := range events { + if event.Operation == RetryOperationRefreshRedirect { + refreshAttempts = append(refreshAttempts, event.Attempt) + } + } + require.Equal(t, []int{1}, refreshAttempts) +} + +func TestDownloader_RetryHandlerReupload(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + base := &mock{ + data: data, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + migrate: true, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + base.reuploadNeeded.Store(true) + provider := &reuploadRetryProvider{mock: base} + provider.tokenInvalidOnce.Store(true) + + output := new(bytes.Buffer) + var ( + mu sync.Mutex + events []RetryEvent + ) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(provider, nil). + WithRetryHandler(func(event RetryEvent) { + mu.Lock() + events = append(events, event) + mu.Unlock() + }). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + + mu.Lock() + defer mu.Unlock() + + var reuploadAttempts []int + for _, event := range events { + if event.Operation == RetryOperationReupload { + reuploadAttempts = append(reuploadAttempts, event.Attempt) + } + } + require.Len(t, reuploadAttempts, 1) + require.GreaterOrEqual(t, reuploadAttempts[0], 1) +} + +func TestDownloader_RetryHandlerGetFileHashes(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + migrate: true, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + m.cdnHashFP.Store(true) + + output := new(bytes.Buffer) + var ( + mu sync.Mutex + events []RetryEvent + ) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + WithRetryHandler(func(event RetryEvent) { + mu.Lock() + events = append(events, event) + mu.Unlock() + }). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + + mu.Lock() + defer mu.Unlock() + + var hashAttempts []int + for _, event := range events { + if event.Operation == RetryOperationGetFileHashes { + hashAttempts = append(hashAttempts, event.Attempt) + } + } + require.Equal(t, []int{1}, hashAttempts) +} + +func TestDownloader_LegacyRetriesOnTimeout(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize+13) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + } + m.getTimeout.Store(true) + + output := new(bytes.Buffer) + _, err := NewDownloader().Download(m, nil).Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + require.GreaterOrEqual(t, m.getFileCalls.Load(), int32(2)) +} + +// Verifies split-window happy path for small custom part size (64KB): +// download must not fail with "hash window exceeds remaining chunk". +func TestDownloader_CDNSmallPartSizeNoHashWindowError(t *testing.T) { + ctx := context.Background() + data := make([]byte, 131072*3+10093) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithPartSize(64*1024). + WithAllowCDN(true). + Download(m, nil). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) +} + +// Verifies split-window happy path for unaligned part size (160KB) where CDN +// hash windows (128KB) are crossed by request boundaries. +func TestDownloader_CDNUnalignedPartSizeNoHashWindowError(t *testing.T) { + ctx := context.Background() + data := make([]byte, 131072*5+7777) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithPartSize(160*1024). + WithAllowCDN(true). + Download(m, nil). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) +} + +func TestDownloader_CDNUnalignedPartSizeRespectsCDNRequestLimits(t *testing.T) { + ctx := context.Background() + data := make([]byte, 131072*5+7777) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + enforceCDNRequestRules: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithPartSize(160*1024). + WithAllowCDN(true). + Download(m, nil). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + require.Greater(t, m.cdnGetCalls.Load(), int32(0)) +} + +// Covers tricky tail case: second hash window is shorter than nominal limit +// and also split by unaligned part size. Download must stay successful. +func TestDownloader_CDNUnalignedPartSizeFinalSplitNoHashWindowError(t *testing.T) { + ctx := context.Background() + // Two hash windows where the second one is short and split by part size: + // [0, 128KB) + [128KB, ~200KB), partSize=160KB. + data := make([]byte, 200*1024) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithPartSize(160*1024). + WithAllowCDN(true). + Download(m, nil). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) +} + +// Integrity regression test for small part size: tampered payload must be +// rejected even when hash windows are split between chunks. +func TestDownloader_CDNSmallPartSizeDetectsHashMismatch(t *testing.T) { + ctx := context.Background() + original := make([]byte, 131072*3+10093) + if _, err := io.ReadFull(rand.Reader, original); err != nil { + t.Fatal(err) + } + hashRanges := countHashes(original, 128*1024) + + // Tamper payload after hash snapshot: downloader must reject. + data := append([]byte(nil), original...) + data[100] ^= 0xFF + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: hashRanges, + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + _, err := NewDownloader(). + WithPartSize(64*1024). + WithAllowCDN(true). + Download(m, nil). + Stream(ctx, io.Discard) + require.ErrorIs(t, err, ErrHashMismatch) +} + +// Integrity regression test for unaligned part size. Corruption inside a +// window that spans chunks must still be detected. +func TestDownloader_CDNUnalignedPartSizeDetectsHashMismatch(t *testing.T) { + ctx := context.Background() + original := make([]byte, 131072*5+7777) + if _, err := io.ReadFull(rand.Reader, original); err != nil { + t.Fatal(err) + } + hashRanges := countHashes(original, 128*1024) + + // Corrupt bytes inside hash window [128KB, 256KB) that spans request chunks + // when part size is 160KB. + data := append([]byte(nil), original...) + data[140*1024] ^= 0xFF + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: hashRanges, + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + + _, err := NewDownloader(). + WithPartSize(160*1024). + WithAllowCDN(true). + Download(m, nil). + Stream(ctx, io.Discard) + require.ErrorIs(t, err, ErrHashMismatch) +} + +// Ensures hashes returned by UploadReuploadCDNFile are consumed immediately: +// retry should proceed without extra UploadGetCDNFileHashes call. +func TestDownloader_CDNUsesReuploadHashes(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + hashRanges := countHashes(data, 128*1024) + m := &mock{ + data: data, + migrate: true, + reupload: hashRanges[0], + hashesErr: true, + hashes: mockHashes{ + ranges: hashRanges, + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + m.reuploadNeeded.Store(true) + + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + Stream(ctx, output) + require.NoError(t, err) + require.Equal(t, data, output.Bytes()) + require.EqualValues(t, 1, m.cdnReupCalls.Load()) + require.Zero(t, m.cdnHashCalls.Load()) +} + +func TestDownloader_CDNReuploadTimeoutDoesNotFallbackToMaster(t *testing.T) { + ctx := context.Background() + data := make([]byte, defaultPartSize) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + key := make([]byte, 32) + iv := make([]byte, aes.BlockSize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + t.Fatal(err) + } + + m := &mock{ + data: data, + migrate: true, + hashes: mockHashes{ + ranges: countHashes(data, 128*1024), + }, + redirect: &tg.UploadFileCDNRedirect{ + DCID: 203, + FileToken: []byte{10}, + EncryptionKey: key, + EncryptionIv: iv, + }, + } + // First CDN chunk requests reupload; first reupload call returns + // CDN_UPLOAD_TIMEOUT. TDesktop does not switch to master refresh flow for + // this error and fails the task. + m.reuploadNeeded.Store(true) + m.cdnUploadTO.Store(true) + + output := new(bytes.Buffer) + _, err := NewDownloader(). + WithAllowCDN(true). + Download(m, nil). + Stream(ctx, output) + require.Error(t, err) + require.ErrorContains(t, err, "CDN_UPLOAD_TIMEOUT") + // Only initial redirect request to master is expected: no fallback refresh. + require.EqualValues(t, 1, m.getFileCalls.Load()) + require.EqualValues(t, 1, m.cdnReupCalls.Load()) +} diff --git a/telegram/downloader/master.go b/telegram/downloader/master.go index 9eab062858..02a0d49a65 100644 --- a/telegram/downloader/master.go +++ b/telegram/downloader/master.go @@ -27,11 +27,24 @@ type master struct { precise bool allowCDN bool - location tg.InputFileLocationClass + // retryHandler observes retried transient downloader errors. + retryHandler RetryHandler + location tg.InputFileLocationClass } var _ schema = master{} +func (c master) reportRetry(operation string, attempt int, err error) { + if attempt < 1 || err == nil || c.retryHandler == nil { + return + } + c.retryHandler(RetryEvent{ + Operation: operation, + Attempt: attempt, + Err: err, + }) +} + func (c master) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) { req := &tg.UploadGetFileRequest{ Offset: offset, diff --git a/telegram/downloader/reader.go b/telegram/downloader/reader.go index 4c0c564a02..33d64be0a8 100644 --- a/telegram/downloader/reader.go +++ b/telegram/downloader/reader.go @@ -88,11 +88,14 @@ func (r *reader) nextPlain(ctx context.Context) (block, error) { } func (r *reader) next(ctx context.Context, offset int64, limit int) (block, error) { + retryAttempt := 0 for { ch, err := r.sch.Chunk(ctx, offset, limit) if flood, err := tgerr.FloodWait(ctx, err); err != nil { if flood || tgerr.Is(err, tg.ErrTimeout) { + retryAttempt++ + reportSchemaRetry(r.sch, RetryOperationReaderChunk, retryAttempt, err) continue } return block{}, errors.Wrap(err, "get next chunk") diff --git a/telegram/downloader/retry.go b/telegram/downloader/retry.go new file mode 100644 index 0000000000..9d063394cd --- /dev/null +++ b/telegram/downloader/retry.go @@ -0,0 +1,81 @@ +package downloader + +import ( + "context" + + "github.com/go-faster/errors" + + "github.com/gotd/td/exchange" + "github.com/gotd/td/tg" + "github.com/gotd/td/tgerr" +) + +const maxRetryAttempts = 20 + +func retryLimitErr(op string, attempts int, err error) error { + return errors.Wrapf(err, "%s: retry limit reached (%d)", op, attempts) +} + +func isCDNFingerprintErr(err error) bool { + return errors.Is(err, exchange.ErrKeyFingerprintNotFound) +} + +func isCDNMasterFallbackErr(err error) bool { + // Token invalidation requires fetching fresh redirect/token window from + // master DC. + return tgerr.Is( + err, + "FILE_TOKEN_INVALID", + "REQUEST_TOKEN_INVALID", + ) +} + +func retryRequest[T any]( + ctx context.Context, + op string, + onRetry func(attempt int, err error), + fn func() (T, error), +) (_ T, err error) { + var zero T + timeoutRetries := 0 + retryAttempt := 0 + for { + if err := ctx.Err(); err != nil { + return zero, err + } + + result, err := fn() + if flood, waitErr := tgerr.FloodWait(ctx, err); waitErr != nil { + if flood { + // FloodWait helper already slept required amount. + if ctxErr := ctx.Err(); ctxErr != nil { + return zero, ctxErr + } + retryAttempt++ + if onRetry != nil { + onRetry(retryAttempt, waitErr) + } + continue + } + if tgerr.Is(waitErr, tg.ErrTimeout) { + if ctxErr := ctx.Err(); ctxErr != nil { + return zero, ctxErr + } + // Timeout can happen on unstable proxy links; retry with bounded + // attempts to avoid infinite tight loops. + timeoutRetries++ + if timeoutRetries >= maxRetryAttempts { + return zero, retryLimitErr(op, timeoutRetries, waitErr) + } + retryAttempt++ + if onRetry != nil { + onRetry(retryAttempt, waitErr) + } + continue + } + return zero, waitErr + } + + return result, nil + } +} diff --git a/telegram/downloader/retry_event.go b/telegram/downloader/retry_event.go new file mode 100644 index 0000000000..e81ccec1f6 --- /dev/null +++ b/telegram/downloader/retry_event.go @@ -0,0 +1,41 @@ +package downloader + +// RetryEvent describes retried transient downloader error. +type RetryEvent struct { + // Operation identifies retry source. + Operation string + // Attempt is 1-based counter inside current retry loop. + Attempt int + // Err is the error that triggered retry. + Err error +} + +const ( + RetryOperationGetFile = "cdn.get_file" + RetryOperationGetFileHashes = "cdn.get_file_hashes" + RetryOperationReupload = "cdn.reupload" + RetryOperationRefreshRedirect = "cdn.refresh_redirect" + RetryOperationCreateClient = "cdn.create_client" + RetryOperationReaderChunk = "reader.chunk" + RetryOperationReaderHashes = "reader.hashes" +) + +// RetryHandler is called for every downloader error that is retried internally. +type RetryHandler func(event RetryEvent) + +type retryReporter interface { + reportRetry(operation string, attempt int, err error) +} + +func reportSchemaRetry(s schema, operation string, attempt int, err error) { + if attempt < 1 || err == nil { + return + } + + reporter, ok := s.(retryReporter) + if !ok { + return + } + + reporter.reportRetry(operation, attempt, err) +} diff --git a/telegram/downloader/verifier.go b/telegram/downloader/verifier.go index 1463e26a59..212ce9e1b5 100644 --- a/telegram/downloader/verifier.go +++ b/telegram/downloader/verifier.go @@ -19,7 +19,9 @@ var ErrHashMismatch = errors.New("file hash mismatch") type verifier struct { client schema + // hashes is ordered queue consumed by reader. hashes []tg.FileHash + // offset points to next hash window to request from server. offset int64 mux sync.Mutex } @@ -31,8 +33,22 @@ func newVerifier(client schema, hashes ...tg.FileHash) *verifier { sort.SliceStable(r, func(i, j int) bool { return r[i].Offset < r[j].Offset }) + var nextOffset int64 + for _, hash := range r { + if hash.Limit <= 0 { + continue + } + end := hash.Offset + int64(hash.Limit) + if end > nextOffset { + nextOffset = end + } + } - return &verifier{client: client, hashes: r} + return &verifier{ + client: client, + hashes: r, + offset: nextOffset, + } } func (v *verifier) pop() (tg.FileHash, bool) { @@ -85,10 +101,14 @@ func (v *verifier) next(ctx context.Context) (tg.FileHash, bool, error) { return hash, ok, nil } + retryAttempt := 0 for { hashes, err := v.client.Hashes(ctx, v.offset) if flood, err := tgerr.FloodWait(ctx, err); err != nil { if flood || tgerr.Is(err, tg.ErrTimeout) { + // Keep retrying transient server throttling/timeouts. + retryAttempt++ + reportSchemaRetry(v.client, RetryOperationReaderHashes, retryAttempt, err) continue } return tg.FileHash{}, false, errors.Wrap(err, "get hashes") diff --git a/telegram/downloader/web.go b/telegram/downloader/web.go index f15673413d..2d5191f9e4 100644 --- a/telegram/downloader/web.go +++ b/telegram/downloader/web.go @@ -14,12 +14,25 @@ var errHashesNotSupported = errors.New("this schema does not support hashes fetc // See https://core.telegram.org/api/files#downloading-webfiles. type web struct { client Client + // retryHandler observes retried transient downloader errors. + retryHandler RetryHandler location tg.InputWebFileLocationClass } var _ schema = web{} +func (w web) reportRetry(operation string, attempt int, err error) { + if attempt < 1 || err == nil || w.retryHandler == nil { + return + } + w.retryHandler(RetryEvent{ + Operation: operation, + Attempt: attempt, + Err: err, + }) +} + func (w web) Chunk(ctx context.Context, offset int64, limit int) (chunk, error) { file, err := w.client.UploadGetWebFile(ctx, &tg.UploadGetWebFileRequest{ Location: w.location, diff --git a/telegram/internal/manager/conn.go b/telegram/internal/manager/conn.go index da2b5a3fba..da3d2b531b 100644 --- a/telegram/internal/manager/conn.go +++ b/telegram/internal/manager/conn.go @@ -3,6 +3,7 @@ package manager import ( "context" "sync" + "sync/atomic" "time" "github.com/cenkalti/backoff/v4" @@ -42,6 +43,7 @@ const ( type Conn struct { // Connection parameters. mode ConnMode // immutable + dc int // immutable // MTProto connection. proto protoConn // immutable @@ -66,6 +68,9 @@ type Conn struct { // State fields. cfg tg.Config + // cdnNeedsInit mirrors TDesktop connectionInited state for CDN transport. + // true means requests must go via invokeWithLayer(initConnection). + cdnNeedsInit atomic.Bool // pending buffers OnSession events until initConnection config is available. pending []mtproto.Session ongoing int @@ -175,7 +180,7 @@ func (c *Conn) Run(ctx context.Context) (err error) { func (c *Conn) waitSession(ctx context.Context) error { select { - // Connection is considered ready only after initConnection succeeded. + // Connection is considered ready only after mode-specific init succeeded. case <-c.gotConfig.Ready(): return nil case <-c.dead.Ready(): @@ -200,14 +205,77 @@ func (c *Conn) Invoke(ctx context.Context, input bin.Encoder, output bin.Decoder return errors.Wrap(err, "waitSession") } + if c.mode == ConnModeCDN { + // CDN mode has dedicated request wrapping rules (see invokeCDN). + err := c.invokeCDN(ctx, input, output) + return err + } q := c.wrapRequest(noopDecoder{input}) req := c.wrapRequest(&tg.InvokeWithLayerRequest{ Layer: tg.Layer, Query: q, }) + err := c.proto.Invoke(ctx, req, output) + return err +} +func (c *Conn) invokeCDN( + ctx context.Context, + input bin.Encoder, + output bin.Decoder, +) error { + // TDesktop model: + // - while connection is "not inited": wrap every query in invokeWithLayer(initConnection); + // - after first successful reply: use raw CDN methods; + // - if server returns CONNECTION_NOT_INITED/LAYER_INVALID on raw call: + // mark "not inited" and retry wrapped once. + if c.cdnNeedsInit.Load() { + err := c.invokeCDNWrapped(ctx, input, output) + if err == nil { + c.cdnNeedsInit.Store(false) + return nil + } + return err + } + err := c.invokeCDNRaw(ctx, input, output) + if err == nil { + return nil + } + if c.shouldCDNRetryWrapped(err) { + c.cdnNeedsInit.Store(true) + retryErr := c.invokeCDNWrapped(ctx, input, output) + if retryErr == nil { + c.cdnNeedsInit.Store(false) + return nil + } + return retryErr + } + return err +} +func (c *Conn) invokeCDNWrapped(ctx context.Context, input bin.Encoder, output bin.Decoder) error { + req := &tg.InvokeWithLayerRequest{ + Layer: tg.Layer, + Query: c.cdnInitRequest(noopDecoder{input}), + } return c.proto.Invoke(ctx, req, output) } +func (c *Conn) invokeCDNRaw(ctx context.Context, input bin.Encoder, output bin.Decoder) error { + return c.proto.Invoke(ctx, input, output) +} +func (c *Conn) shouldCDNRetryWrapped(err error) bool { + if err == nil { + return false + } + if rpcErr, ok := tgerr.As(err); ok { + // Retry wrapped only for not-inited/layer-invalid transport state. + v := rpcErr.IsOneOf( + "CONNECTION_NOT_INITED", + "CONNECTION_LAYER_INVALID", + ) + return v + } + return false +} // OnMessage implements mtproto.Handler. func (c *Conn) OnMessage(b *bin.Buffer) error { @@ -223,7 +291,7 @@ func (n noopDecoder) Decode(b *bin.Buffer) error { } func (c *Conn) wrapRequest(req bin.Object) bin.Object { - if c.mode != ConnModeUpdates { + if c.mode == ConnModeData { return &tg.InvokeWithoutUpdatesRequest{ Query: req, } @@ -232,9 +300,38 @@ func (c *Conn) wrapRequest(req bin.Object) bin.Object { return req } +func (c *Conn) cdnInitRequest(query bin.Object) bin.Object { + // Match TDesktop CDN init wrapper: + // only device/system are anonymized, the rest of initConnection + // parameters stay aligned with regular connection settings. + return &tg.InitConnectionRequest{ + APIID: c.appID, + DeviceModel: "n/a", + SystemVersion: "n/a", + AppVersion: c.device.AppVersion, + SystemLangCode: c.device.SystemLangCode, + LangPack: c.device.LangPack, + LangCode: c.device.LangCode, + Proxy: c.device.Proxy, + Params: c.device.Params, + Query: query, + } +} func (c *Conn) init(ctx context.Context) error { c.log.Debug("Initializing") + if c.mode == ConnModeCDN { + // CDN connections skip help.getConfig init flow and become ready + // immediately after MTProto auth-key exchange. + c.cdnNeedsInit.Store(true) + c.mux.Lock() + c.latest = c.clock.Now() + c.cfg = tg.Config{ThisDC: c.dc} + c.mux.Unlock() + c.gotConfig.Signal() + err := c.flushPendingSession() + return err + } q := c.wrapRequest(&tg.InitConnectionRequest{ APIID: c.appID, DeviceModel: c.device.DeviceModel, @@ -287,7 +384,8 @@ func (c *Conn) init(ctx context.Context) error { c.mux.Unlock() c.gotConfig.Signal() - return c.flushPendingSession() + err := c.flushPendingSession() + return err } // Ping calls ping for underlying protocol connection. diff --git a/telegram/internal/manager/conn_cdn_test.go b/telegram/internal/manager/conn_cdn_test.go new file mode 100644 index 0000000000..e477f69a89 --- /dev/null +++ b/telegram/internal/manager/conn_cdn_test.go @@ -0,0 +1,283 @@ +package manager + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/gotd/td/bin" + "github.com/gotd/td/clock" + "github.com/gotd/td/tdsync" + "github.com/gotd/td/tg" + "github.com/gotd/td/tgerr" +) + +type captureProto struct { + invokeCalls int + lastInput bin.Encoder + inputs []bin.Encoder +} + +func (p *captureProto) Invoke(_ context.Context, input bin.Encoder, _ bin.Decoder) error { + p.invokeCalls++ + p.lastInput = input + p.inputs = append(p.inputs, input) + return nil +} + +func (p *captureProto) Run(ctx context.Context, f func(ctx context.Context) error) error { + return f(ctx) +} + +func (*captureProto) Ping(context.Context) error { + return nil +} + +func newTestConn(mode ConnMode, proto protoConn) *Conn { + c := &Conn{ + mode: mode, + dc: 203, + appID: 42, + device: DeviceConfig{AppVersion: "test-app"}, + proto: proto, + clock: clock.System, + log: zap.NewNop(), + handler: NoopHandler{}, + sessionInit: tdsync.NewReady(), + gotConfig: tdsync.NewReady(), + dead: tdsync.NewReady(), + } + if mode == ConnModeCDN { + c.cdnNeedsInit.Store(true) + } + return c +} + +func TestConnInitCDNNoHelpGetConfig(t *testing.T) { + a := require.New(t) + p := &captureProto{} + c := newTestConn(ConnModeCDN, p) + + a.NoError(c.init(context.Background())) + a.Equal(0, p.invokeCalls) + + select { + case <-c.gotConfig.Ready(): + case <-time.After(time.Second): + a.Fail("gotConfig should be signaled for CDN mode") + } + + c.mux.Lock() + cfg := c.cfg + c.mux.Unlock() + a.Equal(203, cfg.ThisDC) +} + +func TestConnInvokeCDNWrappedUsesInitConnection(t *testing.T) { + a := require.New(t) + p := &captureProto{} + c := newTestConn(ConnModeCDN, p) + c.device = DeviceConfig{ + DeviceModel: "private-model", + SystemVersion: "private-os", + AppVersion: "1.2.3", + SystemLangCode: "ru", + LangPack: "ru-pack", + LangCode: "ru", + Proxy: tg.InputClientProxy{ + Address: "127.0.0.1", + Port: 1080, + }, + Params: &tg.JSONObject{ + Value: []tg.JSONObjectValue{ + { + Key: "tz_offset", + Value: &tg.JSONNumber{Value: 10800}, + }, + }, + }, + } + err := c.invokeCDNWrapped( + context.Background(), + &tg.UploadGetCDNFileRequest{FileToken: []byte{1}, Offset: 0, Limit: 1024}, + &tg.UploadCDNFileBox{}, + ) + a.NoError(err) + + req, ok := p.lastInput.(*tg.InvokeWithLayerRequest) + a.True(ok) + a.Equal(tg.Layer, req.Layer) + + _, wrapped := req.Query.(*tg.InvokeWithoutUpdatesRequest) + a.False(wrapped, "CDN query must not use invokeWithoutUpdates") + + initReq, ok := req.Query.(*tg.InitConnectionRequest) + a.True(ok) + a.Equal(42, initReq.APIID) + a.Equal("n/a", initReq.DeviceModel) + a.Equal("n/a", initReq.SystemVersion) + a.Equal("1.2.3", initReq.AppVersion) + a.Equal("ru", initReq.SystemLangCode) + a.Equal("ru-pack", initReq.LangPack) + a.Equal("ru", initReq.LangCode) + a.Equal(tg.InputClientProxy{Address: "127.0.0.1", Port: 1080}, initReq.Proxy) + + params, ok := initReq.Params.(*tg.JSONObject) + a.True(ok) + a.Equal( + []tg.JSONObjectValue{{Key: "tz_offset", Value: &tg.JSONNumber{Value: 10800}}}, + params.Value, + ) + + query, ok := initReq.Query.(noopDecoder) + a.True(ok) + _, ok = query.Encoder.(*tg.UploadGetCDNFileRequest) + a.True(ok) +} + +type retryOnRawNotInitedProto struct { + calls []bin.Encoder + rawErrBudget int +} + +func (p *retryOnRawNotInitedProto) Invoke(_ context.Context, input bin.Encoder, _ bin.Decoder) error { + p.calls = append(p.calls, input) + if _, ok := input.(*tg.UploadGetCDNFileRequest); ok && p.rawErrBudget > 0 { + p.rawErrBudget-- + return tgerr.New(400, "CONNECTION_NOT_INITED") + } + return nil +} + +func (p *retryOnRawNotInitedProto) Run(ctx context.Context, f func(ctx context.Context) error) error { + return f(ctx) +} + +func (*retryOnRawNotInitedProto) Ping(context.Context) error { + return nil +} + +type rawMethodInvalidProto struct { + calls []bin.Encoder + rawErrBudget int +} + +func (p *rawMethodInvalidProto) Invoke(_ context.Context, input bin.Encoder, _ bin.Decoder) error { + p.calls = append(p.calls, input) + if _, ok := input.(*tg.UploadGetCDNFileRequest); ok && p.rawErrBudget > 0 { + p.rawErrBudget-- + return tgerr.New(400, "METHOD_INVALID") + } + return nil +} + +func (p *rawMethodInvalidProto) Run(ctx context.Context, f func(ctx context.Context) error) error { + return f(ctx) +} + +func (*rawMethodInvalidProto) Ping(context.Context) error { + return nil +} + +func TestConnInvokeCDNFirstCallWrapped(t *testing.T) { + a := require.New(t) + p := &captureProto{} + c := newTestConn(ConnModeCDN, p) + c.gotConfig.Signal() + + err := c.Invoke( + context.Background(), + &tg.UploadGetCDNFileRequest{FileToken: []byte{1}, Offset: 0, Limit: 1024}, + &tg.UploadCDNFileBox{}, + ) + a.NoError(err) + a.Len(p.inputs, 1) + + _, wrapped := p.inputs[0].(*tg.InvokeWithLayerRequest) + a.True(wrapped, "first CDN call must be wrapped with invokeWithLayer(initConnection)") +} + +func TestConnInvokeCDNSecondCallRawAfterInit(t *testing.T) { + a := require.New(t) + p := &captureProto{} + c := newTestConn(ConnModeCDN, p) + c.gotConfig.Signal() + + req := &tg.UploadGetCDNFileRequest{FileToken: []byte{1}, Offset: 0, Limit: 1024} + a.NoError(c.Invoke(context.Background(), req, &tg.UploadCDNFileBox{})) + a.NoError(c.Invoke(context.Background(), req, &tg.UploadCDNFileBox{})) + a.Len(p.inputs, 2) + + _, firstWrapped := p.inputs[0].(*tg.InvokeWithLayerRequest) + a.True(firstWrapped, "first CDN call must initialize connection via wrapper") + _, secondRaw := p.inputs[1].(*tg.UploadGetCDNFileRequest) + a.True(secondRaw, "after successful init, next CDN call must be raw") +} + +func TestConnInvokeCDNRawNotInitedRetryWrappedThenRaw(t *testing.T) { + a := require.New(t) + p := &retryOnRawNotInitedProto{rawErrBudget: 1} + c := newTestConn(ConnModeCDN, p) + c.gotConfig.Signal() + + req := &tg.UploadGetCDNFileRequest{FileToken: []byte{1}, Offset: 0, Limit: 1024} + a.NoError(c.Invoke(context.Background(), req, &tg.UploadCDNFileBox{})) + a.NoError(c.Invoke(context.Background(), req, &tg.UploadCDNFileBox{})) + a.NoError(c.Invoke(context.Background(), req, &tg.UploadCDNFileBox{})) + a.Len(p.calls, 4) + + _, ok := p.calls[0].(*tg.InvokeWithLayerRequest) + a.True(ok, "cold start must be wrapped") + _, ok = p.calls[1].(*tg.UploadGetCDNFileRequest) + a.True(ok, "inited state must try raw") + _, ok = p.calls[2].(*tg.InvokeWithLayerRequest) + a.True(ok, "CONNECTION_NOT_INITED from raw must retry wrapped") + _, ok = p.calls[3].(*tg.UploadGetCDNFileRequest) + a.True(ok, "after successful retry, state must return to raw") +} + +func TestConnInvokeCDNRawMethodInvalidNoWrappedFallback(t *testing.T) { + a := require.New(t) + p := &rawMethodInvalidProto{rawErrBudget: 1} + c := newTestConn(ConnModeCDN, p) + c.gotConfig.Signal() + + req := &tg.UploadGetCDNFileRequest{FileToken: []byte{1}, Offset: 0, Limit: 1024} + a.NoError(c.Invoke(context.Background(), req, &tg.UploadCDNFileBox{})) + + err := c.Invoke(context.Background(), req, &tg.UploadCDNFileBox{}) + a.Error(err) + a.True(tgerr.Is(err, "METHOD_INVALID")) + a.Len(p.calls, 2) + _, ok := p.calls[0].(*tg.InvokeWithLayerRequest) + a.True(ok, "initial request must be wrapped") + _, ok = p.calls[1].(*tg.UploadGetCDNFileRequest) + a.True(ok, "raw METHOD_INVALID should be returned as-is") +} + +func TestConnInvokeDataKeepsInvokeWithoutUpdates(t *testing.T) { + a := require.New(t) + p := &captureProto{} + c := newTestConn(ConnModeData, p) + c.gotConfig.Signal() + + a.NoError(c.Invoke(context.Background(), &tg.HelpGetConfigRequest{}, &tg.Config{})) + + outer, ok := p.lastInput.(*tg.InvokeWithoutUpdatesRequest) + a.True(ok) + + withLayer, ok := outer.Query.(*tg.InvokeWithLayerRequest) + a.True(ok) + a.Equal(tg.Layer, withLayer.Layer) + + inner, ok := withLayer.Query.(*tg.InvokeWithoutUpdatesRequest) + a.True(ok) + + query, ok := inner.Query.(noopDecoder) + a.True(ok) + _, ok = query.Encoder.(*tg.HelpGetConfigRequest) + a.True(ok) +} diff --git a/telegram/internal/manager/create.go b/telegram/internal/manager/create.go index 2f6e2ec2f8..f7ce1a2e3c 100644 --- a/telegram/internal/manager/create.go +++ b/telegram/internal/manager/create.go @@ -61,7 +61,9 @@ func CreateConn( ) *Conn { connOpts.setDefaults(opts.Clock) conn := &Conn{ - mode: mode, + mode: mode, + // Store real DC id for mode-specific init shortcuts (CDN mode). + dc: connOpts.DC, appID: appID, device: connOpts.Device, clock: opts.Clock, @@ -73,6 +75,9 @@ func CreateConn( onDead: connOpts.OnDead, connBackoff: connOpts.Backoff, } + if mode == ConnModeCDN { + conn.cdnNeedsInit.Store(true) + } conn.log = opts.Logger opts.DC = connOpts.DC diff --git a/telegram/options.go b/telegram/options.go index 4c32a1d4bb..c98bd6c981 100644 --- a/telegram/options.go +++ b/telegram/options.go @@ -46,6 +46,14 @@ type Options struct { // Enabled by default if no UpdateHandler is provided. NoUpdates bool + // AllowCDN enables downloader CDN redirect flow for clients that support + // downloader integration. + // + // If false, downloader will stay on the master DC path (legacy behavior). + // If true and server does not return redirect, requests still go through the + // same master path (no extra CDN round-trips). + // Default is false. + AllowCDN bool // ReconnectionBackoff configures and returns reconnection backoff object. ReconnectionBackoff func() backoff.BackOff // OnDead will be called on connection dead. diff --git a/telegram/pfs.go b/telegram/pfs.go index 9c9597fd66..d143c853b2 100644 --- a/telegram/pfs.go +++ b/telegram/pfs.go @@ -1,16 +1,11 @@ package telegram import ( - "context" - "github.com/go-faster/errors" "go.uber.org/zap" "github.com/gotd/td/mtproto" "github.com/gotd/td/pool" - "github.com/gotd/td/telegram/auth" - "github.com/gotd/td/telegram/internal/manager" - "github.com/gotd/td/tg" ) func (c *Client) handlePrimaryConnDead(err error) { @@ -57,14 +52,3 @@ func (c *Client) handleDCConnDead(dcID int, err error) { c.onDead(err) } } - -func (c *Client) dcTransferSetup(dcID int) manager.SetupCallback { - return func(ctx context.Context, invoker tg.Invoker) error { - // Run export/import authorization only when the connection is already up. - _, err := c.transfer(ctx, tg.NewClient(invoker), dcID) - if auth.IsUnauthorized(err) { - return nil - } - return err - } -} diff --git a/telegram/pfs_test.go b/telegram/pfs_test.go index 6d18e480a1..eb92ad3a21 100644 --- a/telegram/pfs_test.go +++ b/telegram/pfs_test.go @@ -1,13 +1,17 @@ package telegram import ( + "context" "errors" "testing" + "time" "github.com/stretchr/testify/require" "go.uber.org/zap" + "github.com/gotd/td/bin" "github.com/gotd/td/crypto" + "github.com/gotd/td/exchange" "github.com/gotd/td/mtproto" "github.com/gotd/td/pool" ) @@ -73,3 +77,283 @@ func TestClientHandleDCConnDeadPassThrough(t *testing.T) { a.Equal(int64(88), data.Salt) a.Equal(1, onDeadCalls) } + +func TestClientHandleCDNConnDeadPFSDropResetsCDNSession(t *testing.T) { + a := require.New(t) + client := Client{ + log: zap.NewNop(), + } + client.init() + + dcID := 7 + key := crypto.Key{3}.WithID() + session := pool.NewSyncSession(pool.Session{ + DC: dcID, + AuthKey: key, + Salt: 99, + }) + client.cdnSessions[dcID] = session + + onDeadCalls := 0 + client.onDead = func(error) { + onDeadCalls++ + } + + client.handleCDNConnDead(dcID, mtproto.ErrPFSDropKeysRequired) + + data := session.Load() + a.True(data.AuthKey.Zero()) + a.Zero(data.Salt) + a.Equal(1, onDeadCalls) +} + +func TestClientHandleCDNConnDeadDoesNotTouchRegularSession(t *testing.T) { + a := require.New(t) + client := Client{ + log: zap.NewNop(), + } + client.init() + + dcID := 8 + regularKey := crypto.Key{4}.WithID() + cdnKey := crypto.Key{5}.WithID() + client.sessions[dcID] = pool.NewSyncSession(pool.Session{ + DC: dcID, + AuthKey: regularKey, + Salt: 11, + }) + client.cdnSessions[dcID] = pool.NewSyncSession(pool.Session{ + DC: dcID, + AuthKey: cdnKey, + Salt: 22, + }) + + client.handleCDNConnDead(dcID, mtproto.ErrPFSDropKeysRequired) + + regular := client.sessions[dcID].Load() + cdn := client.cdnSessions[dcID].Load() + a.Equal(regularKey, regular.AuthKey) + a.Equal(int64(11), regular.Salt) + a.True(cdn.AuthKey.Zero()) + a.Zero(cdn.Salt) +} + +type closeInvokerStub struct { + closed bool + closedCh chan struct{} +} + +func (*closeInvokerStub) Invoke(context.Context, bin.Encoder, bin.Decoder) error { + return nil +} + +func (s *closeInvokerStub) Close() error { + s.closed = true + if s.closedCh != nil { + close(s.closedCh) + } + return nil +} + +type blockingCloseInvokerStub struct { + closed chan struct{} + unlock chan struct{} +} + +func (*blockingCloseInvokerStub) Invoke(context.Context, bin.Encoder, bin.Decoder) error { + return nil +} + +func (s *blockingCloseInvokerStub) Close() error { + close(s.closed) + <-s.unlock + return nil +} + +func TestClientHandleCDNConnDeadFingerprintMissInvalidatesCache(t *testing.T) { + a := require.New(t) + client := Client{ + log: zap.NewNop(), + } + client.init() + + const dcID = 9 + conn := &closeInvokerStub{closedCh: make(chan struct{})} + client.cdnPools.conns[dcID] = []cachedCDNPool{{ + conn: conn, + max: 1, + }} + client.cdnKeysSet = true + client.cdnKeys = []PublicKey{{}} + + onDeadCalls := 0 + client.onDead = func(error) { + onDeadCalls++ + } + + client.handleCDNConnDead(dcID, exchange.ErrKeyFingerprintNotFound) + + client.cdnPools.mux.Lock() + _, ok := client.cdnPools.conns[dcID] + client.cdnPools.mux.Unlock() + a.False(ok) + select { + case <-conn.closedCh: + case <-time.After(time.Second): + t.Fatal("expected async close call") + } + a.True(conn.closed) + a.False(client.cdnKeysSet) + a.Nil(client.cdnKeys) + a.Equal(0, onDeadCalls, "fingerprint miss should be handled internally without onDead callback") +} + +func TestClientHandleCDNConnDeadFingerprintMissDoesNotBlockOnClose(t *testing.T) { + client := Client{ + log: zap.NewNop(), + } + client.init() + + const dcID = 10 + conn := &blockingCloseInvokerStub{ + closed: make(chan struct{}), + unlock: make(chan struct{}), + } + client.cdnPools.conns[dcID] = []cachedCDNPool{{ + conn: conn, + max: 1, + }} + + done := make(chan struct{}) + go func() { + client.handleCDNConnDead(dcID, exchange.ErrKeyFingerprintNotFound) + close(done) + }() + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("handleCDNConnDead blocked on pool close") + } + + select { + case <-conn.closed: + case <-time.After(time.Second): + t.Fatal("expected async close call") + } + close(conn.unlock) +} + +type observedBlockingCloseInvokerStub struct { + started chan struct{} + unlock chan struct{} +} + +func (*observedBlockingCloseInvokerStub) Invoke(context.Context, bin.Encoder, bin.Decoder) error { + return nil +} + +func (s *observedBlockingCloseInvokerStub) Close() error { + select { + case s.started <- struct{}{}: + default: + } + <-s.unlock + return nil +} + +func TestClientHandleCDNConnDeadFingerprintMissStartsMultipleCloseWorkers(t *testing.T) { + client := Client{ + log: zap.NewNop(), + } + client.init() + + const ( + dcID = 11 + totalStale = 32 + maxWorkers = 4 + minWorkers = 2 + waitForStart = time.Second + ) + started := make(chan struct{}, totalStale) + unlock := make(chan struct{}) + pools := make([]cachedCDNPool, 0, totalStale) + for i := 0; i < totalStale; i++ { + pools = append(pools, cachedCDNPool{ + conn: &observedBlockingCloseInvokerStub{ + started: started, + unlock: unlock, + }, + max: int64(i + 1), + }) + } + client.cdnPools.conns[dcID] = pools + + client.handleCDNConnDead(dcID, exchange.ErrKeyFingerprintNotFound) + + startedCount := 0 + deadline := time.After(waitForStart) + for startedCount < minWorkers { + select { + case <-started: + startedCount++ + case <-deadline: + t.Fatalf("expected at least %d close workers to start, got %d", minWorkers, startedCount) + } + } + + // Workers are blocked in Close(); after a short wait the amount of started + // workers should stay bounded. + time.Sleep(100 * time.Millisecond) + select { + case <-started: + startedCount++ + default: + } + for { + select { + case <-started: + startedCount++ + default: + goto done + } + } +done: + if startedCount > maxWorkers { + t.Fatalf("expected bounded close workers <= %d, got %d", maxWorkers, startedCount) + } + + close(unlock) +} + +func TestClientHandleCDNConnDeadFingerprintMissProcessesMultipleCallsInParallel(t *testing.T) { + client := Client{ + log: zap.NewNop(), + } + client.init() + + started := make(chan struct{}, 2) + unlock := make(chan struct{}) + first := &observedBlockingCloseInvokerStub{started: started, unlock: unlock} + second := &observedBlockingCloseInvokerStub{started: started, unlock: unlock} + + client.cdnPools.conns[12] = []cachedCDNPool{{conn: first, max: 1}} + client.handleCDNConnDead(12, exchange.ErrKeyFingerprintNotFound) + + select { + case <-started: + case <-time.After(time.Second): + t.Fatal("expected first close call") + } + + client.cdnPools.conns[13] = []cachedCDNPool{{conn: second, max: 1}} + client.handleCDNConnDead(13, exchange.ErrKeyFingerprintNotFound) + + select { + case <-started: + case <-time.After(time.Second): + t.Fatal("expected second close call to start without waiting first close") + } + + close(unlock) +} diff --git a/telegram/pool.go b/telegram/pool.go index fae90db104..3cc161f97b 100644 --- a/telegram/pool.go +++ b/telegram/pool.go @@ -56,7 +56,13 @@ func (c *Client) Pool(max int64) (CloseInvoker, error) { }) } -func (c *Client) dc(ctx context.Context, dcID int, max int64, dialer mtproto.Dialer) (*pool.DC, error) { +func (c *Client) dc( + ctx context.Context, + dcID int, + max int64, + dialer mtproto.Dialer, + mode manager.ConnMode, +) (*pool.DC, error) { if max < 0 { return nil, errors.Errorf("invalid max value %d", max) } @@ -72,6 +78,24 @@ func (c *Client) dc(ctx context.Context, dcID int, max int64, dialer mtproto.Dia ) opts := c.opts + if mode == manager.ConnModeCDN { + // TDesktop-compatible gate: CDN connection is allowed only when keyset + // for requested CDN DC is present (or can be fetched). + cdnKeys, set := c.cachedCDNKeysForDC(dcID) + if !set || len(cdnKeys) == 0 { + fetched, err := c.fetchCDNKeysForDC(ctx, dcID) + if err != nil { + return nil, errors.Wrapf(err, "fetch CDN public keys for DC %d", dcID) + } + cdnKeys = fetched + } + if len(cdnKeys) == 0 { + return nil, errors.Errorf("no CDN public keys available for CDN DC %d", dcID) + } + // Keep CDN keys first and extend with bundled keys for fingerprint + // compatibility fallback, matching TDesktop key lookup behavior. + opts.PublicKeys = mergePublicKeys(cdnKeys, opts.PublicKeys) + } // suppressSetup temporarily disables per-connection transfer hook while // explicit first transfer below is running, avoiding duplicate import. var suppressSetup atomic.Bool @@ -79,32 +103,50 @@ func (c *Client) dc(ctx context.Context, dcID int, max int64, dialer mtproto.Dia id := c.connsCounter.Inc() c.sessionsMux.Lock() - session, ok := c.sessions[dcID] + sessions := c.sessions + if mode == manager.ConnModeCDN { + // Keep CDN auth key lifecycle separated from regular DC sessions. + sessions = c.cdnSessions + } + session, ok := sessions[dcID] if !ok { session = pool.NewSyncSession(pool.Session{DC: dcID}) - c.sessions[dcID] = session + sessions[dcID] = session } c.sessionsMux.Unlock() options, data := session.Options(opts) setup := manager.SetupCallback(nil) - if data.AuthKey.Zero() && c.session.Load().DC != dcID && !suppressSetup.Load() { + handler := c.asHandler() + if mode != manager.ConnModeCDN && + data.AuthKey.Zero() && + c.session.Load().DC != dcID && + !suppressSetup.Load() { // Non-main DC key must be authorized via auth.export/import after // local key generation. setup = c.dcTransferSetup(dcID) } + if mode == manager.ConnModeCDN { + // CDN pools do not process updates and use dedicated session store. + handler = c.asCDNHandler() + } options.Logger = c.log.Named("conn").With( zap.Int64("conn_id", id), zap.Int("dc_id", dcID), ) return c.create( - dialer, manager.ConnModeData, c.appID, + dialer, mode, c.appID, options, manager.ConnOptions{ DC: dcID, Device: c.device, - Handler: c.asHandler(), + Handler: handler, Setup: setup, OnDead: func(err error) { + if mode == manager.ConnModeCDN { + // CDN dead handler also manages CDN key invalidation. + c.handleCDNConnDead(dcID, err) + return + } c.handleDCConnDead(dcID, err) }, }, @@ -114,6 +156,12 @@ func (c *Client) dc(ctx context.Context, dcID int, max int64, dialer mtproto.Dia return nil, errors.Wrap(err, "create pool") } + if mode == manager.ConnModeCDN { + // No auth transfer for CDN mode: CDN API uses file tokens and does not + // require auth.export/import bootstrap. + return p, nil + } + // First transfer is done explicitly to preserve old behavior: return // transfer errors from DC pool creation. Setup callback remains enabled for // future reconnections when keys are re-generated inside the pool. @@ -136,7 +184,7 @@ func (c *Client) dc(ctx context.Context, dcID int, max int64, dialer mtproto.Dia // DC creates new multi-connection invoker to given DC. func (c *Client) DC(ctx context.Context, dc int, max int64) (CloseInvoker, error) { - return c.dc(ctx, dc, max, c.primaryDC(dc)) + return c.dc(ctx, dc, max, c.primaryDC(dc), manager.ConnModeData) } // MediaOnly creates new multi-connection invoker to given DC ID. @@ -144,5 +192,35 @@ func (c *Client) DC(ctx context.Context, dc int, max int64) (CloseInvoker, error func (c *Client) MediaOnly(ctx context.Context, dc int, max int64) (CloseInvoker, error) { return c.dc(ctx, dc, max, func(ctx context.Context) (transport.Conn, error) { return c.resolver.MediaOnly(ctx, dc, c.dcList()) - }) + }, manager.ConnModeData) +} + +// CDN creates new multi-connection invoker to given CDN DC ID. +// It connects to CDN DCs. +func (c *Client) CDN(ctx context.Context, dc int, max int64) (CloseInvoker, error) { + if max < 0 { + return nil, errors.Errorf("invalid max value %d", max) + } + need := normalizeCDNPoolMax(max) + + if cached, ok := c.cdnPools.acquire(dc, need); ok { + // Reuse existing pool to avoid extra TCP/MTProto handshakes. + return cached, nil + } + + // Keep shared CDN pools per DC with max-aware reuse. + created, err := c.dc(ctx, dc, need, func(ctx context.Context) (transport.Conn, error) { + return c.resolver.CDN(ctx, dc, c.dcList()) + }, manager.ConnModeCDN) + if err != nil { + return nil, err + } + + handle, reused := c.cdnPools.publishOrAcquire(dc, need, created) + if reused { + // Lost race: another goroutine already published suitable pool. + _ = created.Close() + return handle, nil + } + return handle, nil } diff --git a/telegram/pool_cdn_test.go b/telegram/pool_cdn_test.go new file mode 100644 index 0000000000..13eed1fc76 --- /dev/null +++ b/telegram/pool_cdn_test.go @@ -0,0 +1,503 @@ +package telegram + +import ( + "context" + "crypto/rsa" + "errors" + "math/big" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/gotd/td/bin" + "github.com/gotd/td/mtproto" + "github.com/gotd/td/pool" + "github.com/gotd/td/telegram/internal/manager" + "github.com/gotd/td/tg" +) + +func newCDNPoolTestClient() *Client { + c := &Client{ + log: zap.NewNop(), + } + c.init() + c.ctx, c.cancel = context.WithCancel(context.Background()) + c.cfg.Store(tg.Config{ + DCOptions: []tg.DCOption{{ + ID: 203, + IPAddress: "127.0.0.1", + Port: 443, + CDN: true, + }}, + }) + // Skip network in tests: prefill CDN cache and keep at least one bundled key. + baseKey := PublicKey{RSA: &rsa.PublicKey{N: big.NewInt(251), E: 65537}} + c.opts.PublicKeys = []PublicKey{baseKey} + c.cdnKeysSet = true + c.cdnKeys = []PublicKey{baseKey} + c.cdnKeysByDC = map[int][]PublicKey{ + 203: {baseKey}, + } + + return c +} + +func unwrapCDNHandle(t *testing.T, inv CloseInvoker) *cdnPoolHandle { + t.Helper() + + h, ok := inv.(*cdnPoolHandle) + require.True(t, ok) + return h +} + +func cachedCDNHandleForTest(m *cdnPoolManager, conn CloseInvoker) CloseInvoker { + m.mux.Lock() + defer m.mux.Unlock() + return m.cachedHandleLocked(conn) +} + +func closeCachedCDNPools(c *Client) { + c.cdnPools.mux.Lock() + defer c.cdnPools.mux.Unlock() + + for _, pools := range c.cdnPools.conns { + for _, cached := range pools { + _ = cached.conn.Close() + } + } +} + +type countingCloseInvoker struct { + closed int +} + +func (s *countingCloseInvoker) Invoke(context.Context, bin.Encoder, bin.Decoder) error { + return nil +} + +func (s *countingCloseInvoker) Close() error { + s.closed++ + return nil +} + +type signalCloseInvoker struct { + once sync.Once + ch chan struct{} +} + +func (*signalCloseInvoker) Invoke(context.Context, bin.Encoder, bin.Decoder) error { + return nil +} + +func (s *signalCloseInvoker) Close() error { + s.once.Do(func() { + close(s.ch) + }) + return nil +} + +type idlePoolConn struct { + ready chan struct{} +} + +func newIdlePoolConn() *idlePoolConn { + ready := make(chan struct{}) + close(ready) + return &idlePoolConn{ready: ready} +} + +func (c *idlePoolConn) Run(ctx context.Context) error { + <-ctx.Done() + return ctx.Err() +} + +func (*idlePoolConn) Invoke(context.Context, bin.Encoder, bin.Decoder) error { + return nil +} + +func (*idlePoolConn) Ping(context.Context) error { + return nil +} + +func (c *idlePoolConn) Ready() <-chan struct{} { + return c.ready +} + +func TestClientCDNPoolCacheRespectsMax(t *testing.T) { + c := newCDNPoolTestClient() + defer c.cancel() + defer closeCachedCDNPools(c) + + first, err := c.CDN(context.Background(), 203, 1) + require.NoError(t, err) + second, err := c.CDN(context.Background(), 203, 8) + require.NoError(t, err) + require.NotSame(t, first, second) + + reused, err := c.CDN(context.Background(), 203, 2) + require.NoError(t, err) + require.NotSame(t, second, reused) + + firstShared := unwrapCDNHandle(t, first).conn + secondShared := unwrapCDNHandle(t, second).conn + reusedShared := unwrapCDNHandle(t, reused).conn + require.NotSame(t, firstShared, secondShared) + require.Same(t, secondShared, reusedShared) + + c.cdnPools.mux.Lock() + pools := append([]cachedCDNPool(nil), c.cdnPools.conns[203]...) + c.cdnPools.mux.Unlock() + require.Len(t, pools, 2) +} + +func TestNormalizeCDNPoolMax(t *testing.T) { + require.EqualValues(t, 0, normalizeCDNPoolMax(0)) + require.EqualValues(t, 1, normalizeCDNPoolMax(1)) + require.EqualValues(t, 2, normalizeCDNPoolMax(2)) + require.EqualValues(t, 4, normalizeCDNPoolMax(3)) + require.EqualValues(t, 4, normalizeCDNPoolMax(4)) + require.EqualValues(t, 8, normalizeCDNPoolMax(5)) +} + +func TestClientCDNPoolCacheNormalizesNearbyMax(t *testing.T) { + c := newCDNPoolTestClient() + defer c.cancel() + defer closeCachedCDNPools(c) + + first, err := c.CDN(context.Background(), 203, 3) + require.NoError(t, err) + second, err := c.CDN(context.Background(), 203, 4) + require.NoError(t, err) + require.NotSame(t, first, second) + + third, err := c.CDN(context.Background(), 203, 5) + require.NoError(t, err) + require.NotSame(t, second, third) + + firstShared := unwrapCDNHandle(t, first).conn + secondShared := unwrapCDNHandle(t, second).conn + thirdShared := unwrapCDNHandle(t, third).conn + require.Same(t, firstShared, secondShared) + require.NotSame(t, secondShared, thirdShared) + + c.cdnPools.mux.Lock() + pools := append([]cachedCDNPool(nil), c.cdnPools.conns[203]...) + c.cdnPools.mux.Unlock() + require.Len(t, pools, 2) +} + +func TestClientCDNPoolCloseKeepsCacheForReuse(t *testing.T) { + c := newCDNPoolTestClient() + defer c.cancel() + defer closeCachedCDNPools(c) + + first, err := c.CDN(context.Background(), 203, 4) + require.NoError(t, err) + firstShared := unwrapCDNHandle(t, first).conn + require.NoError(t, first.Close()) + + c.cdnPools.mux.Lock() + poolsAfterClose := append([]cachedCDNPool(nil), c.cdnPools.conns[203]...) + refsAfterClose := c.cdnPools.refs[firstShared] + c.cdnPools.mux.Unlock() + require.Len(t, poolsAfterClose, 1) + require.EqualValues(t, 1, refsAfterClose) + + second, err := c.CDN(context.Background(), 203, 4) + require.NoError(t, err) + secondShared := unwrapCDNHandle(t, second).conn + require.Same(t, firstShared, secondShared) + require.NoError(t, second.Close()) +} + +func TestClientCDNPoolHandleDoubleClose(t *testing.T) { + c := newCDNPoolTestClient() + defer c.cancel() + defer closeCachedCDNPools(c) + + h, err := c.CDN(context.Background(), 203, 4) + require.NoError(t, err) + require.NoError(t, h.Close()) + require.ErrorIs(t, h.Close(), errCDNPoolHandleDouble) +} + +func TestClientCDNPoolHandleCloseWaitsForLastHandle(t *testing.T) { + c := newCDNPoolTestClient() + defer c.cancel() + defer closeCachedCDNPools(c) + + const dcID = 203 + shared := &countingCloseInvoker{} + c.cdnPools.conns[dcID] = []cachedCDNPool{{ + conn: shared, + max: 4, + }} + + first := cachedCDNHandleForTest(&c.cdnPools, shared) + second := cachedCDNHandleForTest(&c.cdnPools, shared) + + require.NoError(t, first.Close()) + require.Equal(t, 0, shared.closed, "shared pool must stay alive while second handle is active") + + c.cdnPools.mux.Lock() + _, ok := pickCDNPool(c.cdnPools.conns[dcID], 1) + c.cdnPools.mux.Unlock() + require.True(t, ok, "cache entry must remain until last handle closes") + + require.NoError(t, second.Close()) + require.Equal(t, 0, shared.closed, "underlying pool must remain open in cache after last handle close") + + c.cdnPools.mux.Lock() + refs := c.cdnPools.refs[shared] + c.cdnPools.mux.Unlock() + require.EqualValues(t, 1, refs, "cache owner ref should remain for reuse") +} + +func TestCDNPoolManagerDrainIncludesPendingCloseQueue(t *testing.T) { + m := newCDNPoolManager() + + connA := &countingCloseInvoker{} + connB := &countingCloseInvoker{} + connC := &countingCloseInvoker{} + m.conns[203] = []cachedCDNPool{ + {conn: connA, max: 1}, + {conn: connB, max: 2}, + } + // connB is duplicated intentionally: drain() must deduplicate results. + m.closeQueue = []CloseInvoker{connB, connC} + + drained := m.drain() + require.Len(t, drained, 3) + require.Empty(t, m.conns) + require.Empty(t, m.refs) + require.Empty(t, m.closeQueue) +} + +func TestCDNPoolManagerQueueSaturationDoesNotSpawnDetachedClose(t *testing.T) { + m := newCDNPoolManager() + queued := &signalCloseInvoker{ch: make(chan struct{})} + + m.mux.Lock() + m.closeWorkers = maxCDNCloseWorkers + m.closeBusy = maxCDNCloseWorkers + m.closeQueue = make([]CloseInvoker, maxCDNCloseQueue) + for i := range m.closeQueue { + m.closeQueue[i] = &countingCloseInvoker{} + } + m.enqueueCloseLocked([]CloseInvoker{queued}) + require.Len(t, m.closeQueue, maxCDNCloseQueue) + m.mux.Unlock() + + select { + case <-queued.ch: + t.Fatal("unexpected detached close while workers are saturated") + case <-time.After(100 * time.Millisecond): + } + + m.mux.Lock() + require.Len(t, m.closeQueue, maxCDNCloseQueue) + m.mux.Unlock() +} + +type blockingSignalInvoker struct { + started chan struct{} + unlock chan struct{} + done chan struct{} + once sync.Once +} + +func (*blockingSignalInvoker) Invoke(context.Context, bin.Encoder, bin.Decoder) error { + return nil +} + +func (s *blockingSignalInvoker) Close() error { + s.once.Do(func() { + close(s.started) + }) + <-s.unlock + close(s.done) + return nil +} + +func TestCDNPoolManagerQueueOverflowEventuallyClosesPending(t *testing.T) { + m := newCDNPoolManager() + + blocked := &blockingSignalInvoker{ + started: make(chan struct{}), + unlock: make(chan struct{}), + done: make(chan struct{}), + } + overflow := &signalCloseInvoker{ch: make(chan struct{})} + + m.mux.Lock() + m.closeWorkers = 1 + m.closeBusy = 1 + m.closeQueue = make([]CloseInvoker, 0, maxCDNCloseQueue) + m.closeQueue = append(m.closeQueue, blocked) + for i := 1; i < maxCDNCloseQueue; i++ { + m.closeQueue = append(m.closeQueue, &countingCloseInvoker{}) + } + m.closing[blocked] = true + m.enqueueCloseLocked([]CloseInvoker{overflow}) + require.Len(t, m.closeQueue, maxCDNCloseQueue) + m.mux.Unlock() + + // Free one worker slot and start workers manually. + m.mux.Lock() + m.closeBusy = 0 + m.closeWorkers = 0 + m.mux.Unlock() + go m.runCloseWorker() + + select { + case <-blocked.started: + case <-time.After(time.Second): + t.Fatal("expected queued blocked close to start") + } + + close(blocked.unlock) + + select { + case <-overflow.ch: + case <-time.After(time.Second): + t.Fatal("expected overflow pending close to be processed after queue drains") + } +} + +func TestClientCDNUsesDCSpecificKeysOverBase(t *testing.T) { + c := newCDNPoolTestClient() + defer c.cancel() + + baseKey := PublicKey{RSA: &rsa.PublicKey{N: big.NewInt(257), E: 65537}} + cdnKey := PublicKey{RSA: &rsa.PublicKey{N: big.NewInt(263), E: 65537}} + + c.opts.PublicKeys = []PublicKey{baseKey} + c.cdnKeysSet = true + c.cdnKeys = []PublicKey{cdnKey} + c.cdnKeysByDC = map[int][]PublicKey{ + 203: {cdnKey}, + } + + captured := make(chan []PublicKey, 1) + c.create = func( + _ mtproto.Dialer, + mode manager.ConnMode, + _ int, + opts mtproto.Options, + _ manager.ConnOptions, + ) pool.Conn { + if mode == manager.ConnModeCDN { + captured <- append([]PublicKey(nil), opts.PublicKeys...) + } + return newIdlePoolConn() + } + + inv, err := c.CDN(context.Background(), 203, 1) + require.NoError(t, err) + require.NotNil(t, inv) + defer func() { + require.NoError(t, inv.Close()) + }() + require.NoError(t, inv.Invoke(context.Background(), nil, nil)) + + got := <-captured + require.Equal(t, []PublicKey{cdnKey, baseKey}, got) +} + +func TestClientCDNWithoutDCSpecificKeysFailsFast(t *testing.T) { + c := newCDNPoolTestClient() + defer c.cancel() + + baseKey := PublicKey{RSA: &rsa.PublicKey{N: big.NewInt(269), E: 65537}} + c.opts.PublicKeys = []PublicKey{baseKey} + c.cdnKeysSet = true + c.cdnKeys = nil + c.cdnKeysByDC = map[int][]PublicKey{} + c.tg = tg.NewClient(InvokeFunc(func(context.Context, bin.Encoder, bin.Decoder) error { + return errors.New("cdn config unavailable") + })) + + var calls atomic.Int32 + c.create = func( + _ mtproto.Dialer, + _ manager.ConnMode, + _ int, + _ mtproto.Options, + _ manager.ConnOptions, + ) pool.Conn { + calls.Add(1) + return newIdlePoolConn() + } + + inv, err := c.CDN(context.Background(), 203, 1) + require.Error(t, err) + require.ErrorContains(t, err, "fetch CDN public keys for DC 203") + require.Nil(t, inv) + require.Zero(t, calls.Load()) +} + +func TestClientCDNWithoutAnyKeysFailsFast(t *testing.T) { + c := newCDNPoolTestClient() + defer c.cancel() + + c.opts.PublicKeys = nil + c.cdnKeysSet = true + c.cdnKeys = nil + c.cdnKeysByDC = map[int][]PublicKey{} + c.tg = tg.NewClient(InvokeFunc(func(context.Context, bin.Encoder, bin.Decoder) error { + return errors.New("cdn config unavailable") + })) + + var calls atomic.Int32 + c.create = func( + _ mtproto.Dialer, + _ manager.ConnMode, + _ int, + _ mtproto.Options, + _ manager.ConnOptions, + ) pool.Conn { + calls.Add(1) + return newIdlePoolConn() + } + + inv, err := c.CDN(context.Background(), 203, 1) + require.Error(t, err) + require.Nil(t, inv) + require.Zero(t, calls.Load()) +} + +func TestClientCDNFetchDCKeysErrorFailsFast(t *testing.T) { + c := newCDNPoolTestClient() + defer c.cancel() + + baseKey := PublicKey{RSA: &rsa.PublicKey{N: big.NewInt(271), E: 65537}} + c.opts.PublicKeys = []PublicKey{baseKey} + c.cdnKeysSet = false + c.cdnKeys = nil + c.cdnKeysByDC = nil + c.tg = tg.NewClient(InvokeFunc(func(context.Context, bin.Encoder, bin.Decoder) error { + return errors.New("cdn config unavailable") + })) + + var calls atomic.Int32 + c.create = func( + _ mtproto.Dialer, + _ manager.ConnMode, + _ int, + _ mtproto.Options, + _ manager.ConnOptions, + ) pool.Conn { + calls.Add(1) + return newIdlePoolConn() + } + + inv, err := c.CDN(context.Background(), 203, 1) + require.Error(t, err) + require.ErrorContains(t, err, "fetch CDN public keys for DC 203") + require.Nil(t, inv) + require.Zero(t, calls.Load()) +} diff --git a/telegram/pool_validation_test.go b/telegram/pool_validation_test.go new file mode 100644 index 0000000000..60454eec65 --- /dev/null +++ b/telegram/pool_validation_test.go @@ -0,0 +1,24 @@ +package telegram + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPoolMethodValidation(t *testing.T) { + c := &Client{} + + _, err := c.Pool(-1) + require.ErrorContains(t, err, "invalid max value -1") + + _, err = c.DC(context.Background(), 2, -1) + require.ErrorContains(t, err, "invalid max value -1") + + _, err = c.MediaOnly(context.Background(), 2, -1) + require.ErrorContains(t, err, "invalid max value -1") + + _, err = c.CDN(context.Background(), 203, -1) + require.ErrorContains(t, err, "invalid max value -1") +} diff --git a/telegram/session.go b/telegram/session.go index d55d074422..97a62a6a1c 100644 --- a/telegram/session.go +++ b/telegram/session.go @@ -98,19 +98,9 @@ func (c *Client) saveSession(cfg tg.Config, s mtproto.Session) error { } func (c *Client) onSession(cfg tg.Config, s mtproto.Session) error { - keyToStore := s.Key - if !s.PermKey.Zero() { - // Keep in-memory/persisted key format backward-compatible: one key slot. - keyToStore = s.PermKey - } - - c.sessionsMux.Lock() - c.sessions[cfg.ThisDC] = pool.NewSyncSession(pool.Session{ - DC: cfg.ThisDC, - Salt: s.Salt, - AuthKey: keyToStore, - }) - c.sessionsMux.Unlock() + sessionData := dcSessionFromMTProto(cfg.ThisDC, s) + // Track per-DC session in memory for pool reconnections/migrations. + c.storeDCSess(c.sessions, sessionData) primaryDC := c.session.Load().DC // Do not save session for non-primary DC. @@ -119,11 +109,7 @@ func (c *Client) onSession(cfg tg.Config, s mtproto.Session) error { } c.connMux.Lock() - c.session.Store(pool.Session{ - DC: cfg.ThisDC, - Salt: s.Salt, - AuthKey: keyToStore, - }) + c.session.Store(sessionData) c.cfg.Store(cfg) c.onReady() c.connMux.Unlock() @@ -134,3 +120,36 @@ func (c *Client) onSession(cfg tg.Config, s mtproto.Session) error { return nil } + +func (c *Client) onCDNSession(cfg tg.Config, s mtproto.Session) error { + // CDN sessions are isolated from regular DC map because lifecycle and reset + // triggers differ (fingerprint misses, CDN-specific reconnects). + c.storeDCSess(c.cdnSessions, dcSessionFromMTProto(cfg.ThisDC, s)) + return nil +} + +func (c *Client) storeDCSess(target map[int]*pool.SyncSession, data pool.Session) { + c.sessionsMux.Lock() + if existing, ok := target[data.DC]; ok { + existing.Store(data) + c.sessionsMux.Unlock() + return + } + target[data.DC] = pool.NewSyncSession(data) + c.sessionsMux.Unlock() +} + +func dcSessionFromMTProto(dc int, s mtproto.Session) pool.Session { + keyToStore := s.Key + if !s.PermKey.Zero() { + // Keep in-memory/persisted key format backward-compatible: one key slot. + // In PFS mode temp key rotates, so we pin permanent key. + keyToStore = s.PermKey + } + + return pool.Session{ + DC: dc, + Salt: s.Salt, + AuthKey: keyToStore, + } +} diff --git a/telegram/session_cdn_test.go b/telegram/session_cdn_test.go new file mode 100644 index 0000000000..5ac0e8709b --- /dev/null +++ b/telegram/session_cdn_test.go @@ -0,0 +1,73 @@ +package telegram + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/gotd/td/crypto" + "github.com/gotd/td/mtproto" + "github.com/gotd/td/pool" + "github.com/gotd/td/tg" +) + +func TestClientOnCDNSessionStoresSeparateMap(t *testing.T) { + a := require.New(t) + c := &Client{ + log: zap.NewNop(), + } + c.init() + + const dcID = 7 + regularKey := crypto.Key{1}.WithID() + cdnKey := crypto.Key{2}.WithID() + + c.sessions[dcID] = pool.NewSyncSession(pool.Session{ + DC: dcID, + AuthKey: regularKey, + Salt: 11, + }) + + err := c.onCDNSession(tg.Config{ThisDC: dcID}, mtproto.Session{ + Key: cdnKey, + Salt: 22, + }) + a.NoError(err) + + regular := c.sessions[dcID].Load() + a.Equal(regularKey, regular.AuthKey) + a.Equal(int64(11), regular.Salt) + + cdn, ok := c.cdnSessions[dcID] + a.True(ok) + cdnData := cdn.Load() + a.Equal(cdnKey, cdnData.AuthKey) + a.Equal(int64(22), cdnData.Salt) +} + +func TestCDNHandlerUsesCDNSessionPath(t *testing.T) { + a := require.New(t) + c := &Client{ + log: zap.NewNop(), + } + c.init() + + const dcID = 8 + cdnKey := crypto.Key{3}.WithID() + + h := c.asCDNHandler() + err := h.OnSession(tg.Config{ThisDC: dcID}, mtproto.Session{ + Key: cdnKey, + Salt: 33, + }) + a.NoError(err) + + _, regularOk := c.sessions[dcID] + a.False(regularOk) + + cdn, cdnOK := c.cdnSessions[dcID] + a.True(cdnOK) + a.Equal(cdnKey, cdn.Load().AuthKey) + a.Equal(int64(33), cdn.Load().Salt) +} diff --git a/telegram/sub_conns.go b/telegram/sub_conns.go index 94d6e2dfd3..0b31dd8791 100644 --- a/telegram/sub_conns.go +++ b/telegram/sub_conns.go @@ -6,6 +6,7 @@ import ( "github.com/go-faster/errors" "github.com/gotd/td/bin" + "github.com/gotd/td/telegram/internal/manager" ) func (c *Client) invokeSub(ctx context.Context, dc int, input bin.Encoder, output bin.Decoder) error { @@ -17,7 +18,8 @@ func (c *Client) invokeSub(ctx context.Context, dc int, input bin.Encoder, outpu return conn.Invoke(ctx, input, output) } - conn, err := c.dc(ctx, dc, 1, c.primaryDC(dc)) + // Sub-invoker is regular data connection to target DC. + conn, err := c.dc(ctx, dc, 1, c.primaryDC(dc), manager.ConnModeData) if err != nil { c.subConnsMux.Unlock() return errors.Wrapf(err, "create connection to DC %d", dc)