diff --git a/pkg/cim/disk.go b/pkg/cim/disk.go index 0298b793..58c8f376 100644 --- a/pkg/cim/disk.go +++ b/pkg/cim/disk.go @@ -23,6 +23,17 @@ const ( // GPTPartitionTypeMicrosoftReserved is the GUID for Microsoft Reserved Partition (MSR) // Reserved by Windows for system use GPTPartitionTypeMicrosoftReserved = "{e3c9e316-0b5c-4db8-817d-f92df00215ae}" + + // ErrorCodeCreatePartitionAccessPathAlreadyInUse is the error code (42002) returned when the driver letter failed to assign after partition created + ErrorCodeCreatePartitionAccessPathAlreadyInUse = 42002 +) + +var ( + DiskSelectorListForDiskNumberAndLocation = []string{"Number", "Location"} + DiskSelectorListForPartitionStyle = []string{"PartitionStyle"} + DiskSelectorListForPathAndSerialNumber = []string{"Path", "SerialNumber"} + DiskSelectorListForIsOffline = []string{"IsOffline"} + DiskSelectorListForSize = []string{"Size"} ) // QueryDiskByNumber retrieves disk information for a specific disk identified by its number. @@ -77,6 +88,60 @@ func ListDisks(selectorList []string) ([]*storage.MSFT_Disk, error) { return disks, nil } +// InitializeDisk initializes a RAW disk with a particular partition style. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/initialize-msft-disk +// for the WMI method definition. +func InitializeDisk(disk *storage.MSFT_Disk, partitionStyle int) (int, error) { + result, err := disk.InvokeMethodWithReturn("Initialize", int32(partitionStyle)) + return int(result), err +} + +// RefreshDisk Refreshes the cached disk layout information. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-refresh +// for the WMI method definition. +func RefreshDisk(disk *storage.MSFT_Disk) (int, string, error) { + var status string + result, err := disk.InvokeMethodWithReturn("Refresh", &status) + return int(result), status, err +} + +// CreatePartition creates a partition on a disk. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/createpartition-msft-disk +// for the WMI method definition. +func CreatePartition(disk *storage.MSFT_Disk, params ...interface{}) (int, error) { + result, err := disk.InvokeMethodWithReturn("CreatePartition", params...) + return int(result), err +} + +// SetDiskState takes a disk online or offline. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-online and +// https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-disk-offline +// for the WMI method definition. +func SetDiskState(disk *storage.MSFT_Disk, online bool) (int, string, error) { + method := "Offline" + if online { + method = "Online" + } + + var status string + result, err := disk.InvokeMethodWithReturn(method, &status) + return int(result), status, err +} + +// RescanDisks rescans all changes by updating the internal cache of software objects (that is, Disks, Partitions, Volumes) +// for the storage setting. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-storagesetting-updatehoststoragecache +// for the WMI method definition. +func RescanDisks() (int, error) { + result, _, err := InvokeCimMethod(WMINamespaceStorage, "MSFT_StorageSetting", "UpdateHostStorageCache", nil) + return result, err +} + // GetDiskNumber returns the number of a disk. func GetDiskNumber(disk *storage.MSFT_Disk) (uint32, error) { number, err := disk.GetProperty("Number") @@ -85,3 +150,41 @@ func GetDiskNumber(disk *storage.MSFT_Disk) (uint32, error) { } return uint32(number.(int32)), err } + +// GetDiskLocation returns the location of a disk. +func GetDiskLocation(disk *storage.MSFT_Disk) (string, error) { + return disk.GetPropertyLocation() +} + +// GetDiskPartitionStyle returns the partition style of a disk. +func GetDiskPartitionStyle(disk *storage.MSFT_Disk) (int32, error) { + retValue, err := disk.GetProperty("PartitionStyle") + if err != nil { + return 0, err + } + return retValue.(int32), err +} + +// IsDiskOffline returns whether a disk is offline. +func IsDiskOffline(disk *storage.MSFT_Disk) (bool, error) { + return disk.GetPropertyIsOffline() +} + +// GetDiskSize returns the size of a disk. +func GetDiskSize(disk *storage.MSFT_Disk) (int64, error) { + sz, err := disk.GetProperty("Size") + if err != nil { + return -1, err + } + return strconv.ParseInt(sz.(string), 10, 64) +} + +// GetDiskPath returns the path of a disk. +func GetDiskPath(disk *storage.MSFT_Disk) (string, error) { + return disk.GetPropertyPath() +} + +// GetDiskSerialNumber returns the serial number of a disk. +func GetDiskSerialNumber(disk *storage.MSFT_Disk) (string, error) { + return disk.GetPropertySerialNumber() +} diff --git a/pkg/cim/smb.go b/pkg/cim/smb.go index 5868d456..2850ab78 100644 --- a/pkg/cim/smb.go +++ b/pkg/cim/smb.go @@ -4,6 +4,8 @@ package cim import ( + "strings" + "github.com/microsoft/wmi/pkg/base/query" cim "github.com/microsoft/wmi/pkg/wmiinstance" ) @@ -17,8 +19,24 @@ const ( SmbMappingStatusConnecting SmbMappingStatusReconnecting SmbMappingStatusUnavailable + + credentialDelimiter = ":" ) +// escapeQueryParameter escapes a parameter for WMI Queries +func escapeQueryParameter(s string) string { + s = strings.ReplaceAll(s, "'", "''") + s = strings.ReplaceAll(s, "\\", "\\\\") + return s +} + +func escapeUserName(userName string) string { + // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170 + userName = strings.ReplaceAll(userName, "\\", "\\\\") + userName = strings.ReplaceAll(userName, credentialDelimiter, "\\"+credentialDelimiter) + return userName +} + // QuerySmbGlobalMappingByRemotePath retrieves the SMB global mapping from its remote path. // // The equivalent WMI query is: @@ -28,7 +46,7 @@ const ( // Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping // for the WMI class definition. func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, error) { - smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath) + smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath)) instances, err := QueryInstances(WMINamespaceSmb, smbQuery) if err != nil { return nil, err @@ -37,12 +55,22 @@ func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, err return instances[0], err } -// RemoveSmbGlobalMappingByRemotePath removes a SMB global mapping matching to the remote path. +// GetSmbGlobalMappingStatus returns the status of an SMB global mapping. +func GetSmbGlobalMappingStatus(inst *cim.WmiInstance) (int32, error) { + statusProp, err := inst.GetProperty("Status") + if err != nil { + return SmbMappingStatusUnavailable, err + } + + return statusProp.(int32), nil +} + +// RemoveSmbGlobalMappingByRemotePath removes an SMB global mapping matching to the remote path. // // Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping // for the WMI class definition. func RemoveSmbGlobalMappingByRemotePath(remotePath string) error { - smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath) + smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", escapeQueryParameter(remotePath)) instances, err := QueryInstances(WMINamespaceSmb, smbQuery) if err != nil { return err @@ -51,3 +79,22 @@ func RemoveSmbGlobalMappingByRemotePath(remotePath string) error { _, err = instances[0].InvokeMethod("Remove", true) return err } + +// NewSmbGlobalMapping creates a new SMB global mapping to the remote path. +// +// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping +// for the WMI class definition. +func NewSmbGlobalMapping(remotePath, username, password string, requirePrivacy bool) (int, error) { + params := map[string]interface{}{ + "RemotePath": remotePath, + "RequirePrivacy": requirePrivacy, + } + if username != "" { + // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178 + // on how SMB credential is handled in PowerShell + params["Credential"] = escapeUserName(username) + credentialDelimiter + password + } + + result, _, err := InvokeCimMethod(WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params) + return result, err +} diff --git a/pkg/cim/volume.go b/pkg/cim/volume.go index b77fd6d6..0c880fe1 100644 --- a/pkg/cim/volume.go +++ b/pkg/cim/volume.go @@ -7,9 +7,23 @@ import ( "fmt" "strconv" + "github.com/go-ole/go-ole" "github.com/microsoft/wmi/pkg/base/query" "github.com/microsoft/wmi/pkg/errors" "github.com/microsoft/wmi/server2019/root/microsoft/windows/storage" + "k8s.io/klog/v2" +) + +const ( + FileSystemUnknown = 0 +) + +var ( + VolumeSelectorListForFileSystemType = []string{"FileSystemType"} + VolumeSelectorListForStats = []string{"UniqueId", "SizeRemaining", "Size"} + VolumeSelectorListUniqueID = []string{"UniqueId"} + + PartitionSelectorListObjectID = []string{"ObjectId"} ) // QueryVolumeByUniqueID retrieves a specific volume by its unique identifier, @@ -78,6 +92,68 @@ func ListVolumes(selectorList []string) ([]*storage.MSFT_Volume, error) { return volumes, nil } +// FormatVolume formats the specified volume. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/format-msft-volume +// for the WMI method definition. +func FormatVolume(volume *storage.MSFT_Volume, params ...interface{}) (int, error) { + result, err := volume.InvokeMethodWithReturn("Format", params...) + return int(result), err +} + +// FlushVolume flushes the cached data in the volume's file system to disk. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-volume-flush +// for the WMI method definition. +func FlushVolume(volume *storage.MSFT_Volume) (int, error) { + result, err := volume.Flush() + return int(result), err +} + +// GetVolumeUniqueID returns the unique ID (object ID) of a volume. +func GetVolumeUniqueID(volume *storage.MSFT_Volume) (string, error) { + return volume.GetPropertyUniqueId() +} + +// GetVolumeFileSystemType returns the file system type of a volume. +func GetVolumeFileSystemType(volume *storage.MSFT_Volume) (int32, error) { + fsType, err := volume.GetProperty("FileSystemType") + if err != nil { + return 0, err + } + return fsType.(int32), nil +} + +// GetVolumeSize returns the size of a volume. +func GetVolumeSize(volume *storage.MSFT_Volume) (int64, error) { + volumeSizeVal, err := volume.GetProperty("Size") + if err != nil { + return -1, err + } + + volumeSize, err := strconv.ParseInt(volumeSizeVal.(string), 10, 64) + if err != nil { + return -1, err + } + + return volumeSize, err +} + +// GetVolumeSizeRemaining returns the remaining size of a volume. +func GetVolumeSizeRemaining(volume *storage.MSFT_Volume) (int64, error) { + volumeSizeRemainingVal, err := volume.GetProperty("SizeRemaining") + if err != nil { + return -1, err + } + + volumeSizeRemaining, err := strconv.ParseInt(volumeSizeRemainingVal.(string), 10, 64) + if err != nil { + return -1, err + } + + return volumeSizeRemaining, err +} + // ListPartitionsOnDisk retrieves all partitions or a partition with the specified number on a disk. // // The equivalent WMI query is: @@ -245,3 +321,78 @@ func GetPartitionDiskNumber(part *storage.MSFT_Partition) (uint32, error) { return uint32(diskNumber.(int32)), nil } + +// SetPartitionState takes a partition online or offline. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-online and +// https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-offline +// for the WMI method definition. +func SetPartitionState(part *storage.MSFT_Partition, online bool) (int, string, error) { + method := "Offline" + if online { + method = "Online" + } + + var status string + result, err := part.InvokeMethodWithReturn(method, &status) + return int(result), status, err +} + +// GetPartitionSupportedSize retrieves the minimum and maximum sizes that the partition can be resized to using the ResizePartition method. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-getsupportedsizes +// for the WMI method definition. +func GetPartitionSupportedSize(part *storage.MSFT_Partition) (result int, sizeMin, sizeMax int64, status string, err error) { + sizeMin = -1 + sizeMax = -1 + + var sizeMinVar, sizeMaxVar ole.VARIANT + invokeResult, err := part.InvokeMethodWithReturn("GetSupportedSize", &sizeMinVar, &sizeMaxVar, &status) + if invokeResult != 0 || err != nil { + result = int(invokeResult) + } + klog.V(5).Infof("got sizeMin (%v) sizeMax (%v) from partition (%v), status: %s", sizeMinVar, sizeMaxVar, part, status) + + sizeMin, err = strconv.ParseInt(sizeMinVar.ToString(), 10, 64) + if err != nil { + return + } + + sizeMax, err = strconv.ParseInt(sizeMaxVar.ToString(), 10, 64) + return +} + +// ResizePartition resizes a partition. +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partition-resize +// for the WMI method definition. +func ResizePartition(part *storage.MSFT_Partition, size int64) (int, string, error) { + var status string + result, err := part.InvokeMethodWithReturn("Resize", strconv.Itoa(int(size)), &status) + return int(result), status, err +} + +// GetPartitionSize returns the size of a partition. +func GetPartitionSize(part *storage.MSFT_Partition) (int64, error) { + sizeProp, err := part.GetProperty("Size") + if err != nil { + return -1, err + } + + size, err := strconv.ParseInt(sizeProp.(string), 10, 64) + if err != nil { + return -1, err + } + + return size, err +} + +// FilterForPartitionOnDisk creates a WMI query filter to query a disk by its number. +func FilterForPartitionOnDisk(diskNumber uint32) *query.WmiQueryFilter { + return query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals) +} + +// FilterForPartitionsOfTypeNormal creates a WMI query filter for all non-reserved partitions. +func FilterForPartitionsOfTypeNormal() *query.WmiQueryFilter { + return query.NewWmiQueryFilter("GptType", GPTPartitionTypeMicrosoftReserved, query.NotEquals) +} diff --git a/pkg/cim/wmi.go b/pkg/cim/wmi.go index 1dacce8a..ba75f747 100644 --- a/pkg/cim/wmi.go +++ b/pkg/cim/wmi.go @@ -9,7 +9,7 @@ import ( "github.com/go-ole/go-ole" "github.com/go-ole/go-ole/oleutil" "github.com/microsoft/wmi/pkg/base/query" - "github.com/microsoft/wmi/pkg/errors" + wmierrors "github.com/microsoft/wmi/pkg/errors" cim "github.com/microsoft/wmi/pkg/wmiinstance" "k8s.io/klog/v2" ) @@ -61,7 +61,7 @@ func QueryFromWMI(namespace string, query *query.WmiQuery, handler InstanceHandl } if len(instances) == 0 { - return errors.NotFound + return wmierrors.NotFound } var cont bool @@ -95,7 +95,7 @@ func executeClassMethodParam(classInst *cim.WmiInstance, method *cim.WmiMethod, iDispatchInstance := classInst.GetIDispatch() if iDispatchInstance == nil { - return nil, errors.Wrapf(errors.InvalidInput, "InvalidInstance") + return nil, wmierrors.Wrapf(wmierrors.InvalidInput, "InvalidInstance") } rawResult, err := iDispatchInstance.GetProperty("Methods_") if err != nil { @@ -235,10 +235,15 @@ func InvokeCimMethod(namespace, class, methodName string, inputParameters map[st return int(result.ReturnValue), outputParameters, nil } +// IsNotFound returns true if it's a "not found" error. +func IsNotFound(err error) bool { + return wmierrors.IsNotFound(err) +} + // IgnoreNotFound returns nil if the error is nil or a "not found" error, // otherwise returns the original error. func IgnoreNotFound(err error) error { - if err == nil || errors.IsNotFound(err) { + if err == nil || IsNotFound(err) { return nil } return err diff --git a/pkg/os/disk/api.go b/pkg/os/disk/api.go index dc8637fd..b366b977 100644 --- a/pkg/os/disk/api.go +++ b/pkg/os/disk/api.go @@ -3,14 +3,12 @@ package disk import ( "encoding/hex" "fmt" - "strconv" "strings" "syscall" "unsafe" "github.com/kubernetes-csi/csi-proxy/pkg/cim" shared "github.com/kubernetes-csi/csi-proxy/pkg/shared/disk" - "github.com/microsoft/wmi/pkg/base/query" "k8s.io/klog/v2" ) @@ -67,19 +65,19 @@ func New() DiskAPI { // as the value. The DiskLocation struct has various fields like the Adapter, Bus, Target and LUNID. func (imp DiskAPI) ListDiskLocations() (map[uint32]shared.DiskLocation, error) { // "location": "PCI Slot 3 : Adapter 0 : Port 0 : Target 1 : LUN 0" - disks, err := cim.ListDisks([]string{"Number", "Location"}) + disks, err := cim.ListDisks(cim.DiskSelectorListForDiskNumberAndLocation) if err != nil { return nil, fmt.Errorf("could not query disk locations") } m := make(map[uint32]shared.DiskLocation) for _, disk := range disks { - num, err := disk.GetProperty("Number") + num, err := cim.GetDiskNumber(disk) if err != nil { return m, fmt.Errorf("failed to query disk number: %v, %w", disk, err) } - location, err := disk.GetPropertyLocation() + location, err := cim.GetDiskLocation(disk) if err != nil { return m, fmt.Errorf("failed to query disk location: %v, %w", disk, err) } @@ -107,7 +105,7 @@ func (imp DiskAPI) ListDiskLocations() (map[uint32]shared.DiskLocation, error) { } if found { - m[uint32(num.(int32))] = d + m[num] = d } } } @@ -116,7 +114,7 @@ func (imp DiskAPI) ListDiskLocations() (map[uint32]shared.DiskLocation, error) { } func (imp DiskAPI) Rescan() error { - result, _, err := cim.InvokeCimMethod(cim.WMINamespaceStorage, "MSFT_StorageSetting", "UpdateHostStorageCache", nil) + result, err := cim.RescanDisks() if err != nil { return fmt.Errorf("error updating host storage cache output. result: %d, err: %v", result, err) } @@ -124,18 +122,16 @@ func (imp DiskAPI) Rescan() error { } func (imp DiskAPI) IsDiskInitialized(diskNumber uint32) (bool, error) { - var partitionStyle int32 - disk, err := cim.QueryDiskByNumber(diskNumber, []string{"PartitionStyle"}) + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForPartitionStyle) if err != nil { - return false, fmt.Errorf("error checking initialized status of disk %d. %v", diskNumber, err) + return false, fmt.Errorf("error checking initialized status of disk %d: %v", diskNumber, err) } - retValue, err := disk.GetProperty("PartitionStyle") + partitionStyle, err := cim.GetDiskPartitionStyle(disk) if err != nil { - return false, fmt.Errorf("failed to query partition style of disk %d: %w", diskNumber, err) + return false, fmt.Errorf("failed to query partition style of disk %d: %v", diskNumber, err) } - partitionStyle = retValue.(int32) return partitionStyle != cim.PartitionStyleUnknown, nil } @@ -145,7 +141,7 @@ func (imp DiskAPI) InitializeDisk(diskNumber uint32) error { return fmt.Errorf("failed to initializing disk %d. error: %w", diskNumber, err) } - result, err := disk.InvokeMethodWithReturn("Initialize", int32(cim.PartitionStyleGPT)) + result, err := cim.InitializeDisk(disk, cim.PartitionStyleGPT) if result != 0 || err != nil { return fmt.Errorf("failed to initializing disk %d: result %d, error: %w", diskNumber, result, err) } @@ -154,9 +150,7 @@ func (imp DiskAPI) InitializeDisk(diskNumber uint32) error { } func (imp DiskAPI) BasicPartitionsExist(diskNumber uint32) (bool, error) { - partitions, err := cim.ListPartitionsWithFilters(nil, - query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals), - query.NewWmiQueryFilter("GptType", cim.GPTPartitionTypeMicrosoftReserved, query.NotEquals)) + partitions, err := cim.ListPartitionsWithFilters(nil, cim.FilterForPartitionOnDisk(diskNumber), cim.FilterForPartitionsOfTypeNormal()) if cim.IgnoreNotFound(err) != nil { return false, fmt.Errorf("error checking presence of partitions on disk %d:, %v", diskNumber, err) } @@ -170,8 +164,8 @@ func (imp DiskAPI) CreateBasicPartition(diskNumber uint32) error { return err } - result, err := disk.InvokeMethodWithReturn( - "CreatePartition", + result, err := cim.CreatePartition( + disk, nil, // Size true, // UseMaximumSize nil, // Offset @@ -183,20 +177,16 @@ func (imp DiskAPI) CreateBasicPartition(diskNumber uint32) error { false, // IsHidden false, // IsActive, ) - // 42002 is returned by driver letter failed to assign after partition - if (result != 0 && result != 42002) || err != nil { + if (result != 0 && result != cim.ErrorCodeCreatePartitionAccessPathAlreadyInUse) || err != nil { return fmt.Errorf("error creating partition on disk %d. result: %d, err: %v", diskNumber, result, err) } - var status string - result, err = disk.InvokeMethodWithReturn("Refresh", &status) + result, _, err = cim.RefreshDisk(disk) if result != 0 || err != nil { return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err) } - partitions, err := cim.ListPartitionsWithFilters(nil, - query.NewWmiQueryFilter("DiskNumber", strconv.Itoa(int(diskNumber)), query.Equals), - query.NewWmiQueryFilter("GptType", cim.GPTPartitionTypeMicrosoftReserved, query.NotEquals)) + partitions, err := cim.ListPartitionsWithFilters(nil, cim.FilterForPartitionOnDisk(diskNumber), cim.FilterForPartitionsOfTypeNormal()) if err != nil { return fmt.Errorf("error query basic partition on disk %d:, %v", diskNumber, err) } @@ -206,13 +196,12 @@ func (imp DiskAPI) CreateBasicPartition(diskNumber uint32) error { } partition := partitions[0] - result, err = partition.InvokeMethodWithReturn("Online", status) + result, status, err := cim.SetPartitionState(partition, true) if result != 0 || err != nil { return fmt.Errorf("error bring partition %v on disk %d online. result: %d, status %s, err: %v", partition, diskNumber, result, status, err) } - err = partition.Refresh() - return err + return nil } func (imp DiskAPI) GetDiskNumberByName(page83ID string) (uint32, error) { @@ -272,13 +261,13 @@ func (imp DiskAPI) GetDiskPage83ID(disk syscall.Handle) (string, error) { } func (imp DiskAPI) GetDiskNumberWithID(page83ID string) (uint32, error) { - disks, err := cim.ListDisks([]string{"Path", "SerialNumber"}) + disks, err := cim.ListDisks(cim.DiskSelectorListForPathAndSerialNumber) if err != nil { return 0, err } for _, disk := range disks { - path, err := disk.GetPropertyPath() + path, err := cim.GetDiskPath(disk) if err != nil { return 0, fmt.Errorf("failed to query disk path: %v, %w", disk, err) } @@ -319,19 +308,19 @@ func (imp DiskAPI) GetDiskNumberAndPage83ID(path string) (uint32, string, error) // ListDiskIDs - constructs a map with the disk number as the key and the DiskID structure // as the value. The DiskID struct has a field for the page83 ID. func (imp DiskAPI) ListDiskIDs() (map[uint32]shared.DiskIDs, error) { - disks, err := cim.ListDisks([]string{"Path", "SerialNumber"}) + disks, err := cim.ListDisks(cim.DiskSelectorListForPathAndSerialNumber) if err != nil { return nil, err } m := make(map[uint32]shared.DiskIDs) for _, disk := range disks { - path, err := disk.GetPropertyPath() + path, err := cim.GetDiskPath(disk) if err != nil { return m, fmt.Errorf("failed to query disk path: %v, %w", disk, err) } - sn, err := disk.GetPropertySerialNumber() + sn, err := cim.GetDiskSerialNumber(disk) if err != nil { return m, fmt.Errorf("failed to query disk serial number: %v, %w", disk, err) } @@ -351,56 +340,49 @@ func (imp DiskAPI) ListDiskIDs() (map[uint32]shared.DiskIDs, error) { func (imp DiskAPI) GetDiskStats(diskNumber uint32) (int64, error) { // TODO: change to uint64 as it does not make sense to use int64 for size - var size int64 - disk, err := cim.QueryDiskByNumber(diskNumber, []string{"Size"}) + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForSize) if err != nil { return -1, err } - sz, err := disk.GetProperty("Size") + size, err := cim.GetDiskSize(disk) if err != nil { return -1, fmt.Errorf("failed to query size of disk %d. %v", diskNumber, err) } - size, err = strconv.ParseInt(sz.(string), 10, 64) return size, err } func (imp DiskAPI) SetDiskState(diskNumber uint32, isOnline bool) error { - disk, err := cim.QueryDiskByNumber(diskNumber, []string{"IsOffline"}) + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForIsOffline) if err != nil { return err } - offline, err := disk.GetPropertyIsOffline() + isOffline, err := cim.IsDiskOffline(disk) if err != nil { return fmt.Errorf("error setting disk %d attach state. error: %v", diskNumber, err) } - if isOnline == !offline { + if isOnline == !isOffline { return nil } - method := "Offline" - if isOnline { - method = "Online" - } - - result, err := disk.InvokeMethodWithReturn(method) + result, _, err := cim.SetDiskState(disk, isOnline) if result != 0 || err != nil { - return fmt.Errorf("setting disk %d attach state %s: result %d, error: %w", diskNumber, method, result, err) + return fmt.Errorf("setting disk %d attach state (isOnline: %v): result %d, error: %w", diskNumber, isOnline, result, err) } return nil } func (imp DiskAPI) GetDiskState(diskNumber uint32) (bool, error) { - disk, err := cim.QueryDiskByNumber(diskNumber, []string{"IsOffline"}) + disk, err := cim.QueryDiskByNumber(diskNumber, cim.DiskSelectorListForIsOffline) if err != nil { return false, err } - isOffline, err := disk.GetPropertyIsOffline() + isOffline, err := cim.IsDiskOffline(disk) if err != nil { return false, fmt.Errorf("error parsing disk %d state. error: %v", diskNumber, err) } diff --git a/pkg/os/filesystem/api.go b/pkg/os/filesystem/api.go index a2fc4d26..458fd89f 100644 --- a/pkg/os/filesystem/api.go +++ b/pkg/os/filesystem/api.go @@ -112,13 +112,18 @@ func (filesystemAPI) IsSymlink(tgt string) (bool, error) { // This code is similar to k8s.io/kubernetes/pkg/util/mount except the pathExists usage. // Also in a remote call environment the os error cannot be passed directly back, hence the callers // are expected to perform the isExists check before calling this call in CSI proxy. - stat, err := os.Lstat(tgt) + isSymlink, err := utils.IsPathSymlink(tgt) if err != nil { return false, err } - // If its a link and it points to an existing file then its a mount point. - if stat.Mode()&os.ModeSymlink != 0 { + // mounted folder created by SetVolumeMountPoint may still report ModeSymlink == 0 + mountedFolder, err := utils.IsMountedFolder(tgt) + if err != nil { + return false, err + } + + if isSymlink || mountedFolder { target, err := os.Readlink(tgt) if err != nil { return false, fmt.Errorf("readlink error: %v", err) diff --git a/pkg/os/smb/api.go b/pkg/os/smb/api.go index 20b9544e..f0d28da4 100644 --- a/pkg/os/smb/api.go +++ b/pkg/os/smb/api.go @@ -3,15 +3,9 @@ package smb import ( "fmt" "strings" - "syscall" "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/kubernetes-csi/csi-proxy/pkg/utils" - "golang.org/x/sys/windows" -) - -const ( - credentialDelimiter = ":" ) type API interface { @@ -33,61 +27,23 @@ func New(requirePrivacy bool) *SmbAPI { } } -func remotePathForQuery(remotePath string) string { - return strings.ReplaceAll(remotePath, "\\", "\\\\") -} - -func escapeUserName(userName string) string { - // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170 - escaped := strings.ReplaceAll(userName, "\\", "\\\\") - escaped = strings.ReplaceAll(escaped, credentialDelimiter, "\\"+credentialDelimiter) - return escaped -} - -func createSymlink(link, target string, isDir bool) error { - linkPtr, err := syscall.UTF16PtrFromString(link) - if err != nil { - return err - } - targetPtr, err := syscall.UTF16PtrFromString(target) - if err != nil { - return err - } - - var flags uint32 - if isDir { - flags = windows.SYMBOLIC_LINK_FLAG_DIRECTORY - } - - err = windows.CreateSymbolicLink( - linkPtr, - targetPtr, - flags, - ) - return err -} - func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) { - inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePathForQuery(remotePath)) + inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePath) if err != nil { return false, cim.IgnoreNotFound(err) } - status, err := inst.GetProperty("Status") + status, err := cim.GetSmbGlobalMappingStatus(inst) if err != nil { return false, err } - return status.(int32) == cim.SmbMappingStatusOK, nil + return status == cim.SmbMappingStatusOK, nil } // NewSmbLink - creates a directory symbolic link to the remote share. // The os.Symlink was having issue for cases where the destination was an SMB share - the container -// runtime would complain stating "Access Denied". Because of this, we had to perform -// this operation with powershell commandlet creating an directory softlink. -// Since os.Symlink is currently being used in working code paths, no attempt is made in -// alpha to merge the paths. -// TODO (for beta release): Merge the link paths - os.Symlink and Powershell link path. +// runtime would complain stating "Access Denied". func (*SmbAPI) NewSmbLink(remotePath, localPath string) error { if !strings.HasSuffix(remotePath, "\\") { // Golang has issues resolving paths mapped to file shares if they do not end in a trailing \ @@ -97,7 +53,7 @@ func (*SmbAPI) NewSmbLink(remotePath, localPath string) error { longRemotePath := utils.EnsureLongPath(remotePath) longLocalPath := utils.EnsureLongPath(localPath) - err := createSymlink(longLocalPath, longRemotePath, true) + err := utils.CreateSymlink(longLocalPath, longRemotePath, true) if err != nil { return fmt.Errorf("error linking %s to %s. err: %v", remotePath, localPath, err) } @@ -106,17 +62,7 @@ func (*SmbAPI) NewSmbLink(remotePath, localPath string) error { } func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) error { - params := map[string]interface{}{ - "RemotePath": remotePath, - "RequirePrivacy": api.RequirePrivacy, - } - if username != "" { - // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178 - // on how SMB credential is handled in PowerShell - params["Credential"] = escapeUserName(username) + credentialDelimiter + password - } - - result, _, err := cim.InvokeCimMethod(cim.WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params) + result, err := cim.NewSmbGlobalMapping(remotePath, username, password, api.RequirePrivacy) if err != nil { return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err) } @@ -125,7 +71,7 @@ func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) er } func (*SmbAPI) RemoveSmbGlobalMapping(remotePath string) error { - err := cim.RemoveSmbGlobalMappingByRemotePath(remotePathForQuery(remotePath)) + err := cim.RemoveSmbGlobalMappingByRemotePath(remotePath) if err != nil { return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err) } diff --git a/pkg/os/volume/api.go b/pkg/os/volume/api.go index fcd2e6f8..ea667798 100644 --- a/pkg/os/volume/api.go +++ b/pkg/os/volume/api.go @@ -2,21 +2,21 @@ package volume import ( "fmt" - "os" "path/filepath" "regexp" - "strconv" "strings" - "github.com/go-ole/go-ole" "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/kubernetes-csi/csi-proxy/pkg/utils" - wmierrors "github.com/microsoft/wmi/pkg/errors" "github.com/pkg/errors" "golang.org/x/sys/windows" "k8s.io/klog/v2" ) +const ( + minimumResizeSize = 100 * 1024 * 1024 +) + // API exposes the internal volume operations available in the server type API interface { // ListVolumesOnDisk lists volumes on a disk identified by a `diskNumber` and optionally a partition identified by `partitionNumber`. @@ -69,7 +69,7 @@ func New() VolumeAPI { // ListVolumesOnDisk - returns back list of volumes(volumeIDs) in a disk and a partition. func (VolumeAPI) ListVolumesOnDisk(diskNumber uint32, partitionNumber uint32) (volumeIDs []string, err error) { - partitions, err := cim.ListPartitionsOnDisk(diskNumber, partitionNumber, []string{"ObjectId"}) + partitions, err := cim.ListPartitionsOnDisk(diskNumber, partitionNumber, cim.PartitionSelectorListObjectID) if err != nil { return nil, errors.Wrapf(err, "failed to list partition on disk %d", diskNumber) } @@ -80,9 +80,9 @@ func (VolumeAPI) ListVolumesOnDisk(diskNumber uint32, partitionNumber uint32) (v } for _, volume := range volumes { - uniqueID, err := volume.GetPropertyUniqueId() + uniqueID, err := cim.GetVolumeUniqueID(volume) if err != nil { - return nil, errors.Wrapf(err, "failed to list volumes") + return nil, errors.Wrapf(err, "failed to get unique ID for volume %v", volume) } volumeIDs = append(volumeIDs, uniqueID) } @@ -97,8 +97,7 @@ func (VolumeAPI) FormatVolume(volumeID string) (err error) { return fmt.Errorf("error formatting volume (%s). error: %v", volumeID, err) } - result, err := volume.InvokeMethodWithReturn( - "Format", + result, err := cim.FormatVolume(volume, "NTFS", // Format, "", // FileSystemLabel, nil, // AllocationUnitSize, @@ -113,7 +112,6 @@ func (VolumeAPI) FormatVolume(volumeID string) (err error) { if result != 0 || err != nil { return fmt.Errorf("error formatting volume (%s). result: %d, error: %v", volumeID, result, err) } - // TODO: Do we need to handle anything for len(out) == 0 return nil } @@ -124,18 +122,17 @@ func (VolumeAPI) WriteVolumeCache(volumeID string) (err error) { // IsVolumeFormatted - Check if the volume is formatted with the pre specified filesystem(typically ntfs). func (VolumeAPI) IsVolumeFormatted(volumeID string) (bool, error) { - volume, err := cim.QueryVolumeByUniqueID(volumeID, []string{"FileSystemType"}) + volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForFileSystemType) if err != nil { return false, fmt.Errorf("error checking if volume (%s) is formatted. error: %v", volumeID, err) } - fsType, err := volume.GetProperty("FileSystemType") + fsType, err := cim.GetVolumeFileSystemType(volume) if err != nil { return false, fmt.Errorf("failed to query volume file system type (%s): %w", volumeID, err) } - const FileSystemUnknown = 0 - return fsType.(int32) != FileSystemUnknown, nil + return fsType != cim.FileSystemUnknown, nil } // MountVolume - mounts a volume to a path. This is done using Win32 API SetVolumeMountPoint for presenting the volume via a path. @@ -194,36 +191,25 @@ func (VolumeAPI) ResizeVolume(volumeID string, size int64) error { // If size is 0 then we will resize to the maximum size possible, otherwise just resize to size if size == 0 { - var sizeMin, sizeMax ole.VARIANT + var result int var status string - result, err := part.InvokeMethodWithReturn("GetSupportedSize", &sizeMin, &sizeMax, &status) + result, _, finalSize, status, err = cim.GetPartitionSupportedSize(part) if result != 0 || err != nil { - return fmt.Errorf("error getting sizeMin, sizeMax from volume(%s). result: %d, status: %s, error: %v", volumeID, result, status, err) + return fmt.Errorf("error getting sizeMin, sizeMax from volume (%s). result: %d, status: %s, error: %v", volumeID, result, status, err) } - klog.V(5).Infof("got sizeMin(%v) sizeMax(%v) from volume(%s), status: %s", sizeMin, sizeMax, volumeID, status) - finalSizeStr := sizeMax.ToString() - finalSize, err = strconv.ParseInt(finalSizeStr, 10, 64) - if err != nil { - return fmt.Errorf("error parsing the sizeMax of volume (%s) with error (%v)", volumeID, err) - } } else { finalSize = size } - currentSizeVal, err := part.GetProperty("Size") + currentSize, err := cim.GetPartitionSize(part) if err != nil { return fmt.Errorf("error getting the current size of volume (%s) with error (%v)", volumeID, err) } - currentSize, err := strconv.ParseInt(currentSizeVal.(string), 10, 64) - if err != nil { - return fmt.Errorf("error parsing the current size of volume (%s) with error (%v)", volumeID, err) - } - // only resize if finalSize - currentSize is greater than 100MB - if finalSize-currentSize < 100*1024*1024 { - klog.V(2).Infof("minimum resize difference(1GB) not met, skipping resize. volumeID=%s currentSize=%d finalSize=%d", volumeID, currentSize, finalSize) + if finalSize-currentSize < minimumResizeSize { + klog.V(2).Infof("minimum resize difference (100MB) not met, skipping resize. volumeID=%s currentSize=%d finalSize=%d", volumeID, currentSize, finalSize) return nil } @@ -233,9 +219,7 @@ func (VolumeAPI) ResizeVolume(volumeID string, size int64) error { return nil } - var status string - result, err := part.InvokeMethodWithReturn("Resize", strconv.Itoa(int(finalSize)), &status) - + result, _, err := cim.ResizePartition(part, finalSize) if result != 0 || err != nil { return fmt.Errorf("error resizing volume (%s). size:%v, finalSize %v, error: %v", volumeID, size, finalSize, err) } @@ -247,10 +231,10 @@ func (VolumeAPI) ResizeVolume(volumeID string, size int64) error { disk, err := cim.QueryDiskByNumber(diskNumber, nil) if err != nil { - return fmt.Errorf("error parsing disk number of volume (%s). error: %v", volumeID, err) + return fmt.Errorf("error query disk of volume (%s). error: %v", volumeID, err) } - result, err = disk.InvokeMethodWithReturn("Refresh", &status) + result, _, err = cim.RefreshDisk(disk) if result != 0 || err != nil { return fmt.Errorf("error rescan disk (%d). result %d, error: %v", diskNumber, result, err) } @@ -260,31 +244,21 @@ func (VolumeAPI) ResizeVolume(volumeID string, size int64) error { // GetVolumeStats - retrieves the volume stats for a given volume func (VolumeAPI) GetVolumeStats(volumeID string) (int64, int64, error) { - volume, err := cim.QueryVolumeByUniqueID(volumeID, []string{"UniqueId", "SizeRemaining", "Size"}) + volume, err := cim.QueryVolumeByUniqueID(volumeID, cim.VolumeSelectorListForStats) if err != nil { return -1, -1, fmt.Errorf("error getting capacity and used size of volume (%s). error: %v", volumeID, err) } - volumeSizeVal, err := volume.GetProperty("Size") + volumeSize, err := cim.GetVolumeSize(volume) if err != nil { return -1, -1, fmt.Errorf("failed to query volume size (%s): %w", volumeID, err) } - volumeSize, err := strconv.ParseInt(volumeSizeVal.(string), 10, 64) - if err != nil { - return -1, -1, fmt.Errorf("failed to parse volume size (%s): %w", volumeID, err) - } - - volumeSizeRemainingVal, err := volume.GetProperty("SizeRemaining") + volumeSizeRemaining, err := cim.GetVolumeSizeRemaining(volume) if err != nil { return -1, -1, fmt.Errorf("failed to query volume remaining size (%s): %w", volumeID, err) } - volumeSizeRemaining, err := strconv.ParseInt(volumeSizeRemainingVal.(string), 10, 64) - if err != nil { - return -1, -1, fmt.Errorf("failed to parse volume remaining size (%s): %w", volumeID, err) - } - volumeUsedSize := volumeSize - volumeSizeRemaining return volumeSize, volumeUsedSize, nil } @@ -297,12 +271,12 @@ func (VolumeAPI) GetDiskNumberFromVolumeID(volumeID string) (uint32, error) { return 0, err } - diskNumber, err := part.GetProperty("DiskNumber") + diskNumber, err := cim.GetPartitionDiskNumber(part) if err != nil { return 0, fmt.Errorf("error query disk number of volume (%s). error: %v", volumeID, err) } - return uint32(diskNumber.(int32)), nil + return diskNumber, nil } // GetVolumeIDFromTargetPath - gets the volume ID given a mount point, the function is recursive until it find a volume or errors out @@ -316,7 +290,7 @@ func (VolumeAPI) GetVolumeIDFromTargetPath(mount string) (string, error) { } func getTarget(mount string) (string, error) { - mountedFolder, err := isMountedFolder(mount) + mountedFolder, err := utils.IsMountedFolder(mount) if err != nil { return "", err } @@ -356,7 +330,7 @@ func (VolumeAPI) GetClosestVolumeIDFromTargetPath(targetPath string) (string, er } // findClosestVolume finds the closest volume id for a given target path -// by following symlinks and moving up in the filesystem, if after moving up in the filesystem +// by following symlinks and moving up in the filesystem. if after moving up in the filesystem // we get to a DriveLetter then the volume corresponding to this drive letter is returned instead. func findClosestVolume(path string) (string, error) { candidatePath := path @@ -365,22 +339,20 @@ func findClosestVolume(path string) (string, error) { // while trying to follow symlinks // // The maximum path length in Windows is 260, it could be possible to end - // up in a sceneario where we do more than 256 iterations (e.g. by following symlinks from + // up in a scenario where we do more than 256 iterations (e.g. by following symlinks from // a place high in the hierarchy to a nested sibling location many times) // https://docs.microsoft.com/en-us/windows/win32/fileio/naming-a-file#:~:text=In%20editions%20of%20Windows%20before,required%20to%20remove%20the%20limit. // // The number of iterations is 256, which is similar to the number of iterations in filepath-securejoin // https://github.com/cyphar/filepath-securejoin/blob/64536a8a66ae59588c981e2199f1dcf410508e07/join.go#L51 for i := 0; i < 256; i += 1 { - fi, err := os.Lstat(candidatePath) + isSymlink, err := utils.IsPathSymlink(candidatePath) if err != nil { return "", err } - // for windows NTFS, check if the path is symlink instead of directory. - isSymlink := fi.Mode()&os.ModeSymlink != 0 || fi.Mode()&os.ModeIrregular != 0 // mounted folder created by SetVolumeMountPoint may still report ModeSymlink == 0 - mountedFolder, err := isMountedFolder(candidatePath) + mountedFolder, err := utils.IsMountedFolder(candidatePath) if err != nil { return "", err } @@ -417,51 +389,18 @@ func findClosestVolume(path string) (string, error) { return "", fmt.Errorf("failed to find the closest volume for path=%s", path) } -// isMountedFolder checks whether the `path` is a mounted folder. -func isMountedFolder(path string) (bool, error) { - // https://learn.microsoft.com/en-us/windows/win32/fileio/determining-whether-a-directory-is-a-volume-mount-point - utf16Path, _ := windows.UTF16PtrFromString(path) - attrs, err := windows.GetFileAttributes(utf16Path) - if err != nil { - return false, err - } - - if (attrs & windows.FILE_ATTRIBUTE_REPARSE_POINT) == 0 { - return false, nil - } - - var findData windows.Win32finddata - findHandle, err := windows.FindFirstFile(utf16Path, &findData) - if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) { - return false, err - } - - for err == nil { - if findData.Reserved0&windows.IO_REPARSE_TAG_MOUNT_POINT != 0 { - return true, nil - } - - err = windows.FindNextFile(findHandle, &findData) - if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) { - return false, err - } - } - - return false, nil -} - // getVolumeForDriveLetter gets a volume from a drive letter (e.g. C:/). func getVolumeForDriveLetter(path string) (string, error) { if len(path) != 1 { return "", fmt.Errorf("the path %s is not a valid drive letter", path) } - volume, err := cim.GetVolumeByDriveLetter(path, []string{"UniqueId"}) + volume, err := cim.GetVolumeByDriveLetter(path, cim.VolumeSelectorListUniqueID) if err != nil { return "", nil } - uniqueID, err := volume.GetPropertyUniqueId() + uniqueID, err := cim.GetVolumeUniqueID(volume) if err != nil { return "", fmt.Errorf("error query unique ID of volume (%v). error: %v", volume, err) } @@ -470,12 +409,12 @@ func getVolumeForDriveLetter(path string) (string, error) { } func writeCache(volumeID string) error { - volume, err := cim.QueryVolumeByUniqueID(volumeID, []string{}) - if err != nil && !wmierrors.IsNotFound(err) { + volume, err := cim.QueryVolumeByUniqueID(volumeID, nil) + if err != nil { return fmt.Errorf("error writing volume (%s) cache. error: %v", volumeID, err) } - result, err := volume.Flush() + result, err := cim.FlushVolume(volume) if result != 0 || err != nil { return fmt.Errorf("error writing volume (%s) cache. result: %d, error: %v", volumeID, result, err) } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index bfe446f7..90c21125 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -3,7 +3,6 @@ package utils import ( "fmt" "os" - "os/exec" "strings" "github.com/pkg/errors" @@ -25,14 +24,6 @@ func EnsureLongPath(path string) string { return path } -func RunPowershellCmd(command string, envs ...string) ([]byte, error) { - cmd := exec.Command("powershell", "-Mta", "-NoProfile", "-Command", command) - cmd.Env = append(os.Environ(), envs...) - klog.V(8).Infof("Executing command: %q", cmd.String()) - out, err := cmd.CombinedOutput() - return out, err -} - func IsPathValid(path string) (bool, error) { pathString, err := windows.UTF16PtrFromString(path) if err != nil { @@ -52,3 +43,69 @@ func IsPathValid(path string) (bool, error) { klog.V(6).Infof("Path %s attribute: %d", path, attrs) return attrs != windows.INVALID_FILE_ATTRIBUTES, nil } + +// IsMountedFolder checks whether the `path` is a mounted folder. +func IsMountedFolder(path string) (bool, error) { + // https://learn.microsoft.com/en-us/windows/win32/fileio/determining-whether-a-directory-is-a-volume-mount-point + utf16Path, _ := windows.UTF16PtrFromString(path) + attrs, err := windows.GetFileAttributes(utf16Path) + if err != nil { + return false, err + } + + if (attrs & windows.FILE_ATTRIBUTE_REPARSE_POINT) == 0 { + return false, nil + } + + var findData windows.Win32finddata + findHandle, err := windows.FindFirstFile(utf16Path, &findData) + if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) { + return false, err + } + + for err == nil { + if findData.Reserved0&windows.IO_REPARSE_TAG_MOUNT_POINT != 0 { + return true, nil + } + + err = windows.FindNextFile(findHandle, &findData) + if err != nil && !errors.Is(err, windows.ERROR_NO_MORE_FILES) { + return false, err + } + } + + return false, nil +} + +func IsPathSymlink(path string) (bool, error) { + fi, err := os.Lstat(path) + if err != nil { + return false, err + } + // for windows NTFS, check if the path is symlink instead of directory. + isSymlink := fi.Mode()&os.ModeSymlink != 0 || fi.Mode()&os.ModeIrregular != 0 + return isSymlink, nil +} + +func CreateSymlink(link, target string, isDir bool) error { + linkPtr, err := windows.UTF16PtrFromString(link) + if err != nil { + return err + } + targetPtr, err := windows.UTF16PtrFromString(target) + if err != nil { + return err + } + + var flags uint32 + if isDir { + flags = windows.SYMBOLIC_LINK_FLAG_DIRECTORY + } + + err = windows.CreateSymbolicLink( + linkPtr, + targetPtr, + flags, + ) + return err +}