diff --git a/pkg/cim/system.go b/pkg/cim/system.go index 3ab32af6..80181794 100644 --- a/pkg/cim/system.go +++ b/pkg/cim/system.go @@ -10,6 +10,22 @@ import ( "github.com/microsoft/wmi/server2019/root/cimv2" ) +var ( + BIOSSelectorList = []string{"SerialNumber"} + ServiceSelectorList = []string{"DisplayName", "State", "StartMode"} +) + +type ServiceInterface interface { + GetPropertyName() (string, error) + GetPropertyDisplayName() (string, error) + GetPropertyState() (string, error) + GetPropertyStartMode() (string, error) + GetDependents() ([]ServiceInterface, error) + StartService() (result uint32, err error) + StopService() (result uint32, err error) + Refresh() error +} + // QueryBIOSElement retrieves the BIOS element. // // The equivalent WMI query is: @@ -33,6 +49,11 @@ func QueryBIOSElement(selectorList []string) (*cimv2.CIM_BIOSElement, error) { return bios, err } +// GetBIOSSerialNumber returns the BIOS serial number. +func GetBIOSSerialNumber(bios *cimv2.CIM_BIOSElement) (string, error) { + return bios.GetPropertySerialNumber() +} + // QueryServiceByName retrieves a specific service by its name. // // The equivalent WMI query is: @@ -55,3 +76,60 @@ func QueryServiceByName(name string, selectorList []string) (*cimv2.Win32_Servic return service, err } + +// GetServiceName returns the name of a service. +func GetServiceName(service ServiceInterface) (string, error) { + return service.GetPropertyName() +} + +// GetServiceDisplayName returns the display name of a service. +func GetServiceDisplayName(service ServiceInterface) (string, error) { + return service.GetPropertyDisplayName() +} + +// GetServiceState returns the state of a service. +func GetServiceState(service ServiceInterface) (string, error) { + return service.GetPropertyState() +} + +// GetServiceStartMode returns the start mode of a service. +func GetServiceStartMode(service ServiceInterface) (string, error) { + return service.GetPropertyStartMode() +} + +// Win32Service wraps the WMI class Win32_Service (mainly for testing) +type Win32Service struct { + *cimv2.Win32_Service +} + +func (s *Win32Service) GetDependents() ([]ServiceInterface, error) { + collection, err := s.GetAssociated("Win32_DependentService", "Win32_Service", "Dependent", "Antecedent") + if err != nil { + return nil, err + } + + var result []ServiceInterface + for _, coll := range collection { + service, err := cimv2.NewWin32_ServiceEx1(coll) + if err != nil { + return nil, err + } + + result = append(result, &Win32Service{ + service, + }) + } + return result, nil +} + +type Win32ServiceFactory struct { +} + +func (impl Win32ServiceFactory) GetService(name string) (ServiceInterface, error) { + service, err := QueryServiceByName(name, ServiceSelectorList) + if err != nil { + return nil, err + } + + return &Win32Service{Win32_Service: service}, nil +} diff --git a/pkg/os/system/api.go b/pkg/os/system/api.go index a09c83a4..2a8ddaf8 100644 --- a/pkg/os/system/api.go +++ b/pkg/os/system/api.go @@ -2,10 +2,12 @@ package system import ( "fmt" + "time" "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/kubernetes-csi/csi-proxy/pkg/server/system/impl" - "github.com/kubernetes-csi/csi-proxy/pkg/utils" + "github.com/pkg/errors" + "k8s.io/klog/v2" ) // Implements the System OS API calls. All code here should be very simple @@ -24,6 +26,29 @@ type ServiceInfo struct { Status uint32 `json:"Status"` } +type stateCheckFunc func(cim.ServiceInterface, string) (bool, string, error) +type stateTransitionFunc func(cim.ServiceInterface) error + +const ( + // startServiceErrorCodeAccepted indicates the request is accepted + startServiceErrorCodeAccepted = 0 + + // startServiceErrorCodeAlreadyRunning indicates a service is already running + startServiceErrorCodeAlreadyRunning = 10 + + // stopServiceErrorCodeAccepted indicates the request is accepted + stopServiceErrorCodeAccepted = 0 + + // stopServiceErrorCodeStopPending indicates the request cannot be sent to the service because the state of the service is 0,1,2 (pending) + stopServiceErrorCodeStopPending = 5 + + // stopServiceErrorCodeDependentRunning indicates a service cannot be stopped as its dependents may still be running + stopServiceErrorCodeDependentRunning = 3 + + serviceStateRunning = "Running" + serviceStateStopped = "Stopped" +) + var ( startModeMappings = map[string]uint32{ "Boot": impl.START_TYPE_BOOT, @@ -33,16 +58,20 @@ var ( "Disabled": impl.START_TYPE_DISABLED, } - statusMappings = map[string]uint32{ - "Unknown": impl.SERVICE_STATUS_UNKNOWN, - "Stopped": impl.SERVICE_STATUS_STOPPED, - "Start Pending": impl.SERVICE_STATUS_START_PENDING, - "Stop Pending": impl.SERVICE_STATUS_STOP_PENDING, - "Running": impl.SERVICE_STATUS_RUNNING, - "Continue Pending": impl.SERVICE_STATUS_CONTINUE_PENDING, - "Pause Pending": impl.SERVICE_STATUS_PAUSE_PENDING, - "Paused": impl.SERVICE_STATUS_PAUSED, + stateMappings = map[string]uint32{ + "Unknown": impl.SERVICE_STATUS_UNKNOWN, + serviceStateStopped: impl.SERVICE_STATUS_STOPPED, + "Start Pending": impl.SERVICE_STATUS_START_PENDING, + "Stop Pending": impl.SERVICE_STATUS_STOP_PENDING, + serviceStateRunning: impl.SERVICE_STATUS_RUNNING, + "Continue Pending": impl.SERVICE_STATUS_CONTINUE_PENDING, + "Pause Pending": impl.SERVICE_STATUS_PAUSE_PENDING, + "Paused": impl.SERVICE_STATUS_PAUSED, } + + serviceStateCheckInternal = 200 * time.Millisecond + serviceStateCheckTimeout = 30 * time.Second + errTimedOut = errors.New("Timed out") ) func serviceStartModeToStartType(startMode string) uint32 { @@ -50,22 +79,40 @@ func serviceStartModeToStartType(startMode string) uint32 { } func serviceState(status string) uint32 { - return statusMappings[status] + return stateMappings[status] } -type APIImplementor struct{} +type ServiceManager interface { + WaitUntilServiceState(cim.ServiceInterface, stateTransitionFunc, stateCheckFunc, time.Duration, time.Duration) (string, error) + GetDependentsForService(string) ([]string, error) +} + +type ServiceFactory interface { + GetService(name string) (cim.ServiceInterface, error) +} + +type APIImplementor struct { + serviceFactory ServiceFactory + serviceManager ServiceManager +} func New() APIImplementor { - return APIImplementor{} + serviceFactory := cim.Win32ServiceFactory{} + return APIImplementor{ + serviceFactory: serviceFactory, + serviceManager: ServiceManagerImpl{ + serviceFactory: serviceFactory, + }, + } } func (APIImplementor) GetBIOSSerialNumber() (string, error) { - bios, err := cim.QueryBIOSElement([]string{"SerialNumber"}) + bios, err := cim.QueryBIOSElement(cim.BIOSSelectorList) if err != nil { return "", fmt.Errorf("failed to get BIOS element: %w", err) } - sn, err := bios.GetPropertySerialNumber() + sn, err := cim.GetBIOSSerialNumber(bios) if err != nil { return "", fmt.Errorf("failed to get BIOS serial number property: %w", err) } @@ -73,23 +120,23 @@ func (APIImplementor) GetBIOSSerialNumber() (string, error) { return sn, nil } -func (APIImplementor) GetService(name string) (*ServiceInfo, error) { - service, err := cim.QueryServiceByName(name, []string{"DisplayName", "State", "StartMode"}) +func (impl APIImplementor) GetService(name string) (*ServiceInfo, error) { + service, err := impl.serviceFactory.GetService(name) if err != nil { - return nil, fmt.Errorf("failed to get service %s: %w", name, err) + return nil, fmt.Errorf("failed to get service %s. error: %w", name, err) } - displayName, err := service.GetPropertyDisplayName() + displayName, err := cim.GetServiceDisplayName(service) if err != nil { return nil, fmt.Errorf("failed to get displayName property of service %s: %w", name, err) } - state, err := service.GetPropertyState() + state, err := cim.GetServiceState(service) if err != nil { return nil, fmt.Errorf("failed to get state property of service %s: %w", name, err) } - startMode, err := service.GetPropertyStartMode() + startMode, err := cim.GetServiceStartMode(service) if err != nil { return nil, fmt.Errorf("failed to get startMode property of service %s: %w", name, err) } @@ -101,24 +148,198 @@ func (APIImplementor) GetService(name string) (*ServiceInfo, error) { }, nil } -func (APIImplementor) StartService(name string) error { - // Note: both StartService and StopService are not implemented by WMI - script := `Start-Service -Name $env:ServiceName` - cmdEnv := fmt.Sprintf("ServiceName=%s", name) - out, err := utils.RunPowershellCmd(script, cmdEnv) +func (impl APIImplementor) StartService(name string) error { + startService := func(service cim.ServiceInterface) error { + retVal, err := service.StartService() + if err != nil || (retVal != startServiceErrorCodeAccepted && retVal != startServiceErrorCodeAlreadyRunning) { + return fmt.Errorf("error starting service name %s. return value: %d, error: %w", name, retVal, err) + } + return nil + } + serviceRunningCheck := func(service cim.ServiceInterface, state string) (bool, string, error) { + err := service.Refresh() + if err != nil { + return false, "", err + } + + newState, err := cim.GetServiceState(service) + if err != nil { + return false, state, err + } + + klog.V(6).Infof("service (%v) state check: %s => %s", service, state, newState) + return state == serviceStateRunning, newState, err + } + + service, err := impl.serviceFactory.GetService(name) if err != nil { - return fmt.Errorf("error starting service name=%s. cmd: %s, output: %s, error: %v", name, script, string(out), err) + return fmt.Errorf("failed to get service %s. error: %w", name, err) + } + + state, err := impl.serviceManager.WaitUntilServiceState(service, startService, serviceRunningCheck, serviceStateCheckInternal, serviceStateCheckTimeout) + if err != nil && !errors.Is(err, errTimedOut) { + return fmt.Errorf("failed to wait for service %s state change. error: %w", name, err) + } + + if state != serviceStateRunning { + return fmt.Errorf("timed out waiting for service %s to become running", name) } return nil } -func (APIImplementor) StopService(name string, force bool) error { - script := `Stop-Service -Name $env:ServiceName -Force:$([System.Convert]::ToBoolean($env:Force))` - out, err := utils.RunPowershellCmd(script, fmt.Sprintf("ServiceName=%s", name), fmt.Sprintf("Force=%t", force)) +func (impl APIImplementor) stopSingleService(name string) (bool, error) { + var dependentRunning bool + stopService := func(service cim.ServiceInterface) error { + retVal, err := service.StopService() + if err != nil || (retVal != stopServiceErrorCodeAccepted && retVal != stopServiceErrorCodeStopPending) { + if retVal == stopServiceErrorCodeDependentRunning { + dependentRunning = true + return fmt.Errorf("error stopping service %s as dependent services are not stopped", name) + } + return fmt.Errorf("error stopping service %s. return value: %d, error: %w", name, retVal, err) + } + return nil + } + serviceStoppedCheck := func(service cim.ServiceInterface, state string) (bool, string, error) { + err := service.Refresh() + if err != nil { + return false, "", fmt.Errorf("error refresh service %s instance. error: %w", name, err) + } + + newState, err := cim.GetServiceState(service) + if err != nil { + return false, state, fmt.Errorf("error getting service %s state. error: %w", name, err) + } + + klog.V(6).Infof("service (%v) state check: %s => %s", service, state, newState) + return newState == serviceStateStopped, newState, nil + } + + service, err := impl.serviceFactory.GetService(name) + if err != nil { + return dependentRunning, fmt.Errorf("failed to get service %s. error: %w", name, err) + } + + state, err := impl.serviceManager.WaitUntilServiceState(service, stopService, serviceStoppedCheck, serviceStateCheckInternal, serviceStateCheckTimeout) + if err != nil && !errors.Is(err, errTimedOut) { + return dependentRunning, fmt.Errorf("error stopping service name %s. current state: %s", name, state) + } + + if state != serviceStateStopped { + return dependentRunning, fmt.Errorf("timed out waiting for service %s to stop", name) + } + + return dependentRunning, nil +} + +func (impl APIImplementor) StopService(name string, force bool) error { + dependentRunning, err := impl.stopSingleService(name) + if err == nil { + return nil + } + if !dependentRunning || !force { + return fmt.Errorf("failed to stop service %s. error: %w", name, err) + } + + serviceNames, err := impl.serviceManager.GetDependentsForService(name) if err != nil { - return fmt.Errorf("error stopping service name=%s. cmd: %s, output: %s, error: %v", name, script, string(out), err) + return fmt.Errorf("error getting dependent services for service name %s", name) + } + + for _, serviceName := range serviceNames { + _, err = impl.stopSingleService(serviceName) + if err != nil { + return fmt.Errorf("failed to stop service %s. error: %w", name, err) + } } return nil } + +type ServiceManagerImpl struct { + serviceFactory ServiceFactory +} + +func (impl ServiceManagerImpl) WaitUntilServiceState(service cim.ServiceInterface, stateTransition stateTransitionFunc, stateCheck stateCheckFunc, interval time.Duration, timeout time.Duration) (string, error) { + done, state, err := stateCheck(service, "") + if err != nil { + return state, fmt.Errorf("service %v state check failed: %w", service, err) + } + if done { + return state, nil + } + + // Perform transition if not already in desired state + if err := stateTransition(service); err != nil { + return state, fmt.Errorf("service %v state transition failed: %w", service, err) + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + timeoutChan := time.After(timeout) + + for { + select { + case <-ticker.C: + klog.V(6).Infof("Checking service (%v) state...", service) + done, state, err = stateCheck(service, state) + if err != nil { + return state, fmt.Errorf("service %v state check failed: %w", service, err) + } + if done { + klog.V(6).Infof("service (%v) state is %s and transition done.", service, state) + return state, nil + } + case <-timeoutChan: + done, state, err = stateCheck(service, state) + return state, errTimedOut + } + } +} + +func (impl ServiceManagerImpl) GetDependentsForService(name string) ([]string, error) { + var serviceNames []string + var servicesToCheck []cim.ServiceInterface + servicesByName := map[string]string{} + + service, err := impl.serviceFactory.GetService(name) + if err != nil { + return serviceNames, fmt.Errorf("failed to get service %s. error: %w", name, err) + } + + servicesToCheck = append(servicesToCheck, service) + i := 0 + for i < len(servicesToCheck) { + service = servicesToCheck[i] + i += 1 + + serviceName, err := cim.GetServiceName(service) + if err != nil { + return serviceNames, fmt.Errorf("error getting service name %v. error: %w", service, err) + } + + currentState, err := cim.GetServiceState(service) + if err != nil { + return serviceNames, fmt.Errorf("error getting service %s state. error: %w", serviceName, err) + } + + if currentState != serviceStateRunning { + continue + } + + servicesByName[serviceName] = serviceName + // prepend the current service to the front + serviceNames = append([]string{serviceName}, serviceNames...) + + dependents, err := service.GetDependents() + if err != nil { + return serviceNames, fmt.Errorf("error getting service %s dependents. error: %w", serviceName, err) + } + + servicesToCheck = append(servicesToCheck, dependents...) + } + + return serviceNames, nil +} diff --git a/pkg/os/system/api_test.go b/pkg/os/system/api_test.go new file mode 100644 index 00000000..7b2992bb --- /dev/null +++ b/pkg/os/system/api_test.go @@ -0,0 +1,224 @@ +package system + +import ( + "fmt" + "os" + "strings" + "testing" + "time" + + "github.com/kubernetes-csi/csi-proxy/pkg/cim" + "github.com/pkg/errors" +) + +type MockService struct { + Name string + DisplayName string + State string + StartMode string + Dependents []cim.ServiceInterface + + StartResult uint32 + StopResult uint32 + + Err error +} + +func (m *MockService) GetPropertyName() (string, error) { + return m.Name, m.Err +} + +func (m *MockService) GetPropertyDisplayName() (string, error) { + return m.DisplayName, m.Err +} + +func (m *MockService) GetPropertyState() (string, error) { + return m.State, m.Err +} + +func (m *MockService) GetPropertyStartMode() (string, error) { + return m.StartMode, m.Err +} + +func (m *MockService) GetDependents() ([]cim.ServiceInterface, error) { + return m.Dependents, m.Err +} + +func (m *MockService) StartService() (uint32, error) { + m.State = "Running" + return m.StartResult, m.Err +} + +func (m *MockService) StopService() (uint32, error) { + m.State = "Stopped" + return m.StopResult, m.Err +} + +func (m *MockService) Refresh() error { + return nil +} + +var _ cim.ServiceInterface = &MockService{} + +type MockServiceFactory struct { + Services map[string]cim.ServiceInterface + Err error +} + +func (f *MockServiceFactory) GetService(name string) (cim.ServiceInterface, error) { + svc, ok := f.Services[name] + if !ok { + return nil, fmt.Errorf("service not found: %s", name) + } + return svc, f.Err +} + +var _ ServiceFactory = &MockServiceFactory{} + +func TestWaitUntilServiceState_Success(t *testing.T) { + svc := &MockService{Name: "svc", State: "Stopped"} + + stateChanged := false + + stateCheck := func(_ cim.ServiceInterface, _ string) (bool, string, error) { + if stateChanged { + svc.State = serviceStateRunning + return true, svc.State, nil + } + return false, svc.State, nil + } + + stateTransition := func(_ cim.ServiceInterface) error { + stateChanged = true + return nil + } + + impl := ServiceManagerImpl{} + state, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 500*time.Millisecond) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if state != serviceStateRunning { + t.Fatalf("expected state %q, got %q", serviceStateRunning, state) + } +} + +func TestWaitUntilServiceState_Timeout(t *testing.T) { + svc := &MockService{Name: "svc", State: "Stopped"} + + stateCheck := func(_ cim.ServiceInterface, _ string) (bool, string, error) { + return false, svc.State, nil + } + + stateTransition := func(_ cim.ServiceInterface) error { + return nil + } + + impl := ServiceManagerImpl{} + state, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 50*time.Millisecond) + if !errors.Is(err, errTimedOut) { + t.Fatalf("expected timeout error, got %v", err) + } + if state != svc.State { + t.Fatalf("expected state %q, got %q", svc.State, state) + } +} + +func TestWaitUntilServiceState_TransitionFails(t *testing.T) { + svc := &MockService{Name: "svc", State: "Stopped"} + + stateCheck := func(_ cim.ServiceInterface, _ string) (bool, string, error) { + return false, svc.State, nil + } + + stateTransition := func(_ cim.ServiceInterface) error { + return fmt.Errorf("transition failed") + } + + impl := ServiceManagerImpl{} + _, err := impl.WaitUntilServiceState(svc, stateTransition, stateCheck, 10*time.Millisecond, 50*time.Millisecond) + if err == nil || !strings.Contains(err.Error(), "transition failed") { + t.Fatalf("expected transition error, got %v", err) + } +} + +func TestGetDependentsForService(t *testing.T) { + // Construct the dependency tree + svcC := &MockService{Name: "C", State: serviceStateRunning} + svcB := &MockService{Name: "B", State: serviceStateRunning, Dependents: []cim.ServiceInterface{svcC}} + svcA := &MockService{Name: "A", State: serviceStateRunning, Dependents: []cim.ServiceInterface{svcB}} + + factory := &MockServiceFactory{ + Services: map[string]cim.ServiceInterface{ + "A": svcA, + "B": svcB, + "C": svcC, + }, + } + + impl := ServiceManagerImpl{ + serviceFactory: factory, + } + + names, err := impl.GetDependentsForService("A") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []string{"C", "B", "A"} + if len(names) != len(expected) { + t.Fatalf("expected %d services, got %d", len(expected), len(names)) + } + for i, name := range expected { + if names[i] != name { + t.Errorf("expected %s at position %d, got %s", name, i, names[i]) + } + } +} + +func TestGetDependentsForService_SkipsNonRunning(t *testing.T) { + svcB := &MockService{Name: "B", State: "Stopped"} + svcA := &MockService{Name: "A", State: serviceStateRunning, Dependents: []cim.ServiceInterface{svcB}} + + factory := &MockServiceFactory{ + Services: map[string]cim.ServiceInterface{ + "A": svcA, + "B": svcB, + }, + } + + impl := ServiceManagerImpl{ + serviceFactory: factory, + } + + names, err := impl.GetDependentsForService("A") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := []string{"A"} // B is skipped due to stopped state + if len(names) != len(expected) { + t.Fatalf("expected %d services, got %d", len(expected), len(names)) + } +} + +func TestGetDependenciesForService_Winmgmt(t *testing.T) { + if strings.ToLower(os.Getenv("TEST_MULTI_SERVICE_DEPENDENTS")) != "true" { + t.Skipf("Test skipped") + } + + impl := ServiceManagerImpl{ + serviceFactory: cim.Win32ServiceFactory{}, + } + + serviceName := "Winmgmt" + names, err := impl.GetDependentsForService(serviceName) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expected := 4 + if len(names) != expected || names[len(names)-1] != serviceName { + t.Fatalf("expected %d services, got %d", expected, len(names)) + } +}