Skip to content

Commit de3d3a7

Browse files
authored
Allow scala_junit_tests targets to specify test environment variables (#1384)
* Add env attr and TestingEnvironment provider to scala_junit_tests * Add tests for scala_junit_tests env attr
1 parent 3dd5d81 commit de3d3a7

File tree

7 files changed

+73
-0
lines changed

7 files changed

+73
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Adds testing.TestEnvironment provider if "env" attr is specified
2+
# https://bazel.build/rules/lib/testing#TestEnvironment
3+
4+
def phase_test_environment(ctx, p):
5+
test_env = ctx.attr.env
6+
7+
if test_env:
8+
return struct(
9+
external_providers = {
10+
"TestingEnvironment": testing.TestEnvironment(test_env),
11+
},
12+
)
13+
14+
return struct()

scala/private/phases/phase_write_executable.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def _write_executable_windows(ctx, executable, rjars, main_class, jvm_flags, wra
9090
inputs = [cpfile],
9191
executable = ctx.attr._exe.files_to_run.executable,
9292
arguments = [executable.path, ctx.workspace_name, java_for_exe, main_class, cpfile.path, jvm_flags_str],
93+
env = ctx.attr.env,
9394
mnemonic = "ExeLauncher",
9495
progress_message = "Creating exe launcher",
9596
)

scala/private/phases/phases.bzl

+4
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ load("@io_bazel_rules_scala//scala/private:phases/phase_merge_jars.bzl", _phase_
6464
load("@io_bazel_rules_scala//scala/private:phases/phase_jvm_flags.bzl", _phase_jvm_flags = "phase_jvm_flags")
6565
load("@io_bazel_rules_scala//scala/private:phases/phase_coverage_runfiles.bzl", _phase_coverage_runfiles = "phase_coverage_runfiles")
6666
load("@io_bazel_rules_scala//scala/private:phases/phase_scalafmt.bzl", _phase_scalafmt = "phase_scalafmt")
67+
load("@io_bazel_rules_scala//scala/private:phases/phase_test_environment.bzl", _phase_test_environment = "phase_test_environment")
6768

6869
# API
6970
run_phases = _run_phases
@@ -136,5 +137,8 @@ phase_runfiles_common = _phase_runfiles_common
136137
# default_info
137138
phase_default_info = _phase_default_info
138139

140+
# test_environment
141+
phase_test_environment = _phase_test_environment
142+
139143
# scalafmt
140144
phase_scalafmt = _phase_scalafmt

scala/private/rules/scala_junit_test.bzl

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ load(
2323
"phase_merge_jars",
2424
"phase_runfiles_common",
2525
"phase_scalac_provider",
26+
"phase_test_environment",
2627
"phase_write_executable_junit_test",
2728
"phase_write_manifest",
2829
"run_phases",
@@ -52,6 +53,7 @@ def _scala_junit_test_impl(ctx):
5253
("jvm_flags", phase_jvm_flags),
5354
("write_executable", phase_write_executable_junit_test),
5455
("default_info", phase_default_info),
56+
("test_environment", phase_test_environment),
5557
],
5658
)
5759

@@ -75,6 +77,7 @@ _scala_junit_test_attrs = {
7577
default = Label("@bazel_tools//tools/jdk:current_java_runtime"),
7678
providers = [java_common.JavaRuntimeInfo],
7779
),
80+
"env": attr.string_dict(default = {}),
7881
"_junit_classpath": attr.label(
7982
default = Label("@io_bazel_rules_scala//testing/toolchain:junit_classpath"),
8083
),

test/BUILD

+19
Original file line numberDiff line numberDiff line change
@@ -777,6 +777,25 @@ scala_junit_test(
777777
runtime_deps = [":JunitRuntimePlatform"],
778778
)
779779

780+
scala_junit_test(
781+
name = "JunitNoTestEnvironmentTest",
782+
size = "small",
783+
srcs = ["src/main/scala/scalarules/test/junit/JunitNoTestEnvironmentTest.scala"],
784+
suffixes = ["Test"],
785+
deps = ["@io_bazel_rules_scala_junit_junit"],
786+
)
787+
788+
scala_junit_test(
789+
name = "JunitSetTestEnvironmentTest",
790+
size = "small",
791+
srcs = ["src/main/scala/scalarules/test/junit/JunitSetTestEnvironmentTest.scala"],
792+
env = {
793+
"my_env_var": "my_value",
794+
},
795+
suffixes = ["Test"],
796+
deps = ["@io_bazel_rules_scala_junit_junit"],
797+
)
798+
780799
py_binary(
781800
name = "py_resource_binary",
782801
srcs = ["py_resource.py"],
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package scalarules.test.junit
2+
3+
import org.junit.Assert.fail
4+
import org.junit.Test
5+
6+
class JunitNoTestEnvironmentTest {
7+
8+
@Test
9+
def testUnsetEnvVarIsNull: Unit = {
10+
System.getenv("my_unset_env_var") match {
11+
case null => ()
12+
case x => fail(s"Unexpectedly obtained my_unset_env_var=$x")
13+
}
14+
}
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package scalarules.test.junit
2+
3+
import org.junit.Assert
4+
import org.junit.Assert.fail
5+
import org.junit.Test
6+
7+
class JunitSetTestEnvironmentTest {
8+
9+
@Test
10+
def testSetEnvVarEqualsValue: Unit = {
11+
System.getenv("my_unset_env_var") match {
12+
case null => ()
13+
case x => fail(s"Unexpectedly obtained my_unset_env_var=$x")
14+
}
15+
Assert.assertEquals(System.getenv("my_env_var"), "my_value")
16+
}
17+
}

0 commit comments

Comments
 (0)