Skip to content

Commit a1227ab

Browse files
committed
implement NRI plugin server to inject management CDI devices
Signed-off-by: Tariq Ibrahim <[email protected]>
1 parent 786aa3b commit a1227ab

File tree

588 files changed

+165216
-22
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

588 files changed

+165216
-22
lines changed

cmd/nvidia-ctk-installer/container/container.go

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,15 @@ type Options struct {
4949
// mount.
5050
ExecutablePath string
5151
// EnabledCDI indicates whether CDI should be enabled.
52-
EnableCDI bool
53-
RuntimeName string
54-
RuntimeDir string
55-
SetAsDefault bool
56-
RestartMode string
57-
HostRootMount string
52+
EnableCDI bool
53+
EnableNRI bool
54+
RuntimeName string
55+
RuntimeDir string
56+
SetAsDefault bool
57+
RestartMode string
58+
HostRootMount string
59+
NRIPluginIndex string
60+
NRISocket string
5861

5962
ConfigSources []string
6063
}
@@ -128,6 +131,10 @@ func (o Options) UpdateConfig(cfg engine.Interface) error {
128131
cfg.EnableCDI()
129132
}
130133

134+
if o.EnableNRI {
135+
cfg.EnableNRI()
136+
}
137+
131138
return nil
132139
}
133140

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
/**
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package nri
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"os"
23+
"strings"
24+
25+
"github.com/containerd/nri/pkg/api"
26+
nriplugin "github.com/containerd/nri/pkg/stub"
27+
28+
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
29+
)
30+
31+
// Compile-time interface checks
32+
var (
33+
_ nriplugin.Plugin = (*Plugin)(nil)
34+
)
35+
36+
const (
37+
// nriCDIDeviceKey is the prefix of the key used for CDI device annotations.
38+
nriCDIDeviceKey = "nvidia.cdi.k8s.io"
39+
// defaultNRISocket represents the default path of the NRI socket
40+
defaultNRISocket = "/var/run/nri/nri.sock"
41+
)
42+
43+
type Plugin struct {
44+
logger logger.Interface
45+
46+
stub nriplugin.Stub
47+
}
48+
49+
// NewPlugin creates a new NRI plugin for injecting CDI devices
50+
func NewPlugin(logger logger.Interface) *Plugin {
51+
return &Plugin{
52+
logger: logger,
53+
}
54+
}
55+
56+
// CreateContainer handles container creation requests.
57+
func (p *Plugin) CreateContainer(_ context.Context, pod *api.PodSandbox, ctr *api.Container) (*api.ContainerAdjustment, []*api.ContainerUpdate, error) {
58+
adjust := &api.ContainerAdjustment{}
59+
60+
if err := p.injectCDIDevices(pod, ctr, adjust); err != nil {
61+
return nil, nil, err
62+
}
63+
64+
return adjust, nil, nil
65+
}
66+
67+
func (p *Plugin) injectCDIDevices(pod *api.PodSandbox, ctr *api.Container, a *api.ContainerAdjustment) error {
68+
devices, err := parseCDIDevices(ctr.Name, pod.Annotations)
69+
if err != nil {
70+
return err
71+
}
72+
73+
if len(devices) == 0 {
74+
p.logger.Debugf("%s: no CDI devices annotated...", containerName(pod, ctr))
75+
return nil
76+
}
77+
78+
for _, name := range devices {
79+
a.AddCDIDevice(
80+
&api.CDIDevice{
81+
Name: name,
82+
},
83+
)
84+
p.logger.Infof("%s: injected CDI device %q...", containerName(pod, ctr), name)
85+
}
86+
87+
return nil
88+
}
89+
90+
func parseCDIDevices(ctr string, annotations map[string]string) ([]string, error) {
91+
annotation := getAnnotation(annotations, nriCDIDeviceKey, ctr)
92+
if len(annotation) == 0 {
93+
return nil, nil
94+
}
95+
96+
cdiDevices := strings.Split(annotation, ",")
97+
return cdiDevices, nil
98+
}
99+
100+
func getAnnotation(annotations map[string]string, key, ctr string) string {
101+
nriPluginAnnotationKey := fmt.Sprintf("%s/container.%s", key, ctr)
102+
if value, ok := annotations[nriPluginAnnotationKey]; ok {
103+
return value
104+
}
105+
106+
return ""
107+
}
108+
109+
// Construct a container name for log messages.
110+
func containerName(pod *api.PodSandbox, container *api.Container) string {
111+
if pod != nil {
112+
return pod.Name + "/" + container.Name
113+
}
114+
return container.Name
115+
}
116+
117+
// Start starts the NRI plugin
118+
func (p *Plugin) Start(ctx context.Context, nriSocketPath, nriPluginIdx string) error {
119+
if len(nriSocketPath) == 0 {
120+
nriSocketPath = defaultNRISocket
121+
}
122+
_, err := os.Stat(nriSocketPath)
123+
if err != nil {
124+
return fmt.Errorf("failed to find valid nri socket in %s: %w", nriSocketPath, err)
125+
}
126+
127+
pluginOpts := []nriplugin.Option{
128+
nriplugin.WithPluginIdx(nriPluginIdx),
129+
nriplugin.WithSocketPath(nriSocketPath),
130+
}
131+
if p.stub, err = nriplugin.New(p, pluginOpts...); err != nil {
132+
return fmt.Errorf("failed to initialise plugin at %s: %w", nriSocketPath, err)
133+
}
134+
err = p.stub.Start(ctx)
135+
if err != nil {
136+
return fmt.Errorf("plugin exited with error: %w", err)
137+
}
138+
return nil
139+
}
140+
141+
// Stop stops the NRI plugin
142+
func (p *Plugin) Stop() {
143+
if p != nil && p.stub != nil {
144+
p.stub.Stop()
145+
}
146+
}

cmd/nvidia-ctk-installer/container/runtime/runtime.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ const (
3434
// defaultRuntimeName specifies the NVIDIA runtime to be use as the default runtime if setting the default runtime is enabled
3535
defaultRuntimeName = "nvidia"
3636
defaultHostRootMount = "/host"
37+
defaultNRIPluginIdx = "10"
38+
defaultNRISocket = "/var/run/nri/nri.sock"
3739

3840
runtimeSpecificDefault = "RUNTIME_SPECIFIC_DEFAULT"
3941
)
@@ -94,6 +96,27 @@ func Flags(opts *Options) []cli.Flag {
9496
Destination: &opts.EnableCDI,
9597
Sources: cli.EnvVars("RUNTIME_ENABLE_CDI"),
9698
},
99+
&cli.BoolFlag{
100+
Name: "enable-nri-in-runtime",
101+
Usage: "Enable NRI in the configured runtime",
102+
Destination: &opts.EnableNRI,
103+
Value: true,
104+
Sources: cli.EnvVars("RUNTIME_ENABLE_NRI"),
105+
},
106+
&cli.StringFlag{
107+
Name: "nri-plugin-index",
108+
Usage: "Specify the plugin index to register to NRI",
109+
Value: defaultNRIPluginIdx,
110+
Destination: &opts.NRIPluginIndex,
111+
Sources: cli.EnvVars("RUNTIME_NRI_PLUGIN_INDEX"),
112+
},
113+
&cli.StringFlag{
114+
Name: "nri-socket",
115+
Usage: "Specify the path to the NRI socket file to register the NRI plugin server",
116+
Value: defaultNRISocket,
117+
Destination: &opts.NRISocket,
118+
Sources: cli.EnvVars("RUNTIME_NRI_SOCKET"),
119+
},
97120
&cli.StringFlag{
98121
Name: "host-root",
99122
Usage: "Specify the path to the host root to be used when restarting the runtime using systemd",

cmd/nvidia-ctk-installer/main.go

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ import (
77
"os/signal"
88
"path/filepath"
99
"syscall"
10+
"time"
1011

1112
"github.com/urfave/cli/v3"
1213
"golang.org/x/sys/unix"
1314

1415
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime"
16+
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/container/runtime/nri"
1517
"github.com/NVIDIA/nvidia-container-toolkit/cmd/nvidia-ctk-installer/toolkit"
1618
"github.com/NVIDIA/nvidia-container-toolkit/internal/info"
1719
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
@@ -26,6 +28,9 @@ const (
2628
toolkitSubDir = "toolkit"
2729

2830
defaultRuntime = "docker"
31+
32+
retryBackoff = 2 * time.Second
33+
maxRetryAttempts = 5
2934
)
3035

3136
var availableRuntimes = map[string]struct{}{"docker": {}, "crio": {}, "containerd": {}}
@@ -73,7 +78,7 @@ type app struct {
7378
toolkit *toolkit.Installer
7479
}
7580

76-
// NewApp creates the CLI app fro the specified options.
81+
// NewApp creates the CLI app from the specified options.
7782
func NewApp(logger logger.Interface) *cli.Command {
7883
a := app{
7984
logger: logger,
@@ -93,8 +98,8 @@ func (a app) build() *cli.Command {
9398
Before: func(ctx context.Context, cmd *cli.Command) (context.Context, error) {
9499
return ctx, a.Before(cmd, &options)
95100
},
96-
Action: func(_ context.Context, cmd *cli.Command) error {
97-
return a.Run(cmd, &options)
101+
Action: func(ctx context.Context, cmd *cli.Command) error {
102+
return a.Run(ctx, cmd, &options)
98103
},
99104
Flags: []cli.Flag{
100105
&cli.BoolFlag{
@@ -194,7 +199,7 @@ func (a *app) validateFlags(c *cli.Command, o *options) error {
194199
// Run installs the NVIDIA Container Toolkit and updates the requested runtime.
195200
// If the application is run as a daemon, the application waits and unconfigures
196201
// the runtime on termination.
197-
func (a *app) Run(c *cli.Command, o *options) error {
202+
func (a *app) Run(ctx context.Context, c *cli.Command, o *options) error {
198203
err := a.initialize(o.pidFile)
199204
if err != nil {
200205
return fmt.Errorf("unable to initialize: %v", err)
@@ -222,6 +227,14 @@ func (a *app) Run(c *cli.Command, o *options) error {
222227
}
223228

224229
if !o.noDaemon {
230+
if o.runtimeOptions.EnableNRI {
231+
nriPlugin, err := a.startNRIPluginServer(ctx, o.runtimeOptions)
232+
if err != nil {
233+
a.logger.Errorf("unable to start NRI plugin server: %v", err)
234+
}
235+
defer nriPlugin.Stop()
236+
}
237+
225238
err = a.waitForSignal()
226239
if err != nil {
227240
return fmt.Errorf("unable to wait for signal: %v", err)
@@ -287,6 +300,31 @@ func (a *app) waitForSignal() error {
287300
return nil
288301
}
289302

303+
func (a *app) startNRIPluginServer(ctx context.Context, opts runtime.Options) (*nri.Plugin, error) {
304+
a.logger.Infof("Starting the NRI Plugin server....")
305+
306+
plugin := nri.NewPlugin(a.logger)
307+
retriable := func() error {
308+
return plugin.Start(ctx, opts.NRISocket, opts.NRIPluginIndex)
309+
}
310+
var err error
311+
for i := 0; i < maxRetryAttempts; i++ {
312+
err = retriable()
313+
if err == nil {
314+
break
315+
}
316+
if i == maxRetryAttempts-1 {
317+
break
318+
}
319+
time.Sleep(retryBackoff)
320+
}
321+
if err != nil {
322+
a.logger.Errorf("Max retries reached %d/%d, aborting", maxRetryAttempts, maxRetryAttempts)
323+
return nil, err
324+
}
325+
return plugin, nil
326+
}
327+
290328
func (a *app) shutdown(pidFile string) {
291329
a.logger.Infof("Shutting Down")
292330

cmd/nvidia-ctk-installer/main_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ version = 2
444444
"--pid-file=" + filepath.Join(testRoot, "toolkit.pid"),
445445
"--restart-mode=none",
446446
"--toolkit-source-root=" + filepath.Join(artifactRoot, "deb"),
447+
"--enable-nri-in-runtime=false",
447448
}
448449

449450
err := app.Run(context.Background(), append(testArgs, tc.args...))

go.mod

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ go 1.25.0
55
require (
66
github.com/NVIDIA/go-nvlib v0.9.1-0.20251202135446-d0f42ba016dd
77
github.com/NVIDIA/go-nvml v0.13.0-1
8+
github.com/containerd/nri v0.10.1-0.20251120153915-7d8611f87ad7
89
github.com/google/uuid v1.6.0
910
github.com/moby/sys/mountinfo v0.7.2
1011
github.com/moby/sys/reexec v0.1.0
@@ -25,18 +26,25 @@ require (
2526

2627
require (
2728
cyphar.com/go-pathrs v0.2.1 // indirect
29+
github.com/containerd/log v0.1.0 // indirect
30+
github.com/containerd/ttrpc v1.2.7 // indirect
2831
github.com/cyphar/filepath-securejoin v0.6.0 // indirect
2932
github.com/davecgh/go-spew v1.1.1 // indirect
3033
github.com/fsnotify/fsnotify v1.7.0 // indirect
34+
github.com/golang/protobuf v1.5.3 // indirect
3135
github.com/hashicorp/errwrap v1.1.0 // indirect
32-
github.com/kr/pretty v0.3.1 // indirect
36+
github.com/knqyf263/go-plugin v0.9.0 // indirect
37+
github.com/kr/text v0.2.0 // indirect
3338
github.com/moby/sys/capability v0.4.0 // indirect
3439
github.com/opencontainers/cgroups v0.0.4 // indirect
3540
github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 // indirect
3641
github.com/pmezard/go-difflib v1.0.0 // indirect
3742
github.com/rogpeppe/go-internal v1.11.0 // indirect
43+
github.com/tetratelabs/wazero v1.9.0 // indirect
3844
github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect
39-
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
45+
google.golang.org/genproto/googleapis/rpc v0.0.0-20230731190214-cbb8c96f2d6d // indirect
46+
google.golang.org/grpc v1.57.1 // indirect
47+
google.golang.org/protobuf v1.36.5 // indirect
4048
gopkg.in/yaml.v3 v3.0.1 // indirect
4149
sigs.k8s.io/yaml v1.4.0 // indirect
4250
)

0 commit comments

Comments
 (0)