|
14 | 14 | package awssession |
15 | 15 |
|
16 | 16 | import ( |
| 17 | + "context" |
17 | 18 | "fmt" |
18 | 19 | "net/http" |
19 | 20 | "os" |
20 | 21 |
|
| 22 | + "github.com/aws/aws-sdk-go-v2/aws" |
| 23 | + "github.com/aws/aws-sdk-go-v2/aws/retry" |
| 24 | + "github.com/aws/aws-sdk-go-v2/config" |
| 25 | + "github.com/aws/aws-sdk-go-v2/service/ec2" |
| 26 | + "github.com/aws/smithy-go" |
| 27 | + smithymiddleware "github.com/aws/smithy-go/middleware" |
| 28 | + smithyhttp "github.com/aws/smithy-go/transport/http" |
| 29 | + |
21 | 30 | "strconv" |
22 | 31 | "time" |
23 | 32 |
|
24 | 33 | "github.com/aws/amazon-vpc-cni-k8s/pkg/utils/logger" |
25 | 34 | "github.com/aws/amazon-vpc-cni-k8s/utils" |
26 | | - "github.com/aws/aws-sdk-go/aws" |
27 | | - "github.com/aws/aws-sdk-go/aws/endpoints" |
28 | | - "github.com/aws/aws-sdk-go/aws/request" |
29 | | - "github.com/aws/aws-sdk-go/aws/session" |
30 | | - "github.com/aws/aws-sdk-go/service/ec2" |
31 | 35 | ) |
32 | 36 |
|
33 | 37 | // Http client timeout env for sessions |
@@ -58,43 +62,84 @@ func getHTTPTimeout() time.Duration { |
58 | 62 | return httpTimeoutValue |
59 | 63 | } |
60 | 64 |
|
61 | | -// New will return an session for service clients |
62 | | -func New() *session.Session { |
63 | | - awsCfg := aws.Config{ |
64 | | - MaxRetries: aws.Int(maxRetries), |
65 | | - HTTPClient: &http.Client{ |
66 | | - Timeout: getHTTPTimeout(), |
67 | | - }, |
68 | | - STSRegionalEndpoint: endpoints.RegionalSTSEndpoint, |
| 65 | +// New will return aws.Config to be used by Service Clients. |
| 66 | +func New(ctx context.Context) (aws.Config, error) { |
| 67 | + customHTTPClient := &http.Client{ |
| 68 | + Timeout: getHTTPTimeout()} |
| 69 | + optFns := []func(*config.LoadOptions) error{ |
| 70 | + config.WithHTTPClient(customHTTPClient), |
| 71 | + config.WithRetryMaxAttempts(maxRetries), |
| 72 | + config.WithRetryer(func() aws.Retryer { |
| 73 | + return retry.NewStandard() |
| 74 | + }), |
| 75 | + injectUserAgent, |
69 | 76 | } |
70 | 77 |
|
71 | 78 | endpoint := os.Getenv("AWS_EC2_ENDPOINT") |
72 | 79 | if endpoint != "" { |
73 | | - customResolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { |
74 | | - if service == ec2.EndpointsID { |
75 | | - return endpoints.ResolvedEndpoint{ |
76 | | - URL: endpoint, |
77 | | - }, nil |
78 | | - } |
79 | | - return endpoints.DefaultResolver().EndpointFor(service, region, optFns...) |
80 | | - } |
81 | | - awsCfg.EndpointResolver = endpoints.ResolverFunc(customResolver) |
| 80 | + optFns = append(optFns, config.WithEndpointResolver(aws.EndpointResolverFunc( |
| 81 | + func(service, region string) (aws.Endpoint, error) { |
| 82 | + if service == ec2.ServiceID { |
| 83 | + return aws.Endpoint{ |
| 84 | + URL: endpoint, |
| 85 | + }, nil |
| 86 | + } |
| 87 | + // Fall back to default resolution |
| 88 | + return aws.Endpoint{}, &aws.EndpointNotFoundError{} |
| 89 | + }))) |
| 90 | + |
82 | 91 | } |
83 | 92 |
|
84 | | - sess := session.Must(session.NewSession(&awsCfg)) |
85 | | - //injecting session handler info |
86 | | - injectUserAgent(&sess.Handlers) |
| 93 | + cfg, err := config.LoadDefaultConfig(ctx, optFns...) |
| 94 | + |
| 95 | + if err != nil { |
| 96 | + return aws.Config{}, fmt.Errorf("failed to load AWS config: %w", err) |
| 97 | + } |
87 | 98 |
|
88 | | - return sess |
| 99 | + return cfg, nil |
89 | 100 | } |
90 | 101 |
|
91 | 102 | // injectUserAgent will inject app specific user-agent into awsSDK |
92 | | -func injectUserAgent(handlers *request.Handlers) { |
| 103 | +func injectUserAgent(loadOptions *config.LoadOptions) error { |
93 | 104 | version := utils.GetEnv(envVpcCniVersion, "") |
94 | | - handlers.Build.PushFrontNamed(request.NamedHandler{ |
95 | | - Name: fmt.Sprintf("%s/user-agent", "amazon-vpc-cni-k8s"), |
96 | | - Fn: request.MakeAddToUserAgentHandler( |
97 | | - "amazon-vpc-cni-k8s", |
98 | | - "version/"+version), |
| 105 | + userAgent := fmt.Sprintf("amazon-vpc-cni-k8s/version/%s", version) |
| 106 | + |
| 107 | + loadOptions.APIOptions = append(loadOptions.APIOptions, func(stack *smithymiddleware.Stack) error { |
| 108 | + return stack.Build.Add(&addUserAgentMiddleware{ |
| 109 | + userAgent: userAgent, |
| 110 | + }, smithymiddleware.After) |
99 | 111 | }) |
| 112 | + |
| 113 | + return nil |
| 114 | +} |
| 115 | + |
| 116 | +type addUserAgentMiddleware struct { |
| 117 | + userAgent string |
| 118 | +} |
| 119 | + |
| 120 | +func (m *addUserAgentMiddleware) HandleBuild(ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler) (out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error) { |
| 121 | + // Simply pass through to the next handler in the middleware chain |
| 122 | + return next.HandleBuild(ctx, in) |
| 123 | +} |
| 124 | + |
| 125 | +func (m *addUserAgentMiddleware) ID() string { |
| 126 | + return "AddUserAgent" |
| 127 | +} |
| 128 | + |
| 129 | +func (m *addUserAgentMiddleware) HandleFinalize(ctx context.Context, in smithymiddleware.FinalizeInput, next smithymiddleware.FinalizeHandler) ( |
| 130 | + out smithymiddleware.FinalizeOutput, metadata smithymiddleware.Metadata, err error) { |
| 131 | + req, ok := in.Request.(*smithyhttp.Request) |
| 132 | + if !ok { |
| 133 | + return out, metadata, &smithy.SerializationError{Err: fmt.Errorf("unknown request type %T", in.Request)} |
| 134 | + } |
| 135 | + |
| 136 | + userAgent := req.Header.Get("User-Agent") |
| 137 | + if userAgent == "" { |
| 138 | + userAgent = m.userAgent |
| 139 | + } else { |
| 140 | + userAgent += " " + m.userAgent |
| 141 | + } |
| 142 | + req.Header.Set("User-Agent", userAgent) |
| 143 | + |
| 144 | + return next.HandleFinalize(ctx, in) |
100 | 145 | } |
0 commit comments