diff --git a/docs/IMPLEMENTATION.md b/docs/IMPLEMENTATION.md index 3326f2f2..3468e3c4 100644 --- a/docs/IMPLEMENTATION.md +++ b/docs/IMPLEMENTATION.md @@ -121,6 +121,20 @@ func CallMethod(disp *ole.IDispatch, name string, params ...interface{}) (result } ``` +### Association + +Association can be used to retrieve all instances that are associated with +a particular source instance. + +There are a few Association classes in WMI. + +For example, association class [MSFT_PartitionToVolume](https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partitiontovolume) +can be used to retrieve a volume (`MSFT_Volume`) from a partition (`MSFT_Partition`), and vice versa. + +```go +collection, err := part.GetAssociated("MSFT_PartitionToVolume", "MSFT_Volume", "Volume", "Partition") +``` + ## Debug with PowerShell @@ -181,6 +195,13 @@ PS C:\Users\Administrator> $vol.FileSystem NTFS ``` +### Association + +```powershell +PS C:\Users\Administrator> $partition = (Get-CimInstance -Namespace root\Microsoft\Windows\Storage -ClassName MSFT_Partition -Filter "DiskNumber = 0")[0] +PS C:\Users\Administrator> Get-CimAssociatedInstance -InputObject $partition -Association MSFT_PartitionToVolume +``` + ### Call Class Method You may get Class Methods for a single CIM class using `$class.CimClassMethods`. diff --git a/pkg/cim/volume.go b/pkg/cim/volume.go index 085289ba..b77fd6d6 100644 --- a/pkg/cim/volume.go +++ b/pkg/cim/volume.go @@ -9,7 +9,6 @@ import ( "github.com/microsoft/wmi/pkg/base/query" "github.com/microsoft/wmi/pkg/errors" - cim "github.com/microsoft/wmi/pkg/wmiinstance" "github.com/microsoft/wmi/server2019/root/microsoft/windows/storage" ) @@ -129,131 +128,84 @@ func ListPartitionsWithFilters(selectorList []string, filters ...*query.WmiQuery return partitions, nil } -// ListPartitionToVolumeMappings builds a mapping between partition and volume with partition Object ID as the key. -// -// The equivalent WMI query is: -// -// SELECT [selectors] FROM MSFT_PartitionToVolume -// -// Partition | Volume -// --------- | ------ -// MSFT_Partition (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Win...) | MSFT_Volume (ObjectId = "{1}\\WIN-8E2EVAQ9QS... -// -// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partitiontovolume -// for the WMI class definition. -func ListPartitionToVolumeMappings() (map[string]string, error) { - return ListWMIInstanceMappings(WMINamespaceStorage, "MSFT_PartitionToVolume", nil, - mappingObjectRefIndexer("Partition", "MSFT_Partition", "ObjectId"), - mappingObjectRefIndexer("Volume", "MSFT_Volume", "ObjectId"), - ) -} - -// ListVolumeToPartitionMappings builds a mapping between volume and partition with volume Object ID as the key. -// -// The equivalent WMI query is: +// FindPartitionsByVolume finds all partitions associated with the given volumes +// using MSFT_PartitionToVolume association. // -// SELECT [selectors] FROM MSFT_PartitionToVolume +// WMI association MSFT_PartitionToVolume: // -// Partition | Volume -// --------- | ------ -// MSFT_Partition (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Win...) | MSFT_Volume (ObjectId = "{1}\\WIN-8E2EVAQ9QS... +// Partition | Volume +// --------- | ------ +// MSFT_Partition (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Win...) | MSFT_Volume (ObjectId = "{1}\\WIN-8E2EVAQ9QS... // // Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partitiontovolume // for the WMI class definition. -func ListVolumeToPartitionMappings() (map[string]string, error) { - return ListWMIInstanceMappings(WMINamespaceStorage, "MSFT_PartitionToVolume", nil, - mappingObjectRefIndexer("Volume", "MSFT_Volume", "ObjectId"), - mappingObjectRefIndexer("Partition", "MSFT_Partition", "ObjectId"), - ) -} - -// FindPartitionsByVolume finds all partitions associated with the given volumes -// using partition-to-volume mapping. -func FindPartitionsByVolume(partitions []*storage.MSFT_Partition, volumes []*storage.MSFT_Volume) ([]*storage.MSFT_Partition, error) { - var partitionInstances []*cim.WmiInstance - for _, part := range partitions { - partitionInstances = append(partitionInstances, part.WmiInstance) - } - - var volumeInstances []*cim.WmiInstance - for _, volume := range volumes { - volumeInstances = append(volumeInstances, volume.WmiInstance) - } - - partitionToVolumeMappings, err := ListPartitionToVolumeMappings() - if err != nil { - return nil, err - } - - filtered, err := FindInstancesByObjectIDMapping(partitionInstances, volumeInstances, partitionToVolumeMappings) - if err != nil { - return nil, err - } - +func FindPartitionsByVolume(volumes []*storage.MSFT_Volume) ([]*storage.MSFT_Partition, error) { var result []*storage.MSFT_Partition - for _, instance := range filtered { - part, err := storage.NewMSFT_PartitionEx1(instance) + for _, vol := range volumes { + collection, err := vol.GetAssociated("MSFT_PartitionToVolume", "MSFT_Partition", "Partition", "Volume") if err != nil { - return nil, fmt.Errorf("failed to query partition %v. error: %v", instance, err) + return nil, fmt.Errorf("failed to query associated partition for %v. error: %v", vol, err) } - result = append(result, part) + for _, instance := range collection { + part, err := storage.NewMSFT_PartitionEx1(instance) + if err != nil { + return nil, fmt.Errorf("failed to query partition %v. error: %v", instance, err) + } + + result = append(result, part) + } } return result, nil } // FindVolumesByPartition finds all volumes associated with the given partitions -// using volume-to-partition mapping. -func FindVolumesByPartition(volumes []*storage.MSFT_Volume, partitions []*storage.MSFT_Partition) ([]*storage.MSFT_Volume, error) { - var volumeInstances []*cim.WmiInstance - for _, volume := range volumes { - volumeInstances = append(volumeInstances, volume.WmiInstance) - } - - var partitionInstances []*cim.WmiInstance - for _, part := range partitions { - partitionInstances = append(partitionInstances, part.WmiInstance) - } - - volumeToPartitionMappings, err := ListVolumeToPartitionMappings() - if err != nil { - return nil, err - } - - filtered, err := FindInstancesByObjectIDMapping(volumeInstances, partitionInstances, volumeToPartitionMappings) - if err != nil { - return nil, err - } - +// using MSFT_PartitionToVolume association. +// +// WMI association MSFT_PartitionToVolume: +// +// Partition | Volume +// --------- | ------ +// MSFT_Partition (ObjectId = "{1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Win...) | MSFT_Volume (ObjectId = "{1}\\WIN-8E2EVAQ9QS... +// +// Refer to https://learn.microsoft.com/en-us/windows-hardware/drivers/storage/msft-partitiontovolume +// for the WMI class definition. +func FindVolumesByPartition(partitions []*storage.MSFT_Partition) ([]*storage.MSFT_Volume, error) { var result []*storage.MSFT_Volume - for _, instance := range filtered { - volume, err := storage.NewMSFT_VolumeEx1(instance) + for _, part := range partitions { + collection, err := part.GetAssociated("MSFT_PartitionToVolume", "MSFT_Volume", "Volume", "Partition") if err != nil { - return nil, fmt.Errorf("failed to query volume %v. error: %v", instance, err) + return nil, fmt.Errorf("failed to query associated volumes for %v. error: %v", part, err) } - result = append(result, volume) + for _, instance := range collection { + volume, err := storage.NewMSFT_VolumeEx1(instance) + if err != nil { + return nil, fmt.Errorf("failed to query volume %v. error: %v", instance, err) + } + + result = append(result, volume) + } } return result, nil } // GetPartitionByVolumeUniqueID retrieves a specific partition from a volume identified by its unique ID. -func GetPartitionByVolumeUniqueID(volumeID string, partitionSelectorList []string) (*storage.MSFT_Partition, error) { +func GetPartitionByVolumeUniqueID(volumeID string) (*storage.MSFT_Partition, error) { volume, err := QueryVolumeByUniqueID(volumeID, []string{"ObjectId"}) if err != nil { return nil, err } - partitions, err := ListPartitionsWithFilters(partitionSelectorList) + result, err := FindPartitionsByVolume([]*storage.MSFT_Volume{volume}) if err != nil { return nil, err } - result, err := FindPartitionsByVolume(partitions, []*storage.MSFT_Volume{volume}) - if err != nil { - return nil, err + if len(result) == 0 { + return nil, errors.NotFound } return result[0], nil @@ -269,12 +221,7 @@ func GetVolumeByDriveLetter(driveLetter string, partitionSelectorList []string) return nil, err } - volumes, err := ListVolumes(partitionSelectorList) - if err != nil { - return nil, err - } - - result, err := FindVolumesByPartition(volumes, partitions) + result, err := FindVolumesByPartition(partitions) if err != nil { return nil, err } diff --git a/pkg/cim/wmi.go b/pkg/cim/wmi.go index 81e17701..1dacce8a 100644 --- a/pkg/cim/wmi.go +++ b/pkg/cim/wmi.go @@ -5,7 +5,6 @@ package cim import ( "fmt" - "strings" "github.com/go-ole/go-ole" "github.com/go-ole/go-ole/oleutil" @@ -16,21 +15,18 @@ import ( ) const ( - WMINamespaceRoot = "Root\\CimV2" + WMINamespaceCimV2 = "Root\\CimV2" WMINamespaceStorage = "Root\\Microsoft\\Windows\\Storage" WMINamespaceSmb = "Root\\Microsoft\\Windows\\Smb" ) type InstanceHandler func(instance *cim.WmiInstance) (bool, error) -// An InstanceIndexer provides index key to a WMI Instance in a map -type InstanceIndexer func(instance *cim.WmiInstance) (string, error) - // NewWMISession creates a new local WMI session for the given namespace, defaulting // to root namespace if none specified. func NewWMISession(namespace string) (*cim.WmiSession, error) { if namespace == "" { - namespace = WMINamespaceRoot + namespace = WMINamespaceCimV2 } sessionManager := cim.NewWmiSessionManager() @@ -247,122 +243,3 @@ func IgnoreNotFound(err error) error { } return err } - -// parseObjectRef extracts the object ID from a WMI object reference string. -// The result string is in this format -// {1}\\WIN-8E2EVAQ9QSB\ROOT/Microsoft/Windows/Storage/Providers_v2\WSP_Partition.ObjectId="{b65bb3cd-da86-11ee-854b-806e6f6e6963}:PR:{00000000-0000-0000-0000-100000000000}\\?\scsi#disk&ven_vmware&prod_virtual_disk#4&2c28f6c4&0&000000#{53f56307-b6bf-11d0-94f2-00a0c91efb8b}" -// from an escape string -func parseObjectRef(input, objectClass, refName string) (string, error) { - tokens := strings.Split(input, fmt.Sprintf("%s.%s=", objectClass, refName)) - if len(tokens) < 2 { - return "", fmt.Errorf("invalid object ID value: %s", input) - } - - objectID := tokens[1] - objectID = strings.ReplaceAll(objectID, "\\\"", "\"") - objectID = strings.ReplaceAll(objectID, "\\\\", "\\") - objectID = objectID[1 : len(objectID)-1] - return objectID, nil -} - -// ListWMIInstanceMappings queries WMI instances and creates a map using custom indexing functions -// to extract keys and values from each instance. -func ListWMIInstanceMappings(namespace, mappingClassName string, selectorList []string, keyIndexer InstanceIndexer, valueIndexer InstanceIndexer) (map[string]string, error) { - q := query.NewWmiQueryWithSelectList(mappingClassName, selectorList) - mappingInstances, err := QueryInstances(namespace, q) - if err != nil { - return nil, err - } - - result := make(map[string]string) - for _, mapping := range mappingInstances { - key, err := keyIndexer(mapping) - if err != nil { - return nil, err - } - - value, err := valueIndexer(mapping) - if err != nil { - return nil, err - } - - result[key] = value - } - - return result, nil -} - -// FindInstancesByMapping filters instances based on a mapping relationship, -// matching instances through custom indexing and mapping functions. -func FindInstancesByMapping(instanceToFind []*cim.WmiInstance, instanceToFindIndex InstanceIndexer, associatedInstances []*cim.WmiInstance, associatedInstanceIndexer InstanceIndexer, instanceMappings map[string]string) ([]*cim.WmiInstance, error) { - associatedInstanceObjectIDMapping := map[string]*cim.WmiInstance{} - for _, inst := range associatedInstances { - key, err := associatedInstanceIndexer(inst) - if err != nil { - return nil, err - } - - associatedInstanceObjectIDMapping[key] = inst - } - - var filtered []*cim.WmiInstance - for _, inst := range instanceToFind { - key, err := instanceToFindIndex(inst) - if err != nil { - return nil, err - } - - valueObjectID, ok := instanceMappings[key] - if !ok { - continue - } - - _, ok = associatedInstanceObjectIDMapping[strings.ToUpper(valueObjectID)] - if !ok { - continue - } - filtered = append(filtered, inst) - } - - if len(filtered) == 0 { - return nil, errors.NotFound - } - - return filtered, nil -} - -// mappingObjectRefIndexer indexes an WMI object by the Object ID reference from a specified property. -func mappingObjectRefIndexer(propertyName, className, refName string) InstanceIndexer { - return func(instance *cim.WmiInstance) (string, error) { - valueVal, err := instance.GetProperty(propertyName) - if err != nil { - return "", err - } - - refValue, err := parseObjectRef(valueVal.(string), className, refName) - return strings.ToUpper(refValue), err - } -} - -// stringPropertyIndexer indexes a WMI object from a string property. -func stringPropertyIndexer(propertyName string) InstanceIndexer { - return func(instance *cim.WmiInstance) (string, error) { - valueVal, err := instance.GetProperty(propertyName) - if err != nil { - return "", err - } - - return strings.ToUpper(valueVal.(string)), err - } -} - -var ( - // objectIDPropertyIndexer indexes a WMI object from its ObjectId property. - objectIDPropertyIndexer = stringPropertyIndexer("ObjectId") -) - -// FindInstancesByObjectIDMapping filters instances based on ObjectId mapping -// between two sets of WMI instances. -func FindInstancesByObjectIDMapping(instanceToFind []*cim.WmiInstance, associatedInstances []*cim.WmiInstance, instanceMappings map[string]string) ([]*cim.WmiInstance, error) { - return FindInstancesByMapping(instanceToFind, objectIDPropertyIndexer, associatedInstances, objectIDPropertyIndexer, instanceMappings) -} diff --git a/pkg/os/volume/api.go b/pkg/os/volume/api.go index 5bdf0e04..fcd2e6f8 100644 --- a/pkg/os/volume/api.go +++ b/pkg/os/volume/api.go @@ -74,17 +74,12 @@ func (VolumeAPI) ListVolumesOnDisk(diskNumber uint32, partitionNumber uint32) (v return nil, errors.Wrapf(err, "failed to list partition on disk %d", diskNumber) } - volumes, err := cim.ListVolumes([]string{"ObjectId", "UniqueId"}) - if err != nil { - return nil, errors.Wrapf(err, "failed to list volumes") - } - - filtered, err := cim.FindVolumesByPartition(volumes, partitions) + volumes, err := cim.FindVolumesByPartition(partitions) if cim.IgnoreNotFound(err) != nil { return nil, errors.Wrapf(err, "failed to list volumes on disk %d", diskNumber) } - for _, volume := range filtered { + for _, volume := range volumes { uniqueID, err := volume.GetPropertyUniqueId() if err != nil { return nil, errors.Wrapf(err, "failed to list volumes") @@ -192,7 +187,7 @@ func (VolumeAPI) UnmountVolume(volumeID, path string) error { func (VolumeAPI) ResizeVolume(volumeID string, size int64) error { var err error var finalSize int64 - part, err := cim.GetPartitionByVolumeUniqueID(volumeID, nil) + part, err := cim.GetPartitionByVolumeUniqueID(volumeID) if err != nil { return err } @@ -297,7 +292,7 @@ func (VolumeAPI) GetVolumeStats(volumeID string) (int64, int64, error) { // GetDiskNumberFromVolumeID - gets the disk number where the volume is. func (VolumeAPI) GetDiskNumberFromVolumeID(volumeID string) (uint32, error) { // get the size and sizeRemaining for the volume - part, err := cim.GetPartitionByVolumeUniqueID(volumeID, []string{"DiskNumber"}) + part, err := cim.GetPartitionByVolumeUniqueID(volumeID) if err != nil { return 0, err }