Skip to content

enhance jlfun2pyfun to use explicit (non-varargs) arguments when possible #795

Open
@leonardtschora

Description

@leonardtschora

Hi everyone,
I am currently designing Scikit Learn estimators using Julia functions with the help of the @PyDEF macro.
The purpose of it is being able to integrate my Julia code into the ScikitLearn's interface for pipelines, gird search, etc...

I followed the tutorial from this homepage and everything worked fine up to a point.
I was able to instantiate my custom estimator, integrate it in a pipeline and call fit, transform, predict, etc... on it.

More information on Scikit Learn's estimator and how to design a custom one.

Here is a toy example of a custom estimator using a Julia function:

using ScikitLearn, PyCall, Statistics
using ScikitLearn.Pipelines: Pipeline
@sk_import base: BaseEstimator
@sk_import base: TransformerMixin

# Define a function in Julia to put into a ScikitLearn estimator
f(X, a, b) = @. a * X + b

"""
A Dummy transformer that applies the function f to the data
"""
@pydef mutable struct FunctionTransformer <: (TransformerMixin, BaseEstimator) 
    function __init__(self, b)
        self.b = b
    end

    function fit(self, X, y=nothing)
        self.a_ = mean(X)
        return self
    end

    function transform(self, X)
        return f(X, self.a_, self.b)
    end

    function fit_transform(self, X, y)
        self.fit(X, y)
        return self.transform(X)
    end
end

Here is the part which is working:

# Try the transformer
X = zeros(Int, 3, 2)
transformer  = FunctionTransformer(-1)
fit!(transformer, X)
transform(transformer, X)

# Try it in a pipeline
pipe = Pipeline([("Dummy", transformer)])
fit!(transformer, X)
transform(transformer, X)

The problem occured while I tried to use the get_params() (necessary to use the estimator in a grid search) method on my estimator. It is supposed to be inherited from the BaseEstimator class and the developper shouldn't worry about it. However, I encountered the following error:

# Try geting params
transformer.get_params()

ERROR: PyError ($(Expr(:escape, :(ccall(#= C:\Users\user.julia\packages\PyCall\zqDXB\src\pyfncall.jl:43 =# @pysym(:PyObject_Call), PyPtr, (PyPtr, PyPtr, PyPtr), o, pyargsptr, kw))))) <class 'RuntimeError'>
RuntimeError("scikit-learn estimators should always specify their parameters in the signature of their init (no varargs). <class 'FunctionTransformer'> with constructor (*args, **kwargs) doesn't follow this convention.")
File "C:\Users\user\AppData\Local\Programs\Python\Python37\lib\site-packages\sklearn\base.py", line 205, in get_params
for key in self._get_param_names():
File "C:\Users\user\AppData\Local\Programs\Python\Python37\lib\site-packages\sklearn\base.py", line 185, in _get_param_names
% (cls, init_signature))
Stacktrace:
[1] pyerr_check at C:\Users\user.julia\packages\PyCall\zqDXB\src\exception.jl:60 [inlined]
[2] pyerr_check at C:\Users\user.julia\packages\PyCall\zqDXB\src\exception.jl:64 [inlined]
[3] _handle_error(::String) at C:\Users\user.julia\packages\PyCall\zqDXB\src\exception.jl:81
[4] macro expansion at C:\Users\user.julia\packages\PyCall\zqDXB\src\exception.jl:95 [inlined]
[5] #110 at C:\Users\user.julia\packages\PyCall\zqDXB\src\pyfncall.jl:43 [inlined]
[6] disable_sigint at .\c.jl:446 [inlined]
[7] __pycall! at C:\Users\user.julia\packages\PyCall\zqDXB\src\pyfncall.jl:42 [inlined]
[8] _pycall!(::PyObject, ::PyObject, ::Tuple{}, ::Int64, ::Ptr{Nothing}) at C:\Users\user.julia\packages\PyCall\zqDXB\src\pyfncall.jl:29
[9] #call#117 at C:\Users\user.julia\packages\PyCall\zqDXB\src\pyfncall.jl:11 [inlined]
[10] (::PyObject)() at C:\Users\user.julia\packages\PyCall\zqDXB\src\pyfncall.jl:86
[11] top-level scope at REPL[20]:1
[12] eval(::Module, ::Any) at .\boot.jl:330
[13] eval_user_input(::Any, ::REPL.REPLBackend) at C:\cygwin\home\Administrator\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.2\REPL\src\REPL.jl:86
[14] run_backend(::REPL.REPLBackend) at C:\Users\user.julia\packages\Revise\Pcs5V\src\Revise.jl:1073
[15] top-level scope at REPL[1]:0

Reading the Python's error message, it seemed to me that the problem came from the Python class' constructor's, which had been automatically created thanks to PyDef. The get_params() function expect a class whose constructor params are all specified, no *args and no **kwargs.
In my julian definition of the constructor, there are no positional nor optional arguments.
I went for a little check to get the Python class' constructor params:

py"""
import inspect
cls = $FunctionTransformer
print(inspect.signature(getattr(cls.__init__, 'deprecated_original', cls.__init__)))
"""

(*args, **kwargs)

But it seems that the produced Python class' constructor has, causing this issue!

I could, off course, skip this problem by re-defining my own get_params() method, but this would on one hand be very tedious since I have several estimators each having multiple parameters to get, and on the other hand, this slight detail migth cause non-solvable problems for someone else.

Thanks by advance for the help :)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions