diff --git a/decls/haskell_rules.bzl b/decls/haskell_rules.bzl index 765ef53ad..6fbd82da2 100644 --- a/decls/haskell_rules.bzl +++ b/decls/haskell_rules.bzl @@ -14,7 +14,6 @@ load("@prelude//linking:types.bzl", "Linkage") load(":common.bzl", "LinkableDepType", "buck", "prelude_rule") load(":haskell_common.bzl", "haskell_common") load(":native_common.bzl", "native_common") -load("@prelude//haskell/worker/worker.bzl", "worker_libs", "worker_srcs", "worker_flags") haskell_binary = prelude_rule( name = "haskell_binary", @@ -67,10 +66,7 @@ haskell_binary = prelude_rule( "linker_flags": attrs.list(attrs.arg(), default = []), "platform": attrs.option(attrs.string(), default = None), "platform_linker_flags": attrs.list(attrs.tuple(attrs.regex(), attrs.list(attrs.arg())), default = []), - "_worker_srcs": attrs.list(attrs.source(), default = worker_srcs), - "_worker_deps": attrs.list(attrs.dep(), default = ["@prelude//haskell/worker:{}".format(pkg) for pkg in worker_libs]), - "_worker_compiler_flags": attrs.list(attrs.string(), default = worker_flags), - "_worker_plugin": attrs.dep(default = "@prelude//haskell/worker:ghc-persistent-worker-plugin"), + "allow_worker": attrs.bool(default = True), } ), ) @@ -193,10 +189,7 @@ haskell_library = prelude_rule( "linker_flags": attrs.list(attrs.arg(), default = []), "platform": attrs.option(attrs.string(), default = None), "platform_linker_flags": attrs.list(attrs.tuple(attrs.regex(), attrs.list(attrs.arg())), default = []), - "_worker_srcs": attrs.list(attrs.source(), default = worker_srcs), - "_worker_deps": attrs.list(attrs.dep(), default = ["@prelude//haskell/worker:{}".format(pkg) for pkg in worker_libs]), - "_worker_compiler_flags": attrs.list(attrs.string(), default = worker_flags), - "_worker_plugin": attrs.dep(default = "@prelude//haskell/worker:ghc-persistent-worker-plugin"), + "allow_worker": attrs.bool(default = True), } ), ) diff --git a/haskell/compile.bzl b/haskell/compile.bzl index 19b8ceddf..68e0c31a6 100644 --- a/haskell/compile.bzl +++ b/haskell/compile.bzl @@ -437,20 +437,22 @@ def _common_compile_module_args( external_tool_paths: list[RunInfo], sources: list[Artifact], direct_deps_info: list[HaskellLibraryInfoTSet], + allow_worker: bool, pkgname: str | None = None, - worker_plugin: Dependency | None = None, ) -> CommonCompileModuleArgs: command = cmd_args(ghc_wrapper) command.add("--ghc", haskell_toolchain.compiler) command.add("--ghc-dir", haskell_toolchain.ghc_dir) - if haskell_toolchain.use_worker and worker_plugin != None: + if allow_worker and haskell_toolchain.use_worker and haskell_toolchain.use_worker_multiplexer: + if haskell_toolchain.worker_multiplexer_plugin == None: + fail("'worker_multiplexer_plugin' must be set on the toolchain if 'use_worker_multiplexer' is true") if pkgname == None: - warning("Module {} has no 'pkgname', plugin worker will break".format(label)) + warning("Module {} has no 'pkgname', worker multiplexer will break".format(label)) else: pkg_deps = resolved[haskell_toolchain.packages.dynamic] package_db = pkg_deps.providers[DynamicHaskellPackageDbInfo].packages - db = package_db[worker_plugin[HaskellToolchainLibrary].name] + db = package_db[haskell_toolchain.worker_multiplexer_plugin[HaskellToolchainLibrary].name] command.add("--plugin-db", db.value.db) command.add("--worker-target-id", pkgname) @@ -764,8 +766,8 @@ def _dynamic_do_compile_impl(actions, md_file, pkg_deps, arg, direct_deps_by_nam enable_profiling = arg.enable_profiling, link_style = arg.link_style, direct_deps_info = arg.direct_deps_info, + allow_worker = arg.allow_worker, pkgname = arg.pkgname, - worker_plugin = arg.worker_plugin, ) md = md_file.read_json() @@ -826,7 +828,6 @@ def compile( enable_profiling: bool, enable_haddock: bool, md_file: Artifact, - worker_plugin: Dependency | None, worker: WorkerInfo | None = None, pkgname: str | None = None) -> CompileResultInfo: artifact_suffix = get_artifact_suffix(link_style, enable_profiling) @@ -884,7 +885,7 @@ def compile( srcs_envs = ctx.attrs.srcs_envs, toolchain_deps_by_name = toolchain_deps_by_name, worker = worker, - worker_plugin = worker_plugin, + allow_worker = ctx.attrs.allow_worker, ), )) diff --git a/haskell/haskell.bzl b/haskell/haskell.bzl index 2918bcdb1..54cbcbe6d 100644 --- a/haskell/haskell.bzl +++ b/haskell/haskell.bzl @@ -648,7 +648,6 @@ def _build_haskell_lib( md_file = md_file, pkgname = pkgname, worker = _persistent_worker(ctx), - worker_plugin = ctx.attrs._worker_plugin if ctx.label.cell != "prelude" and ctx.attrs._haskell_toolchain[HaskellToolchainInfo].use_worker else None, ) solibs = {} artifact_suffix = get_artifact_suffix(link_style, enable_profiling) @@ -1224,7 +1223,6 @@ def haskell_binary_impl(ctx: AnalysisContext) -> list[Provider]: md_file = md_file, worker = _persistent_worker(ctx), pkgname = pkgname, - worker_plugin = ctx.attrs._worker_plugin if hasattr(ctx.attrs, "_worker_plugin") else None, ) haskell_toolchain = ctx.attrs._haskell_toolchain[HaskellToolchainInfo] @@ -1505,9 +1503,8 @@ worker = anon_rule( "srcs_deps": attrs.dict(attrs.string(), attrs.dep(), default = {}), "srcs_envs": attrs.dict(attrs.string(), attrs.string(), default = {}), "template_deps": attrs.list(attrs.dep(), default = []), - # N.B. the _worker_* attrs are only treated by the call site of the anon_target - "_worker_deps": attrs.default_only(attrs.list(attrs.dep(), default = [])), - "_worker_srcs": attrs.default_only(attrs.list(attrs.source(), default = [])), + # N.B. allow_worker is only treated by the call site of the anon_target + "allow_worker": attrs.bool(), } | haskell_common.use_argsfile_at_link_arg() | native_common.link_style(), @@ -1517,10 +1514,11 @@ worker = anon_rule( ) def _persistent_worker(ctx: AnalysisContext) -> WorkerInfo | None: - if ctx.label.cell == "prelude": + if not ctx.attrs.allow_worker: return None - if not ctx.attrs._haskell_toolchain[HaskellToolchainInfo].use_worker: + tc = ctx.attrs._haskell_toolchain[HaskellToolchainInfo] + if not tc.use_worker: return None worker_target = ctx.actions.anon_target( @@ -1530,14 +1528,14 @@ def _persistent_worker(ctx: AnalysisContext) -> WorkerInfo | None: "_generate_target_metadata": ctx.attrs._generate_target_metadata, "_ghc_wrapper": ctx.attrs._ghc_wrapper, "_haskell_toolchain": ctx.attrs._haskell_toolchain, - "deps": ctx.attrs._worker_deps, + "deps": tc.worker_deps, "link_style": "shared", "name": "prelude//haskell:worker", - "srcs": ctx.attrs._worker_srcs, - "compiler_flags": ctx.attrs._worker_compiler_flags + [ + "srcs": tc.worker_srcs_multiplexer if tc.use_worker_multiplexer else tc.worker_srcs, + "compiler_flags": tc.worker_compiler_flags + [ "-O2", ], - "linker_flags": ctx.attrs._worker_compiler_flags + [ + "linker_flags": [ "-dynamic", "-rtsopts=all", "-with-rtsopts=-K512M -H -I5 -T", @@ -1545,6 +1543,7 @@ def _persistent_worker(ctx: AnalysisContext) -> WorkerInfo | None: "-O2", ], "use_argsfile_at_link": False, + "allow_worker": False, }, ) return WorkerInfo(worker_target.artifact("worker")) diff --git a/haskell/toolchain.bzl b/haskell/toolchain.bzl index d4a3eae5b..80ef57b27 100644 --- a/haskell/toolchain.bzl +++ b/haskell/toolchain.bzl @@ -41,6 +41,12 @@ HaskellToolchainInfo = provider( "script_template_processor": provider_field(typing.Any, default = None), "packages": provider_field(typing.Any, default = None), "use_worker": provider_field(bool, default = False), + "use_worker_multiplexer": provider_field(bool, default = False), + "worker_multiplexer_plugin": provider_field(None | Dependency, default = None), + "worker_srcs": provider_field(typing.Any, default = []), + "worker_srcs_multiplexer": provider_field(typing.Any, default = []), + "worker_deps": provider_field(typing.Any, default = []), + "worker_compiler_flags": provider_field(typing.Any, default = []), "ghc_dir": provider_field(typing.Any, default = None), }, ) diff --git a/haskell/worker/BUCK b/haskell/worker/BUCK index 232b24429..9d8f52a26 100644 --- a/haskell/worker/BUCK +++ b/haskell/worker/BUCK @@ -1,3 +1,5 @@ load(":worker.bzl", "worker_libs") [haskell_toolchain_library(name = pkg, visibility = ["PUBLIC"]) for pkg in worker_libs] + +haskell_toolchain_library(name = "ghc-persistent-worker-plugin", visibility = ["PUBLIC"]) diff --git a/haskell/worker/worker.bzl b/haskell/worker/worker.bzl index 354039d62..a6f2293c8 100644 --- a/haskell/worker/worker.bzl +++ b/haskell/worker/worker.bzl @@ -8,7 +8,6 @@ worker_libs = [ "extra", "filepath", "ghc", - "ghc-persistent-worker-plugin", "grpc-haskell", "network", "process", @@ -21,7 +20,7 @@ worker_libs = [ "unix", ] -worker_srcs = [ +worker_srcs_shared = [ "@prelude//haskell/worker/impl/plugin/src:Internal/AbiHash.hs", "@prelude//haskell/worker/impl/plugin/src:Internal/Args.hs", "@prelude//haskell/worker/impl/plugin/src:Internal/Cache.hs", @@ -29,13 +28,20 @@ worker_srcs = [ "@prelude//haskell/worker/impl/plugin/src:Internal/Error.hs", "@prelude//haskell/worker/impl/plugin/src:Internal/Log.hs", "@prelude//haskell/worker/impl/plugin/src:Internal/Session.hs", + "@prelude//haskell/worker/impl/buck-worker:Args.hs", + "@prelude//haskell/worker/impl/buck-worker:BuckWorker.hs", +] + +worker_srcs = worker_srcs_shared + [ + "@prelude//haskell/worker/impl/buck-worker:Main.hs", +] + +worker_srcs_multiplexer = worker_srcs_shared + [ + "@prelude//haskell/worker/impl/comm/src:Message.hs", "@prelude//haskell/worker/impl/server/app:Server.hs", "@prelude//haskell/worker/impl/server/app:Pool.hs", "@prelude//haskell/worker/impl/server/app:Worker.hs", - "@prelude//haskell/worker/impl/comm/src:Message.hs", - "@prelude//haskell/worker/impl/buck-worker:Args.hs", "@prelude//haskell/worker/impl/buck-worker-2:Main.hs", - "@prelude//haskell/worker/impl/buck-worker:BuckWorker.hs", ] worker_flags = [ @@ -47,6 +53,5 @@ worker_flags = [ "-XDuplicateRecordFields", "-XOverloadedRecordDot", "-XStrictData", - # "-XNoFieldSelectors", "-XLambdaCase", ]