Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

n_iters_ in SVR is always saved as 10,000 #1712

Open
caspimoshe opened this issue Feb 14, 2024 · 3 comments
Open

n_iters_ in SVR is always saved as 10,000 #1712

caspimoshe opened this issue Feb 14, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@caspimoshe
Copy link

Describe the bug
The n_iter_ that is saved seems to be constant 10000

To Reproduce
import numpy as np
from sklearnex import patch_sklearn

patch_sklearn()
from sklearn.svm import SVR

svr = SVR()
X = np.random.randn(100, 5)
y = np.mean(X, axis=1)
svr.fit(X, y)
print('svr.n_iter_: ', svr.n_iter_)

Expected behavior
Describe what your are expecting from steps above

Output/Screenshots
import numpy as np
from sklearnex import patch_sklearn

patch_sklearn()
from sklearn.svm import SVR

svr = SVR()
X = np.random.randn(100, 5)
y = np.mean(X, axis=1)
svr.fit(X, y)
print(svr.n_iter_)

Environment:

  • OS: [Windows 11]
  • Compiler: [python 3.8]
  • scikit-learn==1.2.2
  • Version: [2024.1.0]
@caspimoshe caspimoshe added the bug Something isn't working label Feb 14, 2024
@Alexsandruss
Copy link
Contributor

This is expected behavior for now. Warning for n_iter_ property might be added.
https://github.com/intel/scikit-learn-intelex/blob/bfa470b1ed71e7b52f7f36a05cb2b6f00deb2047/onedal/svm/svm.py#L218-L221

@caspimoshe
Copy link
Author

This is expected behavior for now. Warning for n_iter_ property might be added.

https://github.com/intel/scikit-learn-intelex/blob/bfa470b1ed71e7b52f7f36a05cb2b6f00deb2047/onedal/svm/svm.py#L218-L221

Thanks for the fast response, Alex.
My main issue is that the max_iter parameter seem to have no effect.

Attaching code that shows it:
With patch the code ran for -0.84 seconds and managed to fit the data
`import time

import numpy as np
from sklearnex import patch_sklearn

patch_sklearn()
from sklearn.svm import SVR

patch_svr = SVR(max_iter=2)
np.random.seed(0)
X = np.random.randn(int(1e5), 5)
y = np.mean(X, axis=1)

tic = time.time()
patch_svr.fit(X, y)
print("Time to fit SVR with patch:", time.time() - tic)

score = patch_svr.score(X, y)
print("Score:", score)`
image

Without the patch the model terminated early, after 0.34 seconds and did not fit the data
`import time

import numpy as np
from sklearnex import patch_sklearn

patch_sklearn()
from sklearn.svm import SVR

patch_svr = SVR(max_iter=2)
np.random.seed(0)
X = np.random.randn(int(1e5), 5)
y = np.mean(X, axis=1)

tic = time.time()
patch_svr.fit(X, y)
print("Time to fit SVR with patch:", time.time() - tic)

score = patch_svr.score(X, y)
print("Score:", score)`
image

@Alexsandruss
Copy link
Contributor

Alexsandruss commented Feb 15, 2024

sklearn and sklearnex use different implementations of SVM, thus, different behavior on same number of iterations it expected.
SVM algorithm stopping is controlled by two parameters: maximum number of iterations (max_iter) and threshold/tolerance for stopping criterion (tol). sklearnex SVM might stop in same point while max_iters (>= stopping point) are widely different due to trigger of stopping criterion. Because of previously mentioned workaround it's impossible to know how many iterations were performed. If tol is significantly close to 0 (for example, 1e-32), stopping criterion would be unreachable and SVM training time becomes strictly proportional to max_iter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants