Description
Hi everyone,
I am currently designing Scikit Learn estimators using Julia functions with the help of the @PyDEF macro.
The purpose of it is being able to integrate my Julia code into the ScikitLearn's interface for pipelines, gird search, etc...
I followed the tutorial from this homepage and everything worked fine up to a point.
I was able to instantiate my custom estimator, integrate it in a pipeline and call fit, transform, predict, etc... on it.
More information on Scikit Learn's estimator and how to design a custom one.
Here is a toy example of a custom estimator using a Julia function:
using ScikitLearn, PyCall, Statistics
using ScikitLearn.Pipelines: Pipeline
@sk_import base: BaseEstimator
@sk_import base: TransformerMixin
# Define a function in Julia to put into a ScikitLearn estimator
f(X, a, b) = @. a * X + b
"""
A Dummy transformer that applies the function f to the data
"""
@pydef mutable struct FunctionTransformer <: (TransformerMixin, BaseEstimator)
function __init__(self, b)
self.b = b
end
function fit(self, X, y=nothing)
self.a_ = mean(X)
return self
end
function transform(self, X)
return f(X, self.a_, self.b)
end
function fit_transform(self, X, y)
self.fit(X, y)
return self.transform(X)
end
end
Here is the part which is working:
# Try the transformer
X = zeros(Int, 3, 2)
transformer = FunctionTransformer(-1)
fit!(transformer, X)
transform(transformer, X)
# Try it in a pipeline
pipe = Pipeline([("Dummy", transformer)])
fit!(transformer, X)
transform(transformer, X)
The problem occured while I tried to use the get_params() (necessary to use the estimator in a grid search) method on my estimator. It is supposed to be inherited from the BaseEstimator class and the developper shouldn't worry about it. However, I encountered the following error:
# Try geting params
transformer.get_params()
ERROR: PyError ($(Expr(:escape, :(ccall(#= C:\Users\user.julia\packages\PyCall\zqDXB\src\pyfncall.jl:43 =# @pysym(:PyObject_Call), PyPtr, (PyPtr, PyPtr, PyPtr), o, pyargsptr, kw))))) <class 'RuntimeError'>
RuntimeError("scikit-learn estimators should always specify their parameters in the signature of their init (no varargs). <class 'FunctionTransformer'> with constructor (*args, **kwargs) doesn't follow this convention.")
File "C:\Users\user\AppData\Local\Programs\Python\Python37\lib\site-packages\sklearn\base.py", line 205, in get_params
for key in self._get_param_names():
File "C:\Users\user\AppData\Local\Programs\Python\Python37\lib\site-packages\sklearn\base.py", line 185, in _get_param_names
% (cls, init_signature))
Stacktrace:
[1] pyerr_check at C:\Users\user.julia\packages\PyCall\zqDXB\src\exception.jl:60 [inlined]
[2] pyerr_check at C:\Users\user.julia\packages\PyCall\zqDXB\src\exception.jl:64 [inlined]
[3] _handle_error(::String) at C:\Users\user.julia\packages\PyCall\zqDXB\src\exception.jl:81
[4] macro expansion at C:\Users\user.julia\packages\PyCall\zqDXB\src\exception.jl:95 [inlined]
[5] #110 at C:\Users\user.julia\packages\PyCall\zqDXB\src\pyfncall.jl:43 [inlined]
[6] disable_sigint at .\c.jl:446 [inlined]
[7] __pycall! at C:\Users\user.julia\packages\PyCall\zqDXB\src\pyfncall.jl:42 [inlined]
[8] _pycall!(::PyObject, ::PyObject, ::Tuple{}, ::Int64, ::Ptr{Nothing}) at C:\Users\user.julia\packages\PyCall\zqDXB\src\pyfncall.jl:29
[9] #call#117 at C:\Users\user.julia\packages\PyCall\zqDXB\src\pyfncall.jl:11 [inlined]
[10] (::PyObject)() at C:\Users\user.julia\packages\PyCall\zqDXB\src\pyfncall.jl:86
[11] top-level scope at REPL[20]:1
[12] eval(::Module, ::Any) at .\boot.jl:330
[13] eval_user_input(::Any, ::REPL.REPLBackend) at C:\cygwin\home\Administrator\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.2\REPL\src\REPL.jl:86
[14] run_backend(::REPL.REPLBackend) at C:\Users\user.julia\packages\Revise\Pcs5V\src\Revise.jl:1073
[15] top-level scope at REPL[1]:0
Reading the Python's error message, it seemed to me that the problem came from the Python class' constructor's, which had been automatically created thanks to PyDef. The get_params() function expect a class whose constructor params are all specified, no *args and no **kwargs.
In my julian definition of the constructor, there are no positional nor optional arguments.
I went for a little check to get the Python class' constructor params:
py"""
import inspect
cls = $FunctionTransformer
print(inspect.signature(getattr(cls.__init__, 'deprecated_original', cls.__init__)))
"""
(*args, **kwargs)
But it seems that the produced Python class' constructor has, causing this issue!
I could, off course, skip this problem by re-defining my own get_params() method, but this would on one hand be very tedious since I have several estimators each having multiple parameters to get, and on the other hand, this slight detail migth cause non-solvable problems for someone else.
Thanks by advance for the help :)