Skip to content

Commit

Permalink
attach worker config to the toolchain
Browse files Browse the repository at this point in the history
  • Loading branch information
tek committed Nov 8, 2024
1 parent ed8370b commit f25e09e
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 33 deletions.
11 changes: 2 additions & 9 deletions decls/haskell_rules.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ load("@prelude//linking:types.bzl", "Linkage")
load(":common.bzl", "LinkableDepType", "buck", "prelude_rule")
load(":haskell_common.bzl", "haskell_common")
load(":native_common.bzl", "native_common")
load("@prelude//haskell/worker/worker.bzl", "worker_libs", "worker_srcs", "worker_flags")

haskell_binary = prelude_rule(
name = "haskell_binary",
Expand Down Expand Up @@ -67,10 +66,7 @@ haskell_binary = prelude_rule(
"linker_flags": attrs.list(attrs.arg(), default = []),
"platform": attrs.option(attrs.string(), default = None),
"platform_linker_flags": attrs.list(attrs.tuple(attrs.regex(), attrs.list(attrs.arg())), default = []),
"_worker_srcs": attrs.list(attrs.source(), default = worker_srcs),
"_worker_deps": attrs.list(attrs.dep(), default = ["@prelude//haskell/worker:{}".format(pkg) for pkg in worker_libs]),
"_worker_compiler_flags": attrs.list(attrs.string(), default = worker_flags),
"_worker_plugin": attrs.dep(default = "@prelude//haskell/worker:ghc-persistent-worker-plugin"),
"allow_worker": attrs.bool(default = True),
}
),
)
Expand Down Expand Up @@ -193,10 +189,7 @@ haskell_library = prelude_rule(
"linker_flags": attrs.list(attrs.arg(), default = []),
"platform": attrs.option(attrs.string(), default = None),
"platform_linker_flags": attrs.list(attrs.tuple(attrs.regex(), attrs.list(attrs.arg())), default = []),
"_worker_srcs": attrs.list(attrs.source(), default = worker_srcs),
"_worker_deps": attrs.list(attrs.dep(), default = ["@prelude//haskell/worker:{}".format(pkg) for pkg in worker_libs]),
"_worker_compiler_flags": attrs.list(attrs.string(), default = worker_flags),
"_worker_plugin": attrs.dep(default = "@prelude//haskell/worker:ghc-persistent-worker-plugin"),
"allow_worker": attrs.bool(default = True),
}
),
)
Expand Down
15 changes: 8 additions & 7 deletions haskell/compile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -437,19 +437,21 @@ def _common_compile_module_args(
external_tool_paths: list[RunInfo],
sources: list[Artifact],
direct_deps_info: list[HaskellLibraryInfoTSet],
allow_worker: bool,
pkgname: str | None = None,
worker_plugin: Dependency | None = None,
) -> CommonCompileModuleArgs:
command = cmd_args(ghc_wrapper)
command.add("--ghc", haskell_toolchain.compiler)
command.add("--ghc-dir", haskell_toolchain.ghc_dir)

if haskell_toolchain.use_worker and worker_plugin != None:
if allow_worker and haskell_toolchain.use_worker and haskell_toolchain.use_worker_multiplexer:
if haskell_toolchain.worker_multiplexer_plugin == None:
fail("'worker_multiplexer_plugin' must be set on the toolchain if 'use_worker_multiplexer' is true")
if pkgname == None:
warning("Module {} has no 'pkgname', plugin worker will break".format(label))
warning("Module {} has no 'pkgname', worker multiplexer will break".format(label))
else:
package_db = pkg_deps.providers[DynamicHaskellPackageDbInfo].packages
db = package_db[worker_plugin[HaskellToolchainLibrary].name]
db = package_db[haskell_toolchain.worker_multiplexer_plugin[HaskellToolchainLibrary].name]
command.add("--plugin-db", db.value.db)
command.add("--worker-target-id", pkgname)

Expand Down Expand Up @@ -763,8 +765,8 @@ def _dynamic_do_compile_impl(actions, md_file, pkg_deps, arg, direct_deps_by_nam
enable_profiling = arg.enable_profiling,
link_style = arg.link_style,
direct_deps_info = arg.direct_deps_info,
allow_worker = arg.allow_worker,
pkgname = arg.pkgname,
worker_plugin = arg.worker_plugin,
)

md = md_file.read_json()
Expand Down Expand Up @@ -825,7 +827,6 @@ def compile(
enable_profiling: bool,
enable_haddock: bool,
md_file: Artifact,
worker_plugin: Dependency | None,
worker: WorkerInfo | None = None,
pkgname: str | None = None) -> CompileResultInfo:
artifact_suffix = get_artifact_suffix(link_style, enable_profiling)
Expand Down Expand Up @@ -883,7 +884,7 @@ def compile(
srcs_envs = ctx.attrs.srcs_envs,
toolchain_deps_by_name = toolchain_deps_by_name,
worker = worker,
worker_plugin = worker_plugin,
allow_worker = ctx.attrs.allow_worker,
),
))

Expand Down
21 changes: 10 additions & 11 deletions haskell/haskell.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,6 @@ def _build_haskell_lib(
md_file = md_file,
pkgname = pkgname,
worker = _persistent_worker(ctx),
worker_plugin = ctx.attrs._worker_plugin if ctx.label.cell != "prelude" and ctx.attrs._haskell_toolchain[HaskellToolchainInfo].use_worker else None,
)
solibs = {}
artifact_suffix = get_artifact_suffix(link_style, enable_profiling)
Expand Down Expand Up @@ -1224,7 +1223,6 @@ def haskell_binary_impl(ctx: AnalysisContext) -> list[Provider]:
md_file = md_file,
worker = _persistent_worker(ctx),
pkgname = pkgname,
worker_plugin = ctx.attrs._worker_plugin if hasattr(ctx.attrs, "_worker_plugin") else None,
)

haskell_toolchain = ctx.attrs._haskell_toolchain[HaskellToolchainInfo]
Expand Down Expand Up @@ -1505,9 +1503,8 @@ worker = anon_rule(
"srcs_deps": attrs.dict(attrs.string(), attrs.dep(), default = {}),
"srcs_envs": attrs.dict(attrs.string(), attrs.string(), default = {}),
"template_deps": attrs.list(attrs.dep(), default = []),
# N.B. the _worker_* attrs are only treated by the call site of the anon_target
"_worker_deps": attrs.default_only(attrs.list(attrs.dep(), default = [])),
"_worker_srcs": attrs.default_only(attrs.list(attrs.source(), default = [])),
# N.B. allow_worker is only treated by the call site of the anon_target
"allow_worker": attrs.bool(),
}
| haskell_common.use_argsfile_at_link_arg()
| native_common.link_style(),
Expand All @@ -1517,10 +1514,11 @@ worker = anon_rule(
)

def _persistent_worker(ctx: AnalysisContext) -> WorkerInfo | None:
if ctx.label.cell == "prelude":
if not ctx.attrs.allow_worker:
return None

if not ctx.attrs._haskell_toolchain[HaskellToolchainInfo].use_worker:
tc = ctx.attrs._haskell_toolchain[HaskellToolchainInfo]
if not tc.use_worker:
return None

worker_target = ctx.actions.anon_target(
Expand All @@ -1530,21 +1528,22 @@ def _persistent_worker(ctx: AnalysisContext) -> WorkerInfo | None:
"_generate_target_metadata": ctx.attrs._generate_target_metadata,
"_ghc_wrapper": ctx.attrs._ghc_wrapper,
"_haskell_toolchain": ctx.attrs._haskell_toolchain,
"deps": ctx.attrs._worker_deps,
"deps": tc.worker_deps,
"link_style": "shared",
"name": "prelude//haskell:worker",
"srcs": ctx.attrs._worker_srcs,
"compiler_flags": ctx.attrs._worker_compiler_flags + [
"srcs": tc.worker_srcs_multiplexer if tc.use_worker_multiplexer else tc.worker_srcs,
"compiler_flags": tc.worker_compiler_flags + [
"-O2",
],
"linker_flags": ctx.attrs._worker_compiler_flags + [
"linker_flags": [
"-dynamic",
"-rtsopts=all",
"-with-rtsopts=-K512M -H -I5 -T",
"-threaded",
"-O2",
],
"use_argsfile_at_link": False,
"allow_worker": False,
},
)
return WorkerInfo(worker_target.artifact("worker"))
Expand Down
6 changes: 6 additions & 0 deletions haskell/toolchain.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ HaskellToolchainInfo = provider(
"script_template_processor": provider_field(typing.Any, default = None),
"packages": provider_field(typing.Any, default = None),
"use_worker": provider_field(bool, default = False),
"use_worker_multiplexer": provider_field(bool, default = False),
"worker_multiplexer_plugin": provider_field(None | Dependency, default = None),
"worker_srcs": provider_field(typing.Any, default = []),
"worker_srcs_multiplexer": provider_field(typing.Any, default = []),
"worker_deps": provider_field(typing.Any, default = []),
"worker_compiler_flags": provider_field(typing.Any, default = []),
"ghc_dir": provider_field(typing.Any, default = None),
},
)
Expand Down
2 changes: 2 additions & 0 deletions haskell/worker/BUCK
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
load(":worker.bzl", "worker_libs")

[haskell_toolchain_library(name = pkg, visibility = ["PUBLIC"]) for pkg in worker_libs]

haskell_toolchain_library(name = "ghc-persistent-worker-plugin", visibility = ["PUBLIC"])
17 changes: 11 additions & 6 deletions haskell/worker/worker.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ worker_libs = [
"extra",
"filepath",
"ghc",
"ghc-persistent-worker-plugin",
"grpc-haskell",
"network",
"process",
Expand All @@ -21,21 +20,28 @@ worker_libs = [
"unix",
]

worker_srcs = [
worker_srcs_shared = [
"@prelude//haskell/worker/impl/plugin/src:Internal/AbiHash.hs",
"@prelude//haskell/worker/impl/plugin/src:Internal/Args.hs",
"@prelude//haskell/worker/impl/plugin/src:Internal/Cache.hs",
"@prelude//haskell/worker/impl/plugin/src:Internal/Compile.hs",
"@prelude//haskell/worker/impl/plugin/src:Internal/Error.hs",
"@prelude//haskell/worker/impl/plugin/src:Internal/Log.hs",
"@prelude//haskell/worker/impl/plugin/src:Internal/Session.hs",
"@prelude//haskell/worker/impl/buck-worker:Args.hs",
"@prelude//haskell/worker/impl/buck-worker:BuckWorker.hs",
]

worker_srcs = worker_srcs_shared + [
"@prelude//haskell/worker/impl/buck-worker:Main.hs",
]

worker_srcs_multiplexer = worker_srcs_shared + [
"@prelude//haskell/worker/impl/comm/src:Message.hs",
"@prelude//haskell/worker/impl/server/app:Server.hs",
"@prelude//haskell/worker/impl/server/app:Pool.hs",
"@prelude//haskell/worker/impl/server/app:Worker.hs",
"@prelude//haskell/worker/impl/comm/src:Message.hs",
"@prelude//haskell/worker/impl/buck-worker:Args.hs",
"@prelude//haskell/worker/impl/buck-worker-2:Main.hs",
"@prelude//haskell/worker/impl/buck-worker:BuckWorker.hs",
]

worker_flags = [
Expand All @@ -47,6 +53,5 @@ worker_flags = [
"-XDuplicateRecordFields",
"-XOverloadedRecordDot",
"-XStrictData",
# "-XNoFieldSelectors",
"-XLambdaCase",
]

0 comments on commit f25e09e

Please sign in to comment.