diff --git a/plugins/meta/vrf/vrf.go b/plugins/meta/vrf/vrf.go index f265c071f..042e72449 100644 --- a/plugins/meta/vrf/vrf.go +++ b/plugins/meta/vrf/vrf.go @@ -17,6 +17,7 @@ package main import ( "fmt" "math" + "strings" "github.com/vishvananda/netlink" ) @@ -104,6 +105,14 @@ func addInterface(vrf *netlink.Vrf, intf string) error { if err != nil { return fmt.Errorf("failed getting ipv6 addresses for %s", intf) } + + // Save routes that are setup for the interface, before setting master, + // because otherwise the routes will be deleted after interface is moved. + routes, err := netlink.RouteList(i, netlink.FAMILY_ALL) + if err != nil { + return fmt.Errorf("failed getting all routes for %s", intf) + } + err = netlink.LinkSetMaster(i, vrf) if err != nil { return fmt.Errorf("could not set vrf %s as master of %s: %v", vrf.Name, intf, err) @@ -130,6 +139,21 @@ CONTINUE: } } + // Apply all saved routes for the interface that was moved to the VRF + for _, route := range routes { + r := route + // Modify original table to vrf one, + // equivalent of 'ip route add
table '. + r.Table = int(vrf.Table) + err = netlink.RouteAdd(&r) + if err != nil { + // If route is already present, returned error is "file exists" + if !strings.Contains(fmt.Sprintf("%v", err), "file exists") { + return fmt.Errorf("error while adding route \"%s\": %v\n", r, err) + } + } + } + return nil } diff --git a/plugins/meta/vrf/vrf_test.go b/plugins/meta/vrf/vrf_test.go index 8eb2fbea3..b8e8ac046 100644 --- a/plugins/meta/vrf/vrf_test.go +++ b/plugins/meta/vrf/vrf_test.go @@ -17,10 +17,13 @@ package main import ( "encoding/json" "fmt" + "net" + "strings" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" "github.com/containernetworking/cni/pkg/skel" "github.com/containernetworking/cni/pkg/types" @@ -107,7 +110,7 @@ var _ = Describe("vrf plugin", func() { }, }) Expect(err).NotTo(HaveOccurred()) - _, err = netlink.LinkByName(IF0Name) + _, err = netlink.LinkByName(IF1Name) Expect(err).NotTo(HaveOccurred()) return nil }) @@ -177,6 +180,102 @@ var _ = Describe("vrf plugin", func() { Expect(err).NotTo(HaveOccurred()) }) + It("adds the interface and custom routing to new VRF", func() { + conf := configWithRouteFor("test", IF0Name, VRF0Name, "10.0.0.2/24", "10.10.10.0/24") + + By("Setting custom routing first", func() { + err := targetNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + + ipv4, err := types.ParseCIDR("10.0.0.2/24") + Expect(err).NotTo(HaveOccurred()) + Expect(ipv4).NotTo(BeNil()) + + _, routev4, err := net.ParseCIDR("10.10.10.0/24") + Expect(err).NotTo(HaveOccurred()) + + ipv6, err := types.ParseCIDR("abcd:1234:ffff::cdde/64") + Expect(err).NotTo(HaveOccurred()) + Expect(ipv6).NotTo(BeNil()) + + _, routev6, err := net.ParseCIDR("1111:dddd::/80") + Expect(err).NotTo(HaveOccurred()) + Expect(routev6).NotTo(BeNil()) + + link, err := netlink.LinkByName(IF0Name) + Expect(err).NotTo(HaveOccurred()) + + // Add IP addresses for network reachability + netlink.AddrAdd(link, &netlink.Addr{IPNet: ipv4}) + netlink.AddrAdd(link, &netlink.Addr{IPNet: ipv6}) + + ipAddrs, err := netlink.AddrList(link, netlink.FAMILY_V4) + Expect(err).NotTo(HaveOccurred()) + // Check if address was assigned properly + Expect(ipAddrs[0].IP.String()).To(Equal("10.0.0.2")) + + // Set interface UP, otherwise local route to 10.0.0.0/24 is not present + err = netlink.LinkSetUp(link) + Expect(err).NotTo(HaveOccurred()) + + // Add additional route to 10.10.10.0/24 via 10.0.0.1 gateway + r := netlink.Route{ + LinkIndex: link.Attrs().Index, + Src: ipv4.IP, + Dst: routev4, + Gw: net.ParseIP("10.0.0.1"), + } + err = netlink.RouteAdd(&r) + Expect(err).NotTo(HaveOccurred()) + + r6 := netlink.Route{ + LinkIndex: link.Attrs().Index, + Src: ipv6.IP, + Dst: routev6, + Gw: net.ParseIP("abcd:1234:ffff::1"), + } + err = netlink.RouteAdd(&r6) + Expect(err).NotTo(HaveOccurred()) + + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + + args := &skel.CmdArgs{ + ContainerID: "dummy", + Netns: targetNS.Path(), + IfName: IF0Name, + StdinData: conf, + } + + err := originalNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + r, _, err := testutils.CmdAddWithArgs(args, func() error { + return cmdAdd(args) + }) + Expect(err).NotTo(HaveOccurred()) + + result, err := current.GetResult(r) + Expect(err).NotTo(HaveOccurred()) + + Expect(result.Interfaces).To(HaveLen(1)) + Expect(result.Interfaces[0].Name).To(Equal(IF0Name)) + Expect(result.Routes).To(HaveLen(1)) + Expect(result.Routes[0].Dst.IP.String()).To(Equal("10.10.10.0")) + return nil + }) + Expect(err).NotTo(HaveOccurred()) + + err = targetNS.Do(func(ns.NetNS) error { + defer GinkgoRecover() + checkInterfaceOnVRF(VRF0Name, IF0Name) + checkRoutesOnVRF(VRF0Name, IF0Name, "10.10.10.0/24", "1111:dddd::/80") + return nil + }) + Expect(err).NotTo(HaveOccurred()) + }) + It("fails if the interface already has a master set", func() { conf := configFor("test", IF0Name, VRF0Name, "10.0.0.2/24") @@ -690,6 +789,35 @@ func configWithTableFor(name, intf, vrf, ip string, tableID int) []byte { return []byte(conf) } +func configWithRouteFor(name, intf, vrf, ip, route string) []byte { + conf := fmt.Sprintf(`{ + "name": "%s", + "type": "vrf", + "cniVersion": "0.3.1", + "vrfName": "%s", + "prevResult": { + "interfaces": [ + {"name": "%s", "sandbox":"netns"} + ], + "ips": [ + { + "version": "4", + "address": "%s", + "gateway": "10.0.0.1", + "interface": 0 + } + ], + "routes": [ + { + "dst": "%s", + "gw": "10.0.0.1" + } + ] + } + }`, name, vrf, intf, ip, route) + return []byte(conf) +} + func checkInterfaceOnVRF(vrfName, intfName string) { vrf, err := netlink.LinkByName(vrfName) Expect(err).NotTo(HaveOccurred()) @@ -702,3 +830,41 @@ func checkInterfaceOnVRF(vrfName, intfName string) { Expect(err).NotTo(HaveOccurred()) Expect(master.Attrs().Name).To(Equal(vrfName)) } + +func checkRoutesOnVRF(vrfName, intfName string, routesToCheck ...string) { + vrf, err := netlink.LinkByName(vrfName) + Expect(err).NotTo(HaveOccurred()) + Expect(vrf).To(BeAssignableToTypeOf(&netlink.Vrf{})) + + link, err := netlink.LinkByName(intfName) + Expect(err).NotTo(HaveOccurred()) + + err = netlink.LinkSetUp(link) + Expect(err).NotTo(HaveOccurred()) + + ipAddrs, err := netlink.AddrList(link, netlink.FAMILY_V4) + Expect(err).NotTo(HaveOccurred()) + Expect(ipAddrs).To(HaveLen(1)) + Expect(ipAddrs[0].IP.String()).To(Equal("10.0.0.2")) + + // Need to read all tables, so cannot use RouteList + routeFilter := &netlink.Route{ + LinkIndex: link.Attrs().Index, + Table: unix.RT_TABLE_UNSPEC, + } + + routes, err := netlink.RouteListFiltered(netlink.FAMILY_ALL, + routeFilter, + netlink.RT_FILTER_OIF|netlink.RT_FILTER_TABLE) + Expect(err).NotTo(HaveOccurred()) + + routesRead := []string{} + for _, route := range routes { + routesRead = append(routesRead, route.String()) + Expect(uint32(route.Table)).To(Equal(vrf.(*netlink.Vrf).Table)) + } + routesStr := strings.Join(routesRead, "\n") + for _, route := range routesToCheck { + Expect(routesStr).To(ContainSubstring(route)) + } +}