From 5726292167809f5b7f619139a576c14eccae3e4c Mon Sep 17 00:00:00 2001 From: Niclas Schad Date: Thu, 7 May 2026 14:04:59 +0200 Subject: [PATCH 01/10] Add support for MutableCSINodeAllocatableCount The CSI list's all PCIe devices that are not of type VIRTIO_BLOCK_DEVICE and subtracts them from the theoretically maximum, so kubernetes can report a correct dynamic max volume count that can be attached for each node. Signed-off-by: Niclas Schad --- pkg/csi/blockstorage/controllerserver.go | 4 ++ pkg/csi/blockstorage/nodeserver.go | 12 +++- pkg/csi/blockstorage/utils.go | 2 +- pkg/csi/blockstorage/utils_test.go | 10 +-- pkg/csi/util/mount/mount_darwin.go | 5 ++ pkg/csi/util/mount/mount_linux.go | 80 ++++++++++++++++++++++++ pkg/stackit/stackiterrors/errors.go | 14 ++++- 7 files changed, 118 insertions(+), 9 deletions(-) diff --git a/pkg/csi/blockstorage/controllerserver.go b/pkg/csi/blockstorage/controllerserver.go index 8de6237e..ef962edd 100644 --- a/pkg/csi/blockstorage/controllerserver.go +++ b/pkg/csi/blockstorage/controllerserver.go @@ -370,6 +370,10 @@ func (cs *controllerServer) ControllerPublishVolume(ctx context.Context, req *cs _, err = cloud.AttachVolume(ctx, instanceID, volumeID) if err != nil { + // Trigger's an immediate `NodeGetInfo` RPC call when MutableCSINodeAllocatableCount is enabled + if stackiterrors.IsTooManyDevicesError(err) { + return nil, status.Errorf(codes.ResourceExhausted, "[ControllerPublishVolume] Node can't accept any more volumes %v. All PCIe lanes are exhausted!", err) + } klog.Errorf("Failed to AttachVolume: %v", err) return nil, status.Errorf(codes.Internal, "[ControllerPublishVolume] Attach Volume failed with error %v", err) } diff --git a/pkg/csi/blockstorage/nodeserver.go b/pkg/csi/blockstorage/nodeserver.go index 648e5df3..0b390d2b 100644 --- a/pkg/csi/blockstorage/nodeserver.go +++ b/pkg/csi/blockstorage/nodeserver.go @@ -308,8 +308,16 @@ func (ns *nodeServer) NodeGetInfo(ctx context.Context, _ *csi.NodeGetInfoRequest } maxVolumesPerNode := DetermineMaxVolumesByFlavor(flavor) - // Subtract 1 for root disk and another for configDrive/spare - maxVolumesPerNode -= 2 + + // Subtract already mounted Volumes + emptyPCIeRootPorts, err := mount.CountNonVirtioBlockDevices() + if err != nil { + klog.Errorf("[NodeGetInfo] unable to retrieve PCIe root ports %v", err) + emptyPCIeRootPorts = 0 + } + + maxVolumesPerNode -= emptyPCIeRootPorts + klog.V(4).Infof("Determined %d PCIe ports occupied by non virtio block devices", emptyPCIeRootPorts) klog.V(4).Infof("Determined node to support %d volumes", maxVolumesPerNode) nodeInfo := &csi.NodeGetInfoResponse{ diff --git a/pkg/csi/blockstorage/utils.go b/pkg/csi/blockstorage/utils.go index aaafc864..eacb77f7 100644 --- a/pkg/csi/blockstorage/utils.go +++ b/pkg/csi/blockstorage/utils.go @@ -97,7 +97,7 @@ func DetermineMaxVolumesByFlavor(flavor string) int64 { return 159 default: // All other flavors can mount 28 volumes - return 25 + return 28 } } diff --git a/pkg/csi/blockstorage/utils_test.go b/pkg/csi/blockstorage/utils_test.go index f9261de4..9d505950 100644 --- a/pkg/csi/blockstorage/utils_test.go +++ b/pkg/csi/blockstorage/utils_test.go @@ -12,14 +12,14 @@ var _ = Describe("Util Test", func() { maxVolumes := DetermineMaxVolumesByFlavor(flavor) Expect(maxVolumes).To(Equal(int64(expectedMaxVolumes))) }, - Entry("Intel 3rd Gen", "c3i.2", 25), - Entry("Intel 2rd Gen", "c2i.2", 25), - Entry("Intel 1st Gen", "c1.2", 25), - Entry("AMD 1st Gen without overprovisioning", "s1a.8d", 25), + Entry("Intel 3rd Gen", "c3i.2", 28), + Entry("Intel 2rd Gen", "c2i.2", 28), + Entry("Intel 1st Gen", "c1.2", 28), + Entry("AMD 1st Gen without overprovisioning", "s1a.8d", 28), Entry("AMD 2nd Gen without overprovisioning", "s2a.8d", 159), Entry("Nvidia GPU", "n2.14d.g1", 10), Entry("Nvidia GPU", "n2.56d.g4", 10), - Entry("ARM Gen1Link without CPU-overprovisioning ARM Gen1", "g1r.4d", 25), + Entry("ARM Gen1Link without CPU-overprovisioning ARM Gen1", "g1r.4d", 28), ) }) }) diff --git a/pkg/csi/util/mount/mount_darwin.go b/pkg/csi/util/mount/mount_darwin.go index 122f4c1c..389fd6cb 100644 --- a/pkg/csi/util/mount/mount_darwin.go +++ b/pkg/csi/util/mount/mount_darwin.go @@ -17,3 +17,8 @@ func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats { UsedInodes: int64(statfs.Files) - int64(statfs.Ffree), } } + +func CountNonVirtioBlockDevices() (int64, error) { + // not implemented + return 0, nil +} diff --git a/pkg/csi/util/mount/mount_linux.go b/pkg/csi/util/mount/mount_linux.go index b525b753..f8925708 100644 --- a/pkg/csi/util/mount/mount_linux.go +++ b/pkg/csi/util/mount/mount_linux.go @@ -4,6 +4,15 @@ package mount import "golang.org/x/sys/unix" +var ( + pciAddressRegex = regexp.MustCompile(`^[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:[0-9a-fA-F]{2}\.[0-9a-fA-F]$`) +) + +const ( + RedhatVendor = "0x1af4" + VirtioBlockDevice = "0x1042" +) + func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats { return &DeviceStats{ Block: false, @@ -17,3 +26,74 @@ func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats { UsedInodes: int64(statfs.Files) - int64(statfs.Ffree), } } + +// CountNonVirtioBlockDevices returns the number of PCIe Root ports who +// are currently occupied by anything else than an VIRTIO 1.0 Block Device +// returns zero when something went wrong +func CountNonVirtioBlockDevices() (int64, error) { + const pciPath = "/sys/bus/pci/devices" + + // Get all PCI devices + devices, err := os.ReadDir(pciPath) + if err != nil { + return 0, fmt.Errorf("failed to read PCI bus: %w", err) + } + + pcieSlotsOccupiedByNonBlockDevice := 0 + + for _, dev := range devices { + devPath := filepath.Join(pciPath, dev.Name()) + + // 1. Identify if it's a Root Port / Bridge + // We check the 'class' file. PCI Bridge class code starts with 0x0604 + classBuf, err := os.ReadFile(filepath.Join(devPath, "class")) + if err != nil { + klog.Errorf("failed to read PCI device class %s : %v", devPath, err) + continue + } + class := strings.TrimSpace(string(classBuf)) + + // Class 0x060400 is a PCI-to-PCI bridge (standard for Root Ports) + if strings.HasPrefix(class, "0x0604") { + // 2. Check if the port has downstream devices + // If the bridge has children, they appear as subdirectories + // matching the PCI address format (e.g., 0000:01:00.0) + files, err2 := os.ReadDir(devPath) + if err2 != nil { + klog.Errorf("failed to read dir %s : %v", devPath, err2) + } + for _, file := range files { + // Ignore PCI bus directories such as pci001 pci002 and pci010 + // Devices must follow format + if pciAddressRegex.MatchString(file.Name()) { + isNonBlockDevice := IsNonBlockDevice(devPath, file) + if isNonBlockDevice { + pcieSlotsOccupiedByNonBlockDevice++ + } + break + } + } + } else { + klog.V(4).Infof("skipping class %s: path: %s", class, devPath) + } + } + + return int64(pcieSlotsOccupiedByNonBlockDevice), nil +} + +func IsNonBlockDevice(devPath string, file os.DirEntry) bool { + var isNonBlockDevice bool + pciDevicePath := filepath.Join(devPath, file.Name()) + vendorBuf, err := os.ReadFile(filepath.Join(pciDevicePath, "vendor")) + if err != nil { + klog.Errorf("failed to read PCI device vendor %s : %v", pciDevicePath, err) + } + deviceBuf, err := os.ReadFile(filepath.Join(pciDevicePath, "device")) + if err != nil { + klog.Errorf("failed to read PCI device file %s : %v", pciDevicePath, err) + } + if strings.TrimSpace(string(vendorBuf)) == RedhatVendor && strings.TrimSpace(string(deviceBuf)) != VirtioBlockDevice { + isNonBlockDevice = true + } + return isNonBlockDevice +} diff --git a/pkg/stackit/stackiterrors/errors.go b/pkg/stackit/stackiterrors/errors.go index ae19b7d7..0e37be49 100644 --- a/pkg/stackit/stackiterrors/errors.go +++ b/pkg/stackit/stackiterrors/errors.go @@ -4,9 +4,10 @@ import ( "errors" "fmt" "net/http" + "strings" oapiError "github.com/stackitcloud/stackit-sdk-go/core/oapierror" - wait "github.com/stackitcloud/stackit-sdk-go/services/iaas/v2api/wait" + "github.com/stackitcloud/stackit-sdk-go/services/iaas/v2api/wait" ) var ErrNotFound = errors.New("failed to find object") @@ -20,6 +21,17 @@ func IsNotFound(err error) bool { return oAPIError.StatusCode == http.StatusNotFound } +func IsTooManyDevicesError(err error) bool { + var oAPIError *oapiError.GenericOpenAPIError + if ok := errors.As(err, &oAPIError); !ok { + return false + } + + // TODO: Improve this if possible + return oAPIError.StatusCode == http.StatusForbidden && + strings.Contains(oAPIError.ErrorMessage, "maximum allowed number of disk devices") +} + func IgnoreNotFound(err error) error { if IsNotFound(err) { return nil From a2bc65cffd61e8d16799915d1ecb7f65be17e846 Mon Sep 17 00:00:00 2001 From: Niclas Schad Date: Fri, 15 May 2026 14:39:28 +0200 Subject: [PATCH 02/10] fix imports for linux Signed-off-by: Niclas Schad --- pkg/csi/util/mount/mount_linux.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pkg/csi/util/mount/mount_linux.go b/pkg/csi/util/mount/mount_linux.go index f8925708..c6259b47 100644 --- a/pkg/csi/util/mount/mount_linux.go +++ b/pkg/csi/util/mount/mount_linux.go @@ -2,7 +2,16 @@ package mount -import "golang.org/x/sys/unix" +import ( + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + + "golang.org/x/sys/unix" + "k8s.io/klog/v2" +) var ( pciAddressRegex = regexp.MustCompile(`^[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:[0-9a-fA-F]{2}\.[0-9a-fA-F]$`) From 101b73158dccba178a4e4f3c7a89879dd759fd83 Mon Sep 17 00:00:00 2001 From: Niclas Schad Date: Tue, 19 May 2026 13:54:12 +0200 Subject: [PATCH 03/10] subtract one from maxVolumes for root partition Signed-off-by: Niclas Schad --- pkg/csi/blockstorage/nodeserver.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pkg/csi/blockstorage/nodeserver.go b/pkg/csi/blockstorage/nodeserver.go index 0b390d2b..44e5c450 100644 --- a/pkg/csi/blockstorage/nodeserver.go +++ b/pkg/csi/blockstorage/nodeserver.go @@ -320,6 +320,9 @@ func (ns *nodeServer) NodeGetInfo(ctx context.Context, _ *csi.NodeGetInfoRequest klog.V(4).Infof("Determined %d PCIe ports occupied by non virtio block devices", emptyPCIeRootPorts) klog.V(4).Infof("Determined node to support %d volumes", maxVolumesPerNode) + // always subtract one for every SKE node, because they always have a root partition + maxVolumesPerNode -= 1 + nodeInfo := &csi.NodeGetInfoResponse{ NodeId: nodeID, MaxVolumesPerNode: maxVolumesPerNode, From e6e983c2fa750643676903cbf4efa7e09028186a Mon Sep 17 00:00:00 2001 From: Niclas Schad Date: Mon, 1 Jun 2026 11:57:50 +0200 Subject: [PATCH 04/10] WIP Signed-off-by: Niclas Schad --- pkg/csi/blockstorage/nodeserver.go | 22 ++++++------ pkg/csi/blockstorage/utils.go | 4 +-- pkg/csi/util/mount/mount_darwin.go | 6 +++- pkg/csi/util/mount/mount_linux.go | 57 +++++++++++++++++++++--------- 4 files changed, 58 insertions(+), 31 deletions(-) diff --git a/pkg/csi/blockstorage/nodeserver.go b/pkg/csi/blockstorage/nodeserver.go index 44e5c450..9ef5c3ab 100644 --- a/pkg/csi/blockstorage/nodeserver.go +++ b/pkg/csi/blockstorage/nodeserver.go @@ -302,26 +302,24 @@ func (ns *nodeServer) NodeGetInfo(ctx context.Context, _ *csi.NodeGetInfoRequest return nil, status.Errorf(codes.Internal, "[NodeGetInfo] unable to retrieve instance id of node %v", err) } - flavor, err := ns.Metadata.GetFlavor(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "[NodeGetInfo] unable to retrieve flavor of node %v", err) - } - - maxVolumesPerNode := DetermineMaxVolumesByFlavor(flavor) + //flavor, err := ns.Metadata.GetFlavor(ctx) + //if err != nil { + // return nil, status.Errorf(codes.Internal, "[NodeGetInfo] unable to retrieve flavor of node %v", err) + //} // Subtract already mounted Volumes - emptyPCIeRootPorts, err := mount.CountNonVirtioBlockDevices() + emptyPCIeRootPorts, err := mount.CountFreePCIeSlots() if err != nil { klog.Errorf("[NodeGetInfo] unable to retrieve PCIe root ports %v", err) emptyPCIeRootPorts = 0 } - maxVolumesPerNode -= emptyPCIeRootPorts - klog.V(4).Infof("Determined %d PCIe ports occupied by non virtio block devices", emptyPCIeRootPorts) - klog.V(4).Infof("Determined node to support %d volumes", maxVolumesPerNode) + vols, err := mount.CountLocalCSIVolumes(driverName) + if err != nil { + klog.Errorf("[NodeGetInfo] unable to retrieve volume count %v", err) + } - // always subtract one for every SKE node, because they always have a root partition - maxVolumesPerNode -= 1 + maxVolumesPerNode := emptyPCIeRootPorts + vols nodeInfo := &csi.NodeGetInfoResponse{ NodeId: nodeID, diff --git a/pkg/csi/blockstorage/utils.go b/pkg/csi/blockstorage/utils.go index eacb77f7..a14cdafb 100644 --- a/pkg/csi/blockstorage/utils.go +++ b/pkg/csi/blockstorage/utils.go @@ -90,8 +90,8 @@ func DetermineMaxVolumesByFlavor(flavor string) int64 { // The following numbers were specified by the IaaS team. They are based on actual tests. switch { case strings.HasPrefix(flavor, "n"): - // Flavors starting with 'n' are nvidia GPU flavors, all GPU VM's can only mount 10 volumes - return 10 + // Flavors starting with 'n' are nvidia GPU flavors + return 13 case strings.HasSuffix(flavorParts[0], "2a"): // AMD 2nd Gen return 159 diff --git a/pkg/csi/util/mount/mount_darwin.go b/pkg/csi/util/mount/mount_darwin.go index 389fd6cb..07dcba11 100644 --- a/pkg/csi/util/mount/mount_darwin.go +++ b/pkg/csi/util/mount/mount_darwin.go @@ -18,7 +18,11 @@ func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats { } } -func CountNonVirtioBlockDevices() (int64, error) { +func CountLocalCSIVolumes(_ string) (int64, error) { // not implemented return 0, nil } + +func CountFreePCIeSlots() (int64, error) { + return 0, nil +} diff --git a/pkg/csi/util/mount/mount_linux.go b/pkg/csi/util/mount/mount_linux.go index c6259b47..efa5a6bd 100644 --- a/pkg/csi/util/mount/mount_linux.go +++ b/pkg/csi/util/mount/mount_linux.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "regexp" + "slices" "strings" "golang.org/x/sys/unix" @@ -36,10 +37,9 @@ func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats { } } -// CountNonVirtioBlockDevices returns the number of PCIe Root ports who -// are currently occupied by anything else than an VIRTIO 1.0 Block Device -// returns zero when something went wrong -func CountNonVirtioBlockDevices() (int64, error) { +// CountFreePCIeSlots returns the number of PCIe Root ports who +// are currently not occupied by anything. +func CountFreePCIeSlots() (int64, error) { const pciPath = "/sys/bus/pci/devices" // Get all PCI devices @@ -48,7 +48,7 @@ func CountNonVirtioBlockDevices() (int64, error) { return 0, fmt.Errorf("failed to read PCI bus: %w", err) } - pcieSlotsOccupiedByNonBlockDevice := 0 + freePCIeSlots := 0 for _, dev := range devices { devPath := filepath.Join(pciPath, dev.Name()) @@ -71,23 +71,48 @@ func CountNonVirtioBlockDevices() (int64, error) { if err2 != nil { klog.Errorf("failed to read dir %s : %v", devPath, err2) } - for _, file := range files { - // Ignore PCI bus directories such as pci001 pci002 and pci010 - // Devices must follow format - if pciAddressRegex.MatchString(file.Name()) { - isNonBlockDevice := IsNonBlockDevice(devPath, file) - if isNonBlockDevice { - pcieSlotsOccupiedByNonBlockDevice++ - } - break - } + hasDownStreamFolder := slices.ContainsFunc(files, func(s os.DirEntry) bool { + return pciAddressRegex.MatchString(s.Name()) + }) + if !hasDownStreamFolder { + freePCIeSlots += 1 } } else { klog.V(4).Infof("skipping class %s: path: %s", class, devPath) } } - return int64(pcieSlotsOccupiedByNonBlockDevice), nil + return int64(freePCIeSlots), nil +} + +// CountLocalCSIVolumes tries to count how many volumes are mounted for a given driverName. +func CountLocalCSIVolumes(driverName string) (int64, error) { + const kubeletDir = "/var/lib/kubelet" + volumeCount := 0 + // The path where Kubelet mounts global tracking directories for a specific CSI driver + targetDir := filepath.Join(kubeletDir, "plugins", "kubernetes.io", "csi", driverName) + + if _, err := os.Stat(targetDir); os.IsNotExist(err) { + return 0, nil + } else if err != nil { + return 0, fmt.Errorf("failed to check directory: %w", err) + } + + volumes, err := os.ReadDir(targetDir) + if err != nil { + return 0, fmt.Errorf("failed to read dir %s: %w", targetDir, err) + } + for _, vol := range volumes { + // Check if volume has a "globalmount" dir to determine if it's mounted correctly + globalMountPath := filepath.Join(vol.Name(), "globalmount") + if _, err := os.Stat(globalMountPath); os.IsNotExist(err) { + continue + } + + volumeCount++ + } + + return int64(volumeCount), nil } func IsNonBlockDevice(devPath string, file os.DirEntry) bool { From 7f0c7bf881c84a9d6c494e2d1440c5f567a6af79 Mon Sep 17 00:00:00 2001 From: Niclas Schad Date: Mon, 1 Jun 2026 15:38:30 +0200 Subject: [PATCH 05/10] parse Body instead of ErrorMessage field in IsTooManyDevicesError() Signed-off-by: Niclas Schad --- pkg/stackit/stackiterrors/errors.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/stackit/stackiterrors/errors.go b/pkg/stackit/stackiterrors/errors.go index 0e37be49..98b0f528 100644 --- a/pkg/stackit/stackiterrors/errors.go +++ b/pkg/stackit/stackiterrors/errors.go @@ -29,7 +29,7 @@ func IsTooManyDevicesError(err error) bool { // TODO: Improve this if possible return oAPIError.StatusCode == http.StatusForbidden && - strings.Contains(oAPIError.ErrorMessage, "maximum allowed number of disk devices") + strings.Contains(string(oAPIError.Body), "maximum allowed number of disk devices") } func IgnoreNotFound(err error) error { From 16bfec030cba8350b6b80d98f98ec3b9055ae477 Mon Sep 17 00:00:00 2001 From: Niclas Schad Date: Mon, 1 Jun 2026 15:54:36 +0200 Subject: [PATCH 06/10] cleanup Signed-off-by: Niclas Schad --- pkg/csi/blockstorage/nodeserver.go | 7 +------ pkg/csi/blockstorage/utils.go | 17 ----------------- pkg/csi/blockstorage/utils_test.go | 25 ------------------------- pkg/csi/util/mount/mount_linux.go | 23 +++-------------------- 4 files changed, 4 insertions(+), 68 deletions(-) delete mode 100644 pkg/csi/blockstorage/utils_test.go diff --git a/pkg/csi/blockstorage/nodeserver.go b/pkg/csi/blockstorage/nodeserver.go index 9ef5c3ab..268692bc 100644 --- a/pkg/csi/blockstorage/nodeserver.go +++ b/pkg/csi/blockstorage/nodeserver.go @@ -302,12 +302,6 @@ func (ns *nodeServer) NodeGetInfo(ctx context.Context, _ *csi.NodeGetInfoRequest return nil, status.Errorf(codes.Internal, "[NodeGetInfo] unable to retrieve instance id of node %v", err) } - //flavor, err := ns.Metadata.GetFlavor(ctx) - //if err != nil { - // return nil, status.Errorf(codes.Internal, "[NodeGetInfo] unable to retrieve flavor of node %v", err) - //} - - // Subtract already mounted Volumes emptyPCIeRootPorts, err := mount.CountFreePCIeSlots() if err != nil { klog.Errorf("[NodeGetInfo] unable to retrieve PCIe root ports %v", err) @@ -319,6 +313,7 @@ func (ns *nodeServer) NodeGetInfo(ctx context.Context, _ *csi.NodeGetInfoRequest klog.Errorf("[NodeGetInfo] unable to retrieve volume count %v", err) } + // maxVolumesPerNode is the result of all free/empty PCIClassBridgePCI ports plus all already mounted volumes. maxVolumesPerNode := emptyPCIeRootPorts + vols nodeInfo := &csi.NodeGetInfoResponse{ diff --git a/pkg/csi/blockstorage/utils.go b/pkg/csi/blockstorage/utils.go index a14cdafb..8fc02143 100644 --- a/pkg/csi/blockstorage/utils.go +++ b/pkg/csi/blockstorage/utils.go @@ -84,23 +84,6 @@ func ParseEndpoint(ep string) (proto, addr string, err error) { return "", "", fmt.Errorf("invalid endpoint: %v", ep) } -func DetermineMaxVolumesByFlavor(flavor string) int64 { - flavorParts := strings.Split(flavor, ".") - - // The following numbers were specified by the IaaS team. They are based on actual tests. - switch { - case strings.HasPrefix(flavor, "n"): - // Flavors starting with 'n' are nvidia GPU flavors - return 13 - case strings.HasSuffix(flavorParts[0], "2a"): - // AMD 2nd Gen - return 159 - default: - // All other flavors can mount 28 volumes - return 28 - } -} - func logGRPC(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { callID := serverGRPCEndpointCallCounter.Add(1) diff --git a/pkg/csi/blockstorage/utils_test.go b/pkg/csi/blockstorage/utils_test.go deleted file mode 100644 index 9d505950..00000000 --- a/pkg/csi/blockstorage/utils_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package blockstorage - -import ( - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("Util Test", func() { - - Context("DetermineMaxVolumesByFlavor", func() { - DescribeTable("should return the correct maximum volume count for different flavors", func(flavor string, expectedMaxVolumes int) { - maxVolumes := DetermineMaxVolumesByFlavor(flavor) - Expect(maxVolumes).To(Equal(int64(expectedMaxVolumes))) - }, - Entry("Intel 3rd Gen", "c3i.2", 28), - Entry("Intel 2rd Gen", "c2i.2", 28), - Entry("Intel 1st Gen", "c1.2", 28), - Entry("AMD 1st Gen without overprovisioning", "s1a.8d", 28), - Entry("AMD 2nd Gen without overprovisioning", "s2a.8d", 159), - Entry("Nvidia GPU", "n2.14d.g1", 10), - Entry("Nvidia GPU", "n2.56d.g4", 10), - Entry("ARM Gen1Link without CPU-overprovisioning ARM Gen1", "g1r.4d", 28), - ) - }) -}) diff --git a/pkg/csi/util/mount/mount_linux.go b/pkg/csi/util/mount/mount_linux.go index efa5a6bd..88e4dc3a 100644 --- a/pkg/csi/util/mount/mount_linux.go +++ b/pkg/csi/util/mount/mount_linux.go @@ -19,8 +19,8 @@ var ( ) const ( - RedhatVendor = "0x1af4" - VirtioBlockDevice = "0x1042" + // PCIClassBridgePCI Linux constant: https://github.com/torvalds/linux/blob/e43ffb69e0438cddd72aaa30898b4dc446f664f8/include/linux/pci_ids.h#L62 + PCIClassBridgePCI = "0x0604" ) func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats { @@ -63,7 +63,7 @@ func CountFreePCIeSlots() (int64, error) { class := strings.TrimSpace(string(classBuf)) // Class 0x060400 is a PCI-to-PCI bridge (standard for Root Ports) - if strings.HasPrefix(class, "0x0604") { + if strings.HasPrefix(class, PCIClassBridgePCI) { // 2. Check if the port has downstream devices // If the bridge has children, they appear as subdirectories // matching the PCI address format (e.g., 0000:01:00.0) @@ -114,20 +114,3 @@ func CountLocalCSIVolumes(driverName string) (int64, error) { return int64(volumeCount), nil } - -func IsNonBlockDevice(devPath string, file os.DirEntry) bool { - var isNonBlockDevice bool - pciDevicePath := filepath.Join(devPath, file.Name()) - vendorBuf, err := os.ReadFile(filepath.Join(pciDevicePath, "vendor")) - if err != nil { - klog.Errorf("failed to read PCI device vendor %s : %v", pciDevicePath, err) - } - deviceBuf, err := os.ReadFile(filepath.Join(pciDevicePath, "device")) - if err != nil { - klog.Errorf("failed to read PCI device file %s : %v", pciDevicePath, err) - } - if strings.TrimSpace(string(vendorBuf)) == RedhatVendor && strings.TrimSpace(string(deviceBuf)) != VirtioBlockDevice { - isNonBlockDevice = true - } - return isNonBlockDevice -} From ed9e7bb97e676fd42f143c960a8acacd8af04419 Mon Sep 17 00:00:00 2001 From: Niclas Schad Date: Wed, 10 Jun 2026 10:23:50 +0200 Subject: [PATCH 07/10] address linter feedback Signed-off-by: Niclas Schad --- pkg/csi/util/mount/mount_linux.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/csi/util/mount/mount_linux.go b/pkg/csi/util/mount/mount_linux.go index 88e4dc3a..afd261d1 100644 --- a/pkg/csi/util/mount/mount_linux.go +++ b/pkg/csi/util/mount/mount_linux.go @@ -75,7 +75,7 @@ func CountFreePCIeSlots() (int64, error) { return pciAddressRegex.MatchString(s.Name()) }) if !hasDownStreamFolder { - freePCIeSlots += 1 + freePCIeSlots++ } } else { klog.V(4).Infof("skipping class %s: path: %s", class, devPath) From e10488967e70d43799b97e6e34f1f757a1d138ee Mon Sep 17 00:00:00 2001 From: Felix Breuer Date: Wed, 10 Jun 2026 17:21:16 +0200 Subject: [PATCH 08/10] add unit tests and restructure into helper function Signed-off-by: Felix Breuer --- pkg/csi/blockstorage/nodeserver.go | 32 ++--- pkg/csi/util/mount/mount_helper.go | 60 +++++++++ pkg/csi/util/mount/mount_helper_test.go | 162 ++++++++++++++++++++++++ pkg/csi/util/mount/mount_linux.go | 91 +------------ pkg/stackit/stackiterrors/errors.go | 26 ++-- 5 files changed, 264 insertions(+), 107 deletions(-) create mode 100644 pkg/csi/util/mount/mount_helper.go create mode 100644 pkg/csi/util/mount/mount_helper_test.go diff --git a/pkg/csi/blockstorage/nodeserver.go b/pkg/csi/blockstorage/nodeserver.go index 268692bc..d7c00a63 100644 --- a/pkg/csi/blockstorage/nodeserver.go +++ b/pkg/csi/blockstorage/nodeserver.go @@ -302,23 +302,9 @@ func (ns *nodeServer) NodeGetInfo(ctx context.Context, _ *csi.NodeGetInfoRequest return nil, status.Errorf(codes.Internal, "[NodeGetInfo] unable to retrieve instance id of node %v", err) } - emptyPCIeRootPorts, err := mount.CountFreePCIeSlots() - if err != nil { - klog.Errorf("[NodeGetInfo] unable to retrieve PCIe root ports %v", err) - emptyPCIeRootPorts = 0 - } - - vols, err := mount.CountLocalCSIVolumes(driverName) - if err != nil { - klog.Errorf("[NodeGetInfo] unable to retrieve volume count %v", err) - } - - // maxVolumesPerNode is the result of all free/empty PCIClassBridgePCI ports plus all already mounted volumes. - maxVolumesPerNode := emptyPCIeRootPorts + vols - nodeInfo := &csi.NodeGetInfoResponse{ NodeId: nodeID, - MaxVolumesPerNode: maxVolumesPerNode, + MaxVolumesPerNode: ns.calculateMaxVolumesPerNode(), } zone, err := ns.Metadata.GetAvailabilityZone(ctx) @@ -336,6 +322,22 @@ func (ns *nodeServer) NodeGetInfo(ctx context.Context, _ *csi.NodeGetInfoRequest return nodeInfo, nil } +func (ns *nodeServer) calculateMaxVolumesPerNode() int64 { + freePCIeRootPorts, err := mount.CountFreePCIeSlots() + if err != nil { + klog.Errorf("[NodeGetInfo] unable to retrieve PCIe root ports: %v", err) + freePCIeRootPorts = 0 + } + + mountedCSIVolumes, err := mount.CountLocalCSIVolumes(driverName) + if err != nil { + klog.Errorf("[NodeGetInfo] unable to retrieve volume count: %v", err) + mountedCSIVolumes = 0 + } + + return freePCIeRootPorts + mountedCSIVolumes +} + func (ns *nodeServer) NodeGetCapabilities(_ context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) { klog.V(5).Infof("NodeGetCapabilities called with req: %#v", req) diff --git a/pkg/csi/util/mount/mount_helper.go b/pkg/csi/util/mount/mount_helper.go new file mode 100644 index 00000000..9c8264c8 --- /dev/null +++ b/pkg/csi/util/mount/mount_helper.go @@ -0,0 +1,60 @@ +package mount + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "k8s.io/klog/v2" +) + +const ( + // pciClassBridgePCI matches the Linux PCI-to-PCI bridge class prefix. + pciClassBridgePCI = "0x0604" + globalMountDir = "globalmount" +) + +func countFreePCIeSlotsAt(devicesPath string) (int64, error) { + devices, err := os.ReadDir(devicesPath) + if err != nil { + return 0, fmt.Errorf("failed to read PCI bus: %w", err) + } + + var freePCIeSlots int64 + + for _, dev := range devices { + devPath := filepath.Join(devicesPath, dev.Name()) + + classBuf, err := os.ReadFile(filepath.Join(devPath, "class")) + if err != nil { + klog.Errorf("failed to read PCI device class %s: %v", devPath, err) + continue + } + + class := strings.TrimSpace(string(classBuf)) + if !strings.HasPrefix(class, pciClassBridgePCI) { + continue + } + + children, err := filepath.Glob(filepath.Join(devPath, "????:??:??.?")) + if err != nil { + return 0, fmt.Errorf("failed to glob PCI children for %s: %w", devPath, err) + } + + if len(children) == 0 { + freePCIeSlots++ + } + } + + return freePCIeSlots, nil +} + +func countLocalCSIVolumesAt(driverPluginDir string) (int64, error) { + volumeMounts, err := filepath.Glob(filepath.Join(driverPluginDir, "*", globalMountDir)) + if err != nil { + return 0, fmt.Errorf("failed to glob CSI volume mounts in %s: %w", driverPluginDir, err) + } + + return int64(len(volumeMounts)), nil +} diff --git a/pkg/csi/util/mount/mount_helper_test.go b/pkg/csi/util/mount/mount_helper_test.go new file mode 100644 index 00000000..519b440d --- /dev/null +++ b/pkg/csi/util/mount/mount_helper_test.go @@ -0,0 +1,162 @@ +package mount + +import ( + "os" + "path/filepath" + "testing" +) + +func TestCountFreePCIeSlotsAtMissingRoot(t *testing.T) { + t.Parallel() + + _, err := countFreePCIeSlotsAt(filepath.Join(t.TempDir(), "missing")) + if err == nil { + t.Fatal("countFreePCIeSlotsAt() error = nil, want error") + } +} + +func TestCountFreePCIeSlotsAtCountsOnlyFreeBridgeSlots(t *testing.T) { + t.Parallel() + + devicesPath := t.TempDir() + + createPCIDevice(t, devicesPath, "0000:00:00.0", "0x060400") + createPCIDevice(t, devicesPath, "0000:00:01.0", "0x060400", "0000:01:00.0") + createPCIDevice(t, devicesPath, "0000:00:02.0", "0x010000", "0000:02:00.0") + + count, err := countFreePCIeSlotsAt(devicesPath) + if err != nil { + t.Fatalf("countFreePCIeSlotsAt() error = %v", err) + } + + if count != 1 { + t.Fatalf("countFreePCIeSlotsAt() = %d, want 1", count) + } +} + +func TestCountFreePCIeSlotsAtSkipsDevicesWithoutClass(t *testing.T) { + t.Parallel() + + devicesPath := t.TempDir() + + createPCIDevice(t, devicesPath, "0000:00:00.0", "0x060400") + mustMkdirAll(t, filepath.Join(devicesPath, "0000:00:01.0")) + + count, err := countFreePCIeSlotsAt(devicesPath) + if err != nil { + t.Fatalf("countFreePCIeSlotsAt() error = %v", err) + } + + if count != 1 { + t.Fatalf("countFreePCIeSlotsAt() = %d, want 1", count) + } +} + +func TestCountFreePCIeSlotsAtIgnoresNonPCIChildren(t *testing.T) { + t.Parallel() + + devicesPath := t.TempDir() + devPath := filepath.Join(devicesPath, "0000:00:00.0") + mustMkdirAll(t, devPath) + mustWriteFile(t, filepath.Join(devPath, "class"), "0x060400") + mustMkdirAll(t, filepath.Join(devPath, "driver")) + mustMkdirAll(t, filepath.Join(devPath, "not-a-pci-child")) + + count, err := countFreePCIeSlotsAt(devicesPath) + if err != nil { + t.Fatalf("countFreePCIeSlotsAt() error = %v", err) + } + + if count != 1 { + t.Fatalf("countFreePCIeSlotsAt() = %d, want 1", count) + } +} + +func TestCountLocalCSIVolumesAtMissingDir(t *testing.T) { + t.Parallel() + + count, err := countLocalCSIVolumesAt(filepath.Join(t.TempDir(), "missing")) + if err != nil { + t.Fatalf("countLocalCSIVolumesAt() error = %v", err) + } + + if count != 0 { + t.Fatalf("countLocalCSIVolumesAt() = %d, want 0", count) + } +} + +func TestCountLocalCSIVolumesAtCountsOnlyGlobalMountDirs(t *testing.T) { + t.Parallel() + + driverPluginDir := t.TempDir() + + mustMkdirAll(t, filepath.Join(driverPluginDir, "volume-a", globalMountDir)) + mustMkdirAll(t, filepath.Join(driverPluginDir, "volume-b", globalMountDir)) + mustMkdirAll(t, filepath.Join(driverPluginDir, "volume-c", "not-a-globalmount")) + + count, err := countLocalCSIVolumesAt(driverPluginDir) + if err != nil { + t.Fatalf("countLocalCSIVolumesAt() error = %v", err) + } + + if count != 2 { + t.Fatalf("countLocalCSIVolumesAt() = %d, want 2", count) + } +} + +func TestCountLocalCSIVolumesAtEmptyDir(t *testing.T) { + t.Parallel() + + count, err := countLocalCSIVolumesAt(t.TempDir()) + if err != nil { + t.Fatalf("countLocalCSIVolumesAt() error = %v", err) + } + + if count != 0 { + t.Fatalf("countLocalCSIVolumesAt() = %d, want 0", count) + } +} + +func TestCountLocalCSIVolumesAtReturnsZeroWhenDriverPathIsFile(t *testing.T) { + t.Parallel() + + driverPluginDir := filepath.Join(t.TempDir(), "driver") + mustWriteFile(t, driverPluginDir, "not a directory") + + count, err := countLocalCSIVolumesAt(driverPluginDir) + if err != nil { + t.Fatalf("countLocalCSIVolumesAt() error = %v", err) + } + + if count != 0 { + t.Fatalf("countLocalCSIVolumesAt() = %d, want 0", count) + } +} + +func createPCIDevice(t *testing.T, rootPath, deviceName, class string, children ...string) { + t.Helper() + + devPath := filepath.Join(rootPath, deviceName) + mustMkdirAll(t, devPath) + mustWriteFile(t, filepath.Join(devPath, "class"), class) + + for _, child := range children { + mustMkdirAll(t, filepath.Join(devPath, child)) + } +} + +func mustMkdirAll(t *testing.T, path string) { + t.Helper() + + if err := os.MkdirAll(path, 0o755); err != nil { + t.Fatalf("MkdirAll(%q) error = %v", path, err) + } +} + +func mustWriteFile(t *testing.T, path string, content string) { + t.Helper() + + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("WriteFile(%q) error = %v", path, err) + } +} diff --git a/pkg/csi/util/mount/mount_linux.go b/pkg/csi/util/mount/mount_linux.go index afd261d1..4bdfe55b 100644 --- a/pkg/csi/util/mount/mount_linux.go +++ b/pkg/csi/util/mount/mount_linux.go @@ -3,24 +3,14 @@ package mount import ( - "fmt" - "os" "path/filepath" - "regexp" - "slices" - "strings" "golang.org/x/sys/unix" - "k8s.io/klog/v2" -) - -var ( - pciAddressRegex = regexp.MustCompile(`^[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:[0-9a-fA-F]{2}\.[0-9a-fA-F]$`) ) const ( - // PCIClassBridgePCI Linux constant: https://github.com/torvalds/linux/blob/e43ffb69e0438cddd72aaa30898b4dc446f664f8/include/linux/pci_ids.h#L62 - PCIClassBridgePCI = "0x0604" + pciDevicesPath = "/sys/bus/pci/devices" + kubeletDir = "/var/lib/kubelet" ) func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats { @@ -37,80 +27,13 @@ func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats { } } -// CountFreePCIeSlots returns the number of PCIe Root ports who -// are currently not occupied by anything. +// CountFreePCIeSlots returns the number of PCIe root ports that are not occupied. func CountFreePCIeSlots() (int64, error) { - const pciPath = "/sys/bus/pci/devices" - - // Get all PCI devices - devices, err := os.ReadDir(pciPath) - if err != nil { - return 0, fmt.Errorf("failed to read PCI bus: %w", err) - } - - freePCIeSlots := 0 - - for _, dev := range devices { - devPath := filepath.Join(pciPath, dev.Name()) - - // 1. Identify if it's a Root Port / Bridge - // We check the 'class' file. PCI Bridge class code starts with 0x0604 - classBuf, err := os.ReadFile(filepath.Join(devPath, "class")) - if err != nil { - klog.Errorf("failed to read PCI device class %s : %v", devPath, err) - continue - } - class := strings.TrimSpace(string(classBuf)) - - // Class 0x060400 is a PCI-to-PCI bridge (standard for Root Ports) - if strings.HasPrefix(class, PCIClassBridgePCI) { - // 2. Check if the port has downstream devices - // If the bridge has children, they appear as subdirectories - // matching the PCI address format (e.g., 0000:01:00.0) - files, err2 := os.ReadDir(devPath) - if err2 != nil { - klog.Errorf("failed to read dir %s : %v", devPath, err2) - } - hasDownStreamFolder := slices.ContainsFunc(files, func(s os.DirEntry) bool { - return pciAddressRegex.MatchString(s.Name()) - }) - if !hasDownStreamFolder { - freePCIeSlots++ - } - } else { - klog.V(4).Infof("skipping class %s: path: %s", class, devPath) - } - } - - return int64(freePCIeSlots), nil + return countFreePCIeSlotsAt(pciDevicesPath) } -// CountLocalCSIVolumes tries to count how many volumes are mounted for a given driverName. +// CountLocalCSIVolumes counts staged CSI volumes for the given driver. func CountLocalCSIVolumes(driverName string) (int64, error) { - const kubeletDir = "/var/lib/kubelet" - volumeCount := 0 - // The path where Kubelet mounts global tracking directories for a specific CSI driver - targetDir := filepath.Join(kubeletDir, "plugins", "kubernetes.io", "csi", driverName) - - if _, err := os.Stat(targetDir); os.IsNotExist(err) { - return 0, nil - } else if err != nil { - return 0, fmt.Errorf("failed to check directory: %w", err) - } - - volumes, err := os.ReadDir(targetDir) - if err != nil { - return 0, fmt.Errorf("failed to read dir %s: %w", targetDir, err) - } - for _, vol := range volumes { - // Check if volume has a "globalmount" dir to determine if it's mounted correctly - globalMountPath := filepath.Join(vol.Name(), "globalmount") - if _, err := os.Stat(globalMountPath); os.IsNotExist(err) { - continue - } - - volumeCount++ - } - - return int64(volumeCount), nil + driverPluginDir := filepath.Join(kubeletDir, "plugins", "kubernetes.io", "csi", driverName) + return countLocalCSIVolumesAt(driverPluginDir) } diff --git a/pkg/stackit/stackiterrors/errors.go b/pkg/stackit/stackiterrors/errors.go index 98b0f528..1b1f127a 100644 --- a/pkg/stackit/stackiterrors/errors.go +++ b/pkg/stackit/stackiterrors/errors.go @@ -10,11 +10,13 @@ import ( "github.com/stackitcloud/stackit-sdk-go/services/iaas/v2api/wait" ) +const tooManyDiskDevicesMessageFragment = "maximum allowed number of disk devices" + var ErrNotFound = errors.New("failed to find object") func IsNotFound(err error) bool { - var oAPIError *oapiError.GenericOpenAPIError - if ok := errors.As(err, &oAPIError); !ok { + oAPIError, ok := genericOpenAPIError(err) + if !ok { return false } @@ -22,14 +24,13 @@ func IsNotFound(err error) bool { } func IsTooManyDevicesError(err error) bool { - var oAPIError *oapiError.GenericOpenAPIError - if ok := errors.As(err, &oAPIError); !ok { + oAPIError, ok := genericOpenAPIError(err) + if !ok { return false } - // TODO: Improve this if possible return oAPIError.StatusCode == http.StatusForbidden && - strings.Contains(string(oAPIError.Body), "maximum allowed number of disk devices") + strings.Contains(string(oAPIError.Body), tooManyDiskDevicesMessageFragment) } func IgnoreNotFound(err error) error { @@ -52,10 +53,19 @@ func WrapErrorWithResponseID(err error, reqID string) error { } func IsInvalidError(err error) bool { - var oAPIError *oapiError.GenericOpenAPIError - if ok := errors.As(err, &oAPIError); !ok { + oAPIError, ok := genericOpenAPIError(err) + if !ok { return false } return oAPIError.StatusCode == http.StatusBadRequest } + +func genericOpenAPIError(err error) (*oapiError.GenericOpenAPIError, bool) { + var oAPIError *oapiError.GenericOpenAPIError + if ok := errors.As(err, &oAPIError); !ok { + return nil, false + } + + return oAPIError, true +} From 09065c91263e611095a10de5196ca078ef22a90f Mon Sep 17 00:00:00 2001 From: Felix Breuer Date: Thu, 11 Jun 2026 09:32:16 +0200 Subject: [PATCH 09/10] rewrite tests with ginkgo Signed-off-by: Felix Breuer --- pkg/csi/util/mount/mount_helper_test.go | 239 ++++++++++-------------- pkg/csi/util/mount/suite_test.go | 13 ++ 2 files changed, 107 insertions(+), 145 deletions(-) create mode 100644 pkg/csi/util/mount/suite_test.go diff --git a/pkg/csi/util/mount/mount_helper_test.go b/pkg/csi/util/mount/mount_helper_test.go index 519b440d..e43624a8 100644 --- a/pkg/csi/util/mount/mount_helper_test.go +++ b/pkg/csi/util/mount/mount_helper_test.go @@ -3,160 +3,109 @@ package mount import ( "os" "path/filepath" - "testing" -) - -func TestCountFreePCIeSlotsAtMissingRoot(t *testing.T) { - t.Parallel() - - _, err := countFreePCIeSlotsAt(filepath.Join(t.TempDir(), "missing")) - if err == nil { - t.Fatal("countFreePCIeSlotsAt() error = nil, want error") - } -} - -func TestCountFreePCIeSlotsAtCountsOnlyFreeBridgeSlots(t *testing.T) { - t.Parallel() - - devicesPath := t.TempDir() - - createPCIDevice(t, devicesPath, "0000:00:00.0", "0x060400") - createPCIDevice(t, devicesPath, "0000:00:01.0", "0x060400", "0000:01:00.0") - createPCIDevice(t, devicesPath, "0000:00:02.0", "0x010000", "0000:02:00.0") - - count, err := countFreePCIeSlotsAt(devicesPath) - if err != nil { - t.Fatalf("countFreePCIeSlotsAt() error = %v", err) - } - - if count != 1 { - t.Fatalf("countFreePCIeSlotsAt() = %d, want 1", count) - } -} - -func TestCountFreePCIeSlotsAtSkipsDevicesWithoutClass(t *testing.T) { - t.Parallel() - - devicesPath := t.TempDir() - createPCIDevice(t, devicesPath, "0000:00:00.0", "0x060400") - mustMkdirAll(t, filepath.Join(devicesPath, "0000:00:01.0")) - - count, err := countFreePCIeSlotsAt(devicesPath) - if err != nil { - t.Fatalf("countFreePCIeSlotsAt() error = %v", err) - } - - if count != 1 { - t.Fatalf("countFreePCIeSlotsAt() = %d, want 1", count) - } -} - -func TestCountFreePCIeSlotsAtIgnoresNonPCIChildren(t *testing.T) { - t.Parallel() - - devicesPath := t.TempDir() - devPath := filepath.Join(devicesPath, "0000:00:00.0") - mustMkdirAll(t, devPath) - mustWriteFile(t, filepath.Join(devPath, "class"), "0x060400") - mustMkdirAll(t, filepath.Join(devPath, "driver")) - mustMkdirAll(t, filepath.Join(devPath, "not-a-pci-child")) - - count, err := countFreePCIeSlotsAt(devicesPath) - if err != nil { - t.Fatalf("countFreePCIeSlotsAt() error = %v", err) - } - - if count != 1 { - t.Fatalf("countFreePCIeSlotsAt() = %d, want 1", count) - } -} - -func TestCountLocalCSIVolumesAtMissingDir(t *testing.T) { - t.Parallel() - - count, err := countLocalCSIVolumesAt(filepath.Join(t.TempDir(), "missing")) - if err != nil { - t.Fatalf("countLocalCSIVolumesAt() error = %v", err) - } - - if count != 0 { - t.Fatalf("countLocalCSIVolumesAt() = %d, want 0", count) - } -} - -func TestCountLocalCSIVolumesAtCountsOnlyGlobalMountDirs(t *testing.T) { - t.Parallel() - - driverPluginDir := t.TempDir() - - mustMkdirAll(t, filepath.Join(driverPluginDir, "volume-a", globalMountDir)) - mustMkdirAll(t, filepath.Join(driverPluginDir, "volume-b", globalMountDir)) - mustMkdirAll(t, filepath.Join(driverPluginDir, "volume-c", "not-a-globalmount")) - - count, err := countLocalCSIVolumesAt(driverPluginDir) - if err != nil { - t.Fatalf("countLocalCSIVolumesAt() error = %v", err) - } - - if count != 2 { - t.Fatalf("countLocalCSIVolumesAt() = %d, want 2", count) - } -} - -func TestCountLocalCSIVolumesAtEmptyDir(t *testing.T) { - t.Parallel() - - count, err := countLocalCSIVolumesAt(t.TempDir()) - if err != nil { - t.Fatalf("countLocalCSIVolumesAt() error = %v", err) - } - - if count != 0 { - t.Fatalf("countLocalCSIVolumesAt() = %d, want 0", count) - } -} - -func TestCountLocalCSIVolumesAtReturnsZeroWhenDriverPathIsFile(t *testing.T) { - t.Parallel() - - driverPluginDir := filepath.Join(t.TempDir(), "driver") - mustWriteFile(t, driverPluginDir, "not a directory") - - count, err := countLocalCSIVolumesAt(driverPluginDir) - if err != nil { - t.Fatalf("countLocalCSIVolumesAt() error = %v", err) - } - - if count != 0 { - t.Fatalf("countLocalCSIVolumesAt() = %d, want 0", count) - } -} + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) -func createPCIDevice(t *testing.T, rootPath, deviceName, class string, children ...string) { - t.Helper() +var _ = Describe("Mount helpers", func() { + Describe("countFreePCIeSlotsAt", func() { + It("returns an error when the PCI devices root is missing", func() { + _, err := countFreePCIeSlotsAt(filepath.Join(GinkgoT().TempDir(), "missing")) + Expect(err).To(HaveOccurred()) + }) + + It("counts only free bridge-backed PCIe slots", func() { + devicesPath := GinkgoT().TempDir() + + createPCIDevice(devicesPath, "0000:00:00.0", "0x060400") + createPCIDevice(devicesPath, "0000:00:01.0", "0x060400", "0000:01:00.0") + createPCIDevice(devicesPath, "0000:00:02.0", "0x010000", "0000:02:00.0") + + count, err := countFreePCIeSlotsAt(devicesPath) + Expect(err).NotTo(HaveOccurred()) + Expect(count).To(Equal(int64(1))) + }) + + It("skips devices whose class cannot be read", func() { + devicesPath := GinkgoT().TempDir() + + createPCIDevice(devicesPath, "0000:00:00.0", "0x060400") + mustMkdirAll(filepath.Join(devicesPath, "0000:00:01.0")) + + count, err := countFreePCIeSlotsAt(devicesPath) + Expect(err).NotTo(HaveOccurred()) + Expect(count).To(Equal(int64(1))) + }) + + It("ignores non-PCI child entries when checking bridge occupancy", func() { + devicesPath := GinkgoT().TempDir() + devPath := filepath.Join(devicesPath, "0000:00:00.0") + mustMkdirAll(devPath) + mustWriteFile(filepath.Join(devPath, "class"), "0x060400") + mustMkdirAll(filepath.Join(devPath, "driver")) + mustMkdirAll(filepath.Join(devPath, "not-a-pci-child")) + + count, err := countFreePCIeSlotsAt(devicesPath) + Expect(err).NotTo(HaveOccurred()) + Expect(count).To(Equal(int64(1))) + }) + }) + + Describe("countLocalCSIVolumesAt", func() { + It("returns zero for a missing driver directory", func() { + count, err := countLocalCSIVolumesAt(filepath.Join(GinkgoT().TempDir(), "missing")) + Expect(err).NotTo(HaveOccurred()) + Expect(count).To(BeZero()) + }) + + It("counts only global mount directories", func() { + driverPluginDir := GinkgoT().TempDir() + + mustMkdirAll(filepath.Join(driverPluginDir, "volume-a", globalMountDir)) + mustMkdirAll(filepath.Join(driverPluginDir, "volume-b", globalMountDir)) + mustMkdirAll(filepath.Join(driverPluginDir, "volume-c", "not-a-globalmount")) + + count, err := countLocalCSIVolumesAt(driverPluginDir) + Expect(err).NotTo(HaveOccurred()) + Expect(count).To(Equal(int64(2))) + }) + + It("returns zero for an empty driver directory", func() { + count, err := countLocalCSIVolumesAt(GinkgoT().TempDir()) + Expect(err).NotTo(HaveOccurred()) + Expect(count).To(BeZero()) + }) + + It("returns zero when the driver path is a file", func() { + driverPluginDir := filepath.Join(GinkgoT().TempDir(), "driver") + mustWriteFile(driverPluginDir, "not a directory") + + count, err := countLocalCSIVolumesAt(driverPluginDir) + Expect(err).NotTo(HaveOccurred()) + Expect(count).To(BeZero()) + }) + }) +}) + +func createPCIDevice(rootPath, deviceName, class string, children ...string) { + GinkgoHelper() devPath := filepath.Join(rootPath, deviceName) - mustMkdirAll(t, devPath) - mustWriteFile(t, filepath.Join(devPath, "class"), class) + mustMkdirAll(devPath) + mustWriteFile(filepath.Join(devPath, "class"), class) for _, child := range children { - mustMkdirAll(t, filepath.Join(devPath, child)) + mustMkdirAll(filepath.Join(devPath, child)) } } -func mustMkdirAll(t *testing.T, path string) { - t.Helper() - - if err := os.MkdirAll(path, 0o755); err != nil { - t.Fatalf("MkdirAll(%q) error = %v", path, err) - } +func mustMkdirAll(path string) { + GinkgoHelper() + Expect(os.MkdirAll(path, 0o755)).To(Succeed()) } -func mustWriteFile(t *testing.T, path string, content string) { - t.Helper() - - if err := os.WriteFile(path, []byte(content), 0o644); err != nil { - t.Fatalf("WriteFile(%q) error = %v", path, err) - } +func mustWriteFile(path string, content string) { + GinkgoHelper() + Expect(os.WriteFile(path, []byte(content), 0o644)).To(Succeed()) } diff --git a/pkg/csi/util/mount/suite_test.go b/pkg/csi/util/mount/suite_test.go new file mode 100644 index 00000000..02e3ae62 --- /dev/null +++ b/pkg/csi/util/mount/suite_test.go @@ -0,0 +1,13 @@ +package mount + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestMount(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Mount Suite") +} From 9173dea6bfd980c8e9c7032efe49adb2662fa723 Mon Sep 17 00:00:00 2001 From: Felix Breuer Date: Thu, 11 Jun 2026 09:35:07 +0200 Subject: [PATCH 10/10] fix golint Signed-off-by: Felix Breuer --- pkg/csi/util/mount/mount_helper_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/csi/util/mount/mount_helper_test.go b/pkg/csi/util/mount/mount_helper_test.go index e43624a8..2c248e4d 100644 --- a/pkg/csi/util/mount/mount_helper_test.go +++ b/pkg/csi/util/mount/mount_helper_test.go @@ -105,7 +105,7 @@ func mustMkdirAll(path string) { Expect(os.MkdirAll(path, 0o755)).To(Succeed()) } -func mustWriteFile(path string, content string) { +func mustWriteFile(path, content string) { GinkgoHelper() Expect(os.WriteFile(path, []byte(content), 0o644)).To(Succeed()) }