diff --git a/conn_darwin.go b/conn_darwin.go new file mode 100644 index 0000000..2e0e2fc --- /dev/null +++ b/conn_darwin.go @@ -0,0 +1,62 @@ +//go:build darwin +// +build darwin + +package vsock + +import ( + "context" + + "github.com/mdlayher/socket" + "golang.org/x/sys/unix" +) + +// A conn is the net.Conn implementation for connection-oriented VM sockets. +// We can use socket.Conn directly on Linux to implement all of the necessary +// methods. +type conn = socket.Conn + +// dial is the entry point for Dial on Linux. +func dial(ctx context.Context, cid, port uint32, _ *Config) (*Conn, error) { + // TODO(mdlayher): Config default nil check and initialize. Pass options to + // socket.Config where necessary. + + c, err := socket.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0, "vsock", nil) + if err != nil { + return nil, err + } + + sa := &unix.SockaddrVM{CID: cid, Port: port} + rsa, err := c.Connect(ctx, sa) + if err != nil { + _ = c.Close() + return nil, err + } + + // TODO(mdlayher): getpeername(2) appears to return nil in the GitHub CI + // environment, so in the event of a nil sockaddr, fall back to the previous + // method of synthesizing the remote address. + if rsa == nil { + rsa = sa + } + + lsa, err := c.Getsockname() + if err != nil { + _ = c.Close() + return nil, err + } + + lsavm := lsa.(*unix.SockaddrVM) + rsavm := rsa.(*unix.SockaddrVM) + + return &Conn{ + c: c, + local: &Addr{ + ContextID: lsavm.CID, + Port: lsavm.Port, + }, + remote: &Addr{ + ContextID: rsavm.CID, + Port: rsavm.Port, + }, + }, nil +} diff --git a/conn_linux.go b/conn_linux.go index 6029d54..0453edf 100644 --- a/conn_linux.go +++ b/conn_linux.go @@ -16,7 +16,7 @@ import ( type conn = socket.Conn // dial is the entry point for Dial on Linux. -func dial(cid, port uint32, _ *Config) (*Conn, error) { +func dial(ctx context.Context, cid, port uint32, _ *Config) (*Conn, error) { // TODO(mdlayher): Config default nil check and initialize. Pass options to // socket.Config where necessary. @@ -26,7 +26,7 @@ func dial(cid, port uint32, _ *Config) (*Conn, error) { } sa := &unix.SockaddrVM{CID: cid, Port: port} - rsa, err := c.Connect(context.Background(), sa) + rsa, err := c.Connect(ctx, sa) if err != nil { _ = c.Close() return nil, err diff --git a/fd_darwin.go b/fd_darwin.go new file mode 100644 index 0000000..4955766 --- /dev/null +++ b/fd_darwin.go @@ -0,0 +1,37 @@ +package vsock + +import ( + "fmt" + + "golang.org/x/sys/unix" +) + +// contextID retrieves the local context ID for this system. +func contextID() (uint32, error) { + if fd, err := unix.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0); err != nil { + return 2, nil + } else { + defer unix.Close(fd) + + cid, err := unix.IoctlGetInt(fd, unix.IOCTL_VM_SOCKETS_GET_LOCAL_CID) + + return uint32(cid), err + } +} + +// isErrno determines if an error a matches UNIX error number. +func isErrno(err error, errno int) bool { + switch errno { + case ebadf: + return err == unix.EBADF + case enotconn: + return err == unix.ENOTCONN + default: + panicf("vsock: isErrno called with unhandled error number parameter: %d", errno) + return false + } +} + +func panicf(format string, a ...interface{}) { + panic(fmt.Sprintf(format, a...)) +} diff --git a/go.mod b/go.mod index 2cf1b59..8a838c7 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,9 @@ module github.com/mdlayher/vsock go 1.20 require ( - github.com/google/go-cmp v0.5.9 - github.com/mdlayher/socket v0.4.1 - golang.org/x/net v0.9.0 - golang.org/x/sync v0.1.0 - golang.org/x/sys v0.7.0 + github.com/google/go-cmp v0.6.0 + github.com/mdlayher/socket v0.5.1 + golang.org/x/net v0.33.0 + golang.org/x/sync v0.10.0 + golang.org/x/sys v0.28.0 ) diff --git a/go.sum b/go.sum index 0723272..6d94c24 100644 --- a/go.sum +++ b/go.sum @@ -1,16 +1,10 @@ -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/mdlayher/socket v0.4.0 h1:280wsy40IC9M9q1uPGcLBwXpcTQDtoGwVt+BNoITxIw= -github.com/mdlayher/socket v0.4.0/go.mod h1:xxFqz5GRCUN3UEOm9CZqEJsAbe1C8OwSK46NlmWuVoc= -github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= -github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= -golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= -golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= -golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= -golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= -golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/listener_darwin.go b/listener_darwin.go new file mode 100644 index 0000000..e6f7f8e --- /dev/null +++ b/listener_darwin.go @@ -0,0 +1,133 @@ +//go:build darwin +// +build darwin + +package vsock + +import ( + "context" + "net" + "os" + "time" + + "github.com/mdlayher/socket" + "golang.org/x/sys/unix" +) + +var _ net.Listener = &listener{} + +// A listener is the net.Listener implementation for connection-oriented +// VM sockets. +type listener struct { + c *socket.Conn + addr *Addr +} + +// Addr and Close implement the net.Listener interface for listener. +func (l *listener) Addr() net.Addr { return l.addr } +func (l *listener) Close() error { return l.c.Close() } +func (l *listener) SetDeadline(t time.Time) error { return l.c.SetDeadline(t) } + +// Accept accepts a single connection from the listener, and sets up +// a net.Conn backed by conn. +func (l *listener) Accept() (net.Conn, error) { + c, rsa, err := l.c.Accept(context.Background(), 0) + if err != nil { + return nil, err + } + + savm := rsa.(*unix.SockaddrVM) + remote := &Addr{ + ContextID: savm.CID, + Port: savm.Port, + } + + return &Conn{ + c: c, + local: l.addr, + remote: remote, + }, nil +} + +// name is the socket name passed to package socket. +const name = "vsock" + +// listen is the entry point for Listen on Linux. +func listen(cid, port uint32, _ *Config) (*Listener, error) { + // TODO(mdlayher): Config default nil check and initialize. Pass options to + // socket.Config where necessary. + + c, err := socket.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0, name, nil) + if err != nil { + return nil, err + } + + // Be sure to close the Conn if any of the system calls fail before we + // return the Conn to the caller. + + if port == 0 { + port = unix.VMADDR_PORT_ANY + } + + if err := c.Bind(&unix.SockaddrVM{CID: cid, Port: port}); err != nil { + _ = c.Close() + return nil, err + } + + if err := c.Listen(unix.SOMAXCONN); err != nil { + _ = c.Close() + return nil, err + } + + l, err := newListener(c) + if err != nil { + _ = c.Close() + return nil, err + } + + return l, nil +} + +// fileListener is the entry point for FileListener on Linux. +func fileListener(f *os.File) (*Listener, error) { + c, err := socket.FileConn(f, name) + if err != nil { + return nil, err + } + + l, err := newListener(c) + if err != nil { + _ = c.Close() + return nil, err + } + + return l, nil +} + +// newListener creates a Listener from a raw socket.Conn. +func newListener(c *socket.Conn) (*Listener, error) { + lsa, err := c.Getsockname() + if err != nil { + return nil, err + } + + // Now that the library can also accept arbitrary os.Files, we have to + // verify the address family so we don't accidentally create a + // *vsock.Listener backed by TCP or some other socket type. + lsavm, ok := lsa.(*unix.SockaddrVM) + if !ok { + // All errors should wrapped with os.SyscallError. + return nil, os.NewSyscallError("listen", unix.EINVAL) + } + + addr := &Addr{ + ContextID: lsavm.CID, + Port: lsavm.Port, + } + + return &Listener{ + l: &listener{ + c: c, + addr: addr, + }, + }, nil +} diff --git a/vsock.go b/vsock.go index 7876393..6c71d8d 100644 --- a/vsock.go +++ b/vsock.go @@ -1,11 +1,12 @@ package vsock import ( - "errors" + "context" "fmt" "io" "net" "os" + "runtime" "strings" "syscall" "time" @@ -54,6 +55,10 @@ const ( opWrite = "write" ) +// errUnimplemented is returned by all functions on platforms that +// cannot make use of VM sockets. +var errUnimplemented = fmt.Errorf("vsock: not implemented on %s", runtime.GOOS) + // TODO(mdlayher): plumb through socket.Config.NetNS if it makes sense. // Config contains options for a Conn or Listener. @@ -176,7 +181,21 @@ func (l *Listener) opError(op string, err error) error { // When the connection is no longer needed, Close must be called to free // resources. func Dial(contextID, port uint32, cfg *Config) (*Conn, error) { - c, err := dial(contextID, port, cfg) + return dial(context.Background(), contextID, port, cfg) +} + +// DialWithContext connects to the address on the named network using +// the provided context. +// +// The provided Context must be non-nil. If the context expires before +// the connection is complete, an error is returned. Once successfully +// connected, any expiration of the context will not affect the +// connection. +// +// See func Dial for a description of the contextID and port +// parameters. +func DialWithContext(ctx context.Context, contextID, port uint32, cfg *Config) (*Conn, error) { + c, err := dial(ctx, contextID, port, cfg) if err != nil { // No local address, but we have a remote address we can return. return nil, opError(opDial, err, nil, &Addr{ @@ -403,7 +422,7 @@ func opError(op string, err error, local, remote net.Addr) error { // // To rectify the differences, net.TCPConn uses an error with this text // from internal/poll for the backing file already being closed. - err = errors.New("use of closed network connection") + err = net.ErrClosed default: // Nothing to do, return this directly. } diff --git a/vsock_darwin_test.go b/vsock_darwin_test.go new file mode 100644 index 0000000..f966a0c --- /dev/null +++ b/vsock_darwin_test.go @@ -0,0 +1,267 @@ +package vsock + +import ( + "errors" + "io" + "net" + "os" + "testing" + + "github.com/google/go-cmp/cmp" + "golang.org/x/sys/unix" +) + +func Test_opError(t *testing.T) { + // The default op for empty op fields. + const defaultOp = "read" + + var ( + // Unfortunate, but string matching it is for now. + errClosed = errors.New("use of closed network connection") + + local = &Addr{ + ContextID: Host, + Port: 1024, + } + + remote = &Addr{ + ContextID: 3, + Port: 2048, + } + ) + + tests := []struct { + name string + op string + err error + local net.Addr + remote net.Addr + want error + }{ + { + name: "nil error", + }, + { + name: "unknown", + err: errors.New("foo"), + want: &net.OpError{ + Err: errors.New("foo"), + }, + }, + { + name: "EOF", + err: io.EOF, + want: io.EOF, + }, + { + name: "ENOTCONN", + err: unix.ENOTCONN, + want: io.EOF, + }, + { + name: "PathError ENOTCONN", + err: &os.PathError{ + Err: unix.ENOTCONN, + }, + want: io.EOF, + }, + { + name: "ErrClosed", + err: os.ErrClosed, + want: &net.OpError{ + Err: errClosed, + }, + }, + { + name: "EBADF", + err: unix.EBADF, + want: &net.OpError{ + Err: errClosed, + }, + }, + { + name: "string use of closed", + err: errors.New("use of closed file"), + want: &net.OpError{ + Err: errClosed, + }, + }, + { + name: "op close", + op: opClose, + err: errClosed, + local: local, + remote: remote, + want: &net.OpError{ + Op: opClose, + Source: local, + Addr: remote, + Err: errClosed, + }, + }, + { + name: "op dial", + op: opDial, + err: errClosed, + local: local, + remote: remote, + want: &net.OpError{ + Op: opDial, + Source: local, + Addr: remote, + Err: errClosed, + }, + }, + { + name: "op raw-read", + op: opRawRead, + err: errClosed, + local: local, + remote: remote, + want: &net.OpError{ + Op: opRawRead, + Source: local, + Addr: remote, + Err: errClosed, + }, + }, + { + name: "op raw-write", + op: opRawWrite, + err: errClosed, + local: local, + remote: remote, + want: &net.OpError{ + Op: opRawWrite, + Source: local, + Addr: remote, + Err: errClosed, + }, + }, + { + name: "op read", + op: opRead, + err: errClosed, + local: local, + remote: remote, + want: &net.OpError{ + Op: opRead, + Source: local, + Addr: remote, + Err: errClosed, + }, + }, + { + name: "op write", + op: opWrite, + err: errClosed, + local: local, + remote: remote, + want: &net.OpError{ + Op: opWrite, + Source: local, + Addr: remote, + Err: errClosed, + }, + }, + { + name: "op accept", + op: opAccept, + err: errClosed, + local: local, + want: &net.OpError{ + Op: opAccept, + Addr: local, + Err: errClosed, + }, + }, + { + name: "op listen", + op: opListen, + err: errClosed, + local: local, + want: &net.OpError{ + Op: opListen, + Addr: local, + Err: errClosed, + }, + }, + { + name: "op raw-control", + op: opRawControl, + err: errClosed, + local: local, + want: &net.OpError{ + Op: opRawControl, + Addr: local, + Err: errClosed, + }, + }, + { + name: "op set", + op: opSet, + err: errClosed, + local: local, + want: &net.OpError{ + Op: opSet, + Addr: local, + Err: errClosed, + }, + }, + { + name: "op syscall-conn", + op: opSyscallConn, + err: errClosed, + local: local, + want: &net.OpError{ + Op: opSyscallConn, + Addr: local, + Err: errClosed, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + op := tt.op + if op == "" { + op = defaultOp + } + + err := opError(op, tt.err, tt.local, tt.remote) + if err == nil { + if tt.want != nil { + t.Fatal("expected an output error, but none occurred") + } + + return + } + + // Populate sane defaults to save some typing. + want := tt.want + if nerr, ok := tt.want.(*net.OpError); ok { + if nerr.Op == "" { + nerr.Op = defaultOp + } + + if nerr.Net == "" { + nerr.Net = network + } + + want = nerr + } + + if diff := cmp.Diff(want, err, cmp.Comparer(errorsEqual)); diff != "" { + t.Fatalf("unexpected error (-want +got):\n%s", diff) + } + }) + } +} + +func errorsEqual(x, y error) bool { + if x == nil || y == nil { + return x == nil && y == nil + } + + return x.Error() == y.Error() +} diff --git a/vsock_others.go b/vsock_others.go index 5c1e88e..fe25690 100644 --- a/vsock_others.go +++ b/vsock_others.go @@ -1,21 +1,15 @@ -//go:build !linux -// +build !linux +//go:build !linux && !darwin +// +build !linux,!darwin package vsock import ( - "fmt" "net" "os" - "runtime" "syscall" "time" ) -// errUnimplemented is returned by all functions on platforms that -// cannot make use of VM sockets. -var errUnimplemented = fmt.Errorf("vsock: not implemented on %s", runtime.GOOS) - func fileListener(_ *os.File) (*Listener, error) { return nil, errUnimplemented } func listen(_, _ uint32, _ *Config) (*Listener, error) { return nil, errUnimplemented } diff --git a/vsock_others_test.go b/vsock_others_test.go index 4d33df9..5d88715 100644 --- a/vsock_others_test.go +++ b/vsock_others_test.go @@ -1,5 +1,5 @@ -//go:build !linux -// +build !linux +//go:build !linux && !darwin +// +build !linux,!darwin package vsock