diff --git a/nexthop.go b/nexthop.go index 4f961665..f4407230 100644 --- a/nexthop.go +++ b/nexthop.go @@ -13,15 +13,25 @@ type Nexthop struct { OIF uint32 Gateway net.IP Protocol RouteProtocol + Encap Encap } func (h *Nexthop) String() string { + if h == nil { + return "" + } elems := []string{ "ID: " + strconv.FormatUint(uint64(h.ID), 10), "Blackhole: " + strconv.FormatBool(h.Blackhole), "OIF: " + strconv.FormatUint(uint64(h.OIF), 10), "Gateway: " + h.Gateway.String(), "Protocol: " + h.Protocol.String(), + "Encap: " + func() string { + if h.Encap != nil { + return h.Encap.String() + } + return "" + }(), } return fmt.Sprintf("{%s}", strings.Join(elems, " ")) } diff --git a/nexthop_linux.go b/nexthop_linux.go index d9f0c6d6..8023ffcd 100644 --- a/nexthop_linux.go +++ b/nexthop_linux.go @@ -112,9 +112,10 @@ var nexthopAttrHandlers = map[uint16]struct { encode func(*Nexthop) *nl.RtAttr // decode decodes the corresponding attribute from RtAttr into Nexthop // It must perform bounds check for the given attribute's data and does - // nothing if the attribute encoding is invalid. - decode func(*Nexthop, *nl.RtAttr) - // match reports whether the given Nexthop + // nothing if the attribute encoding is invalid. The third parameter is + // the read-only full list of RtAttr for possible multi-attribute + // decoding. + decode func(*Nexthop, *nl.RtAttr, []*nl.RtAttr) }{ unix.NHA_ID: { encode: func(nh *Nexthop) *nl.RtAttr { @@ -125,7 +126,7 @@ var nexthopAttrHandlers = map[uint16]struct { } return nil }, - decode: func(nh *Nexthop, attr *nl.RtAttr) { + decode: func(nh *Nexthop, attr *nl.RtAttr, _ []*nl.RtAttr) { if len(attr.Data) < 4 { return } @@ -139,7 +140,7 @@ var nexthopAttrHandlers = map[uint16]struct { } return nil }, - decode: func(nh *Nexthop, attr *nl.RtAttr) { + decode: func(nh *Nexthop, attr *nl.RtAttr, _ []*nl.RtAttr) { nh.Blackhole = true }, }, @@ -152,7 +153,7 @@ var nexthopAttrHandlers = map[uint16]struct { } return nil }, - decode: func(nh *Nexthop, attr *nl.RtAttr) { + decode: func(nh *Nexthop, attr *nl.RtAttr, _ []*nl.RtAttr) { if len(attr.Data) < 4 { return } @@ -169,13 +170,55 @@ var nexthopAttrHandlers = map[uint16]struct { } return nil }, - decode: func(nh *Nexthop, attr *nl.RtAttr) { + decode: func(nh *Nexthop, attr *nl.RtAttr, _ []*nl.RtAttr) { if len(attr.Data) != 0 { nh.Gateway = make(net.IP, len(attr.Data)) copy(nh.Gateway, attr.Data) } }, }, + unix.NHA_ENCAP_TYPE: { + encode: func(nh *Nexthop) *nl.RtAttr { + if nh.Encap != nil { + b := make([]byte, 2) + native.PutUint16(b, uint16(nh.Encap.Type())) + return nl.NewRtAttr(unix.NHA_ENCAP_TYPE, b) + } + return nil + }, + }, + unix.NHA_ENCAP: { + encode: func(nh *Nexthop) *nl.RtAttr { + if nh.Encap != nil { + data, err := nh.Encap.Encode() + if err != nil { + return nil + } + return nl.NewRtAttr(unix.NHA_ENCAP|unix.NLA_F_NESTED, data) + } + return nil + }, + decode: func(nh *Nexthop, attr *nl.RtAttr, allAttrs []*nl.RtAttr) { + typ := nl.LWTUNNEL_ENCAP_NONE + for _, a := range allAttrs { + if a.Type == unix.NHA_ENCAP_TYPE { + if len(a.Data) < 2 { + return + } + typ = int(native.Uint16(a.Data[0:2])) + break + } + } + if typ == nl.LWTUNNEL_ENCAP_NONE { + return + } + e, err := decodeEncap(typ, attr.Data) + if err != nil { + return + } + nh.Encap = e + }, + }, } // encodeNexthopAttrs encodes the attributes in the Nexthop into the slice of @@ -206,7 +249,7 @@ func decodeNexthopAttrs(nh *Nexthop, attrs []*nl.RtAttr) { if !found || handler.decode == nil { continue } - handler.decode(nh, attr) + handler.decode(nh, attr, attrs) } } @@ -255,6 +298,8 @@ func prepareNewNexthop(nh *Nexthop, req *nl.NetlinkRequest, msg *nl.Nhmsg) error unix.NHA_BLACKHOLE, unix.NHA_OIF, unix.NHA_GATEWAY, + unix.NHA_ENCAP_TYPE, + unix.NHA_ENCAP, })...) msg.Family = deriveFamilyFromNexthop(nh) diff --git a/nexthop_test.go b/nexthop_test.go index ebe1ae80..e877e78a 100644 --- a/nexthop_test.go +++ b/nexthop_test.go @@ -8,9 +8,19 @@ import ( "slices" "testing" + "github.com/vishvananda/netlink/nl" "golang.org/x/sys/unix" ) +func TestNexthopString(t *testing.T) { + var nh *Nexthop = nil + // Ensure calling String() does not panic with nil receiver + t.Log(nh.String()) + // Ensure calling String() does not panic with empty fields + nh = &Nexthop{} + t.Log(nh.String()) +} + func TestNexthopAddListDelReplace(t *testing.T) { t.Cleanup(setUpNetlinkTest(t)) @@ -153,3 +163,60 @@ func TestNexthopAddListDelReplace(t *testing.T) { t.Fatalf("Nexthop Gateway mismatch: expected %s, got %s", nh2.Gateway, resNH2.Gateway) } } + +func TestNexthopEncap(t *testing.T) { + t.Cleanup(setUpNetlinkTest(t)) + + // get loopback interface + loop, err := LinkByName("lo") + if err != nil { + t.Fatal(err) + } + + // bring the interface up + if err = LinkSetUp(loop); err != nil { + t.Fatal(err) + } + + nh := &Nexthop{ + ID: 1, + OIF: uint32(loop.Attrs().Index), + Encap: &SEG6Encap{ + Mode: nl.SEG6_IPTUN_MODE_ENCAP, + Segments: []net.IP{ + net.ParseIP("2001:db8:1234::"), + }, + }, + } + + if err = NexthopAdd(nh); err != nil { + t.Fatal(err) + } + + nhs, err := NexthopList() + if err != nil { + t.Fatal(err) + } + if len(nhs) != 1 { + t.Fatalf("Expected 1 nexthop, got %d", len(nhs)) + } + + // Check we can read what we wrote + resNH := nhs[0] + if resNH.Encap == nil { + t.Fatal("Nexthop Encap is nil") + } + seg6Encap, ok := resNH.Encap.(*SEG6Encap) + if !ok { + t.Fatalf("Nexthop Encap is not SEG6Encap, got %T", resNH.Encap) + } + if seg6Encap.Mode != nl.SEG6_IPTUN_MODE_ENCAP { + t.Fatalf("Nexthop Encap Mode mismatch: expected %d, got %d", nl.SEG6_IPTUN_MODE_ENCAP, seg6Encap.Mode) + } + if len(seg6Encap.Segments) != 1 { + t.Fatalf("Expected 1 segment, got %d", len(seg6Encap.Segments)) + } + if !seg6Encap.Segments[0].Equal(net.ParseIP("2001:db8:1234::")) { + t.Fatalf("Nexthop Encap Segment mismatch: expected %s, got %s", "2001:db8:1234::", seg6Encap.Segments[0]) + } +} diff --git a/route_linux.go b/route_linux.go index 1f99a17d..a6cda1c3 100644 --- a/route_linux.go +++ b/route_linux.go @@ -1569,33 +1569,9 @@ func deserializeRoute(m []byte) (Route, error) { if len(encap.Value) != 0 && len(encapType.Value) != 0 { typ := int(native.Uint16(encapType.Value[0:2])) - var e Encap - switch typ { - case nl.LWTUNNEL_ENCAP_MPLS: - e = &MPLSEncap{} - if err := e.Decode(encap.Value); err != nil { - return route, err - } - case nl.LWTUNNEL_ENCAP_SEG6: - e = &SEG6Encap{} - if err := e.Decode(encap.Value); err != nil { - return route, err - } - case nl.LWTUNNEL_ENCAP_SEG6_LOCAL: - e = &SEG6LocalEncap{} - if err := e.Decode(encap.Value); err != nil { - return route, err - } - case nl.LWTUNNEL_ENCAP_BPF: - e = &BpfEncap{} - if err := e.Decode(encap.Value); err != nil { - return route, err - } - case nl.LWTUNNEL_ENCAP_IP6: - e = &IP6tnlEncap{} - if err := e.Decode(encap.Value); err != nil { - return route, err - } + e, err := decodeEncap(typ, encap.Value) + if err != nil { + return route, err } route.Encap = e } @@ -1603,6 +1579,38 @@ func deserializeRoute(m []byte) (Route, error) { return route, nil } +func decodeEncap(typ int, v []byte) (Encap, error) { + var e Encap + switch typ { + case nl.LWTUNNEL_ENCAP_MPLS: + e = &MPLSEncap{} + if err := e.Decode(v); err != nil { + return nil, err + } + case nl.LWTUNNEL_ENCAP_SEG6: + e = &SEG6Encap{} + if err := e.Decode(v); err != nil { + return nil, err + } + case nl.LWTUNNEL_ENCAP_SEG6_LOCAL: + e = &SEG6LocalEncap{} + if err := e.Decode(v); err != nil { + return nil, err + } + case nl.LWTUNNEL_ENCAP_BPF: + e = &BpfEncap{} + if err := e.Decode(v); err != nil { + return nil, err + } + case nl.LWTUNNEL_ENCAP_IP6: + e = &IP6tnlEncap{} + if err := e.Decode(v); err != nil { + return nil, err + } + } + return e, nil +} + // RouteGetOptions contains a set of options to use with // RouteGetWithOptions type RouteGetOptions struct {