Skip to content

Commit

Permalink
Initial conversion pythoncall
Browse files Browse the repository at this point in the history
  • Loading branch information
tylerjthomas9 committed Feb 28, 2023
1 parent 9d9f4dc commit dfbef0d
Show file tree
Hide file tree
Showing 16 changed files with 165 additions and 111 deletions.
16 changes: 4 additions & 12 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
name: CI
ame: CI

env:
PYTHON: Conda

on:
pull_request:
branches:
Expand All @@ -23,8 +20,10 @@ jobs:
version:
- '1.6'
- '1'
- 'nightly'
os:
- ubuntu-latest
- windows-latest
arch:
- x64
steps:
Expand All @@ -44,15 +43,8 @@ jobs:
${{ runner.os }}-test-
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
# The following is needed for Julia <=0.8.4 on Linux OS
# due to old version of libstcxx used by Julia
- name: "Export LD_LIBRARY_PATH envrioment variable"
if: ${{matrix.version == '1.6'}}
run: echo "LD_LIBRARY_PATH=/home/runner/.julia/conda/3/x86_64/lib" >> $GITHUB_ENV
- uses: julia-actions/julia-runtest@v1
env:
LD_LIBRARY_PATH: /home/runner/.julia/conda/3/x86_64/lib
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v1
with:
file: lcov.info
file: lcov.info
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Manifest.toml
*DS_Store
.CondaPkg/*
4 changes: 4 additions & 0 deletions CondaPkg.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

[deps.scikit-learn]
channel = "conda-forge"
version = ">=1"
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ version = "0.3.0"

[deps]
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
ScikitLearn = "3646fa90-6ef7-5e7e-9f22-8aca16db6324"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
MLJModelInterface = "1.4"
PyCall = "1"
ScikitLearn = "0.7"
PythonCall = "0.9"
Tables = "1"
julia = "1.6"

[extras]
Expand Down
35 changes: 15 additions & 20 deletions src/MLJScikitLearnInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,9 @@ import MLJModelInterface:
Table, Continuous, Count, Finite, OrderedFactor, Multiclass, Unknown
const MMI = MLJModelInterface

import ScikitLearn
function __init__()
ScikitLearn.Skcore.import_sklearn()
end
const SK = ScikitLearn

# Note: PyCall is already imported as part of ScikitLearn so this is cheap
import PyCall: ispynull, PyNULL, pyimport
include("ScikitLearnAPI.jl")
const SK = ScikitLearnAPI
import PythonCall: pyisnull, PyNULL, pyimport, pycopy!, pynew, pyconvert

# ------------------------------------------------------------------------
# NOTE: the next few lines of wizardry and their call should not be
Expand All @@ -22,14 +17,14 @@ import PyCall: ispynull, PyNULL, pyimport
# from which much of this stems.

# supervised
const SKLM = PyNULL()
const SKGP = PyNULL()
const SKEN = PyNULL()
const SKDU = PyNULL()
const SKNB = PyNULL()
const SKNE = PyNULL()
const SKDA = PyNULL()
const SKSV = PyNULL()
const SKLM = pynew()
const SKGP = pynew()
const SKEN = pynew()
const SKDU = pynew()
const SKNB = pynew()
const SKNE = pynew()
const SKDA = pynew()
const SKSV = pynew()
sklm(m) = (:SKLM, :linear_model, m)
skgp(m) = (:SKGP, :gaussian_process, m)
sken(m) = (:SKEN, :ensemble, m)
Expand All @@ -40,20 +35,20 @@ skda(m) = (:SKDA, :discriminant_analysis, m)
sksv(m) = (:SKSV, :svm, m)

# unsupervised
const SKCL = PyNULL()
const SKCL = pynew()
skcl(m) = (:SKCL, :cluster, m)

# Generic loader (see _skmodel_fit_* in macros)
ski!(sks, mdl) = copy!(sks, pyimport("sklearn.$mdl"))
ski!(sks, mdl) = pycopy!(sks, pyimport("sklearn.$mdl"))
# ------------------------------------------------------------------------

const Option{T} = Union{Nothing,T}

# recurrent information for traits
const PKG_NAME = "ScikitLearn"
const PKG_NAME = "MLJScikitLearnInterface"
const API_PKG_NAME = "MLJScikitLearnInterface"
const SK_UUID = "3646fa90-6ef7-5e7e-9f22-8aca16db6324"
const SK_URL = "https://github.com/cstjean/ScikitLearn.jl"
const SK_URL = "https://github.com/JuliaAI/MLJScikitLearnInterface.jl"
const SK_LIC = "BSD"

const CV = "with built-in cross-validation"
Expand Down
17 changes: 10 additions & 7 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,17 @@ function _skmodel_fit_reg(modelname, params)
sksym, skmod, mdl = $(Symbol(modelname, "_"))
# retrieve the namespace, if it's not there yet, import it
parent = eval(sksym)
ispynull(parent) && ski!(parent, skmod)
pyisnull(parent) && ski!(parent, skmod)
# retrieve the effective ScikitLearn constructor
skconstr = getproperty(parent, mdl)
# build the scikitlearn model passing all the parameters
skmodel = skconstr(
$((Expr(:kw, p, :(model.$p)) for p in params)...))
# --------------------------------------------------------------
# fit and organise results
fitres = SK.fit!(skmodel, Xmatrix, yplain)
X_py = ScikitLearnAPI.numpy.array(Xmatrix)
y_py = ScikitLearnAPI.numpy.array(yplain)
fitres = SK.fit!(skmodel, X_py, y_py)
# TODO: we may want to use the report later on
report = NamedTuple()
# the first nothing is so that we can use the same predict for
Expand All @@ -170,7 +172,7 @@ function _skmodel_fit_clf(modelname, params)
# See _skmodel_fit_reg, same story
sksym, skmod, mdl = $(Symbol(modelname, "_"))
parent = eval(sksym)
ispynull(parent) && ski!(parent, skmod)
pyisnull(parent) && ski!(parent, skmod)
skconstr = getproperty(parent, mdl)
skmodel = skconstr(
$((Expr(:kw, p, :(model.$p)) for p in params)...))
Expand Down Expand Up @@ -249,7 +251,7 @@ function _skmodel_fit_uns(modelname, params)
# See _skmodel_fit_reg, same story
sksym, skmod, mdl = $(Symbol(modelname, "_"))
parent = eval(sksym)
ispynull(parent) && ski!(parent, skmod)
pyisnull(parent) && ski!(parent, skmod)
skconstr = getproperty(parent, mdl)
skmodel = skconstr(
$((Expr(:kw, p, :(model.$p)) for p in params)...))
Expand Down Expand Up @@ -285,7 +287,8 @@ there is one supported.
macro sku_inverse_transform(modelname)
quote
function MMI.inverse_transform(::$modelname, fitres, X)
X = SK.inverse_transform(fitres, MMI.matrix(X))
X_py = ScikitLearnAPI.numpy.array(MMI.matrix(X))
X = SK.inverse_transform(fitres, X_py)
MMI.table(X)
end
end
Expand Down Expand Up @@ -313,10 +316,10 @@ macro sku_predict(modelname)
if sm in (:Birch, :KMeans, :MiniBatchKMeans)
catv = MMI.categorical(1:m.n_clusters)
elseif sm == :AffinityPropagation
nc = length(fitres.cluster_centers_indices_)
nc = length(pyconvert(Array, fitres.cluster_centers_))
catv = MMI.categorical(1:nc)
elseif sm == :MeanShift
nc = size(fitres.cluster_centers_, 1)
nc = size(pyconvert(Array, fitres.cluster_centers_), 1)
catv = MMI.categorical(1:nc)
else
throw(ArgumentError("Model $sm does not support `predict`."))
Expand Down
52 changes: 52 additions & 0 deletions src/models/ScikitLearnAPI.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

module ScikitLearnAPI


using Tables
using PythonCall

const numpy = PythonCall.pynew()
const sklearn = PythonCall.pynew()

function __init__()
PythonCall.pycopy!(numpy, pyimport("numpy"))
PythonCall.pycopy!(sklearn, pyimport("sklearn"))
end

# convert return values back to Julia
tweak_rval(x) = x
function tweak_rval(x::Py)
if pyisinstance(x, numpy.ndarray)
return pyconvert(Array, x)
else
return pyconvert(Any, x)
end
end

################################################################################
# Julia => Python
################################################################################
api_map = Dict(:decision_function => :decision_function,
:fit_predict! => :fit_predict,
:fit_transform! => :fit_transform,
:get_feature_names => :get_feature_names,
:get_params => :get_params,
:predict => :predict,
:predict_proba => :predict_proba,
:predict_log_proba => :predict_log_proba,
:partial_fit! => :partial_fit,
:score_samples => :score_samples,
:sample => :sample,
:score => :score,
:transform => :transform,
:inverse_transform => :inverse_transform,
:set_params! => :set_params)

for (jl_fun, py_fun) in api_map
@eval $jl_fun(py_estimator::Py, args...; kwargs...) =
tweak_rval(py_estimator.$(py_fun)(args...; kwargs...))
end

fit!(py_estimator::Py, args...; kwargs...) = py_estimator.fit(args...; kwargs...)

end
Loading

0 comments on commit dfbef0d

Please sign in to comment.