diff --git a/modules/core/shared/src/main/scala/data/Completion.scala b/modules/core/shared/src/main/scala/data/Completion.scala index 345a0595..97f36346 100644 --- a/modules/core/shared/src/main/scala/data/Completion.scala +++ b/modules/core/shared/src/main/scala/data/Completion.scala @@ -72,6 +72,10 @@ object Completion { case object DropPolicy extends Completion case object Comment extends Completion case object Analyze extends Completion + case object AlterDefaultPrivileges extends Completion + case object GrantRole extends Completion + case object RevokeRole extends Completion + // more ... // weird Redshift variations diff --git a/modules/core/shared/src/main/scala/net/message/CommandComplete.scala b/modules/core/shared/src/main/scala/net/message/CommandComplete.scala index cbbdc54e..bee536fc 100644 --- a/modules/core/shared/src/main/scala/net/message/CommandComplete.scala +++ b/modules/core/shared/src/main/scala/net/message/CommandComplete.scala @@ -111,6 +111,9 @@ object CommandComplete { case "ALTER POLICY" => apply(Completion.AlterPolicy) case "DROP POLICY" => apply(Completion.DropPolicy) case "ANALYZE" => apply(Completion.Analyze) + case "ALTER DEFAULT PRIVILEGES" => apply(Completion.AlterDefaultPrivileges) + case "GRANT ROLE" => apply(Completion.GrantRole) + case "REVOKE ROLE" => apply(Completion.RevokeRole) // more .. fill in as we hit them // weird Redshift variations diff --git a/modules/tests/shared/src/test/scala/CommandTest.scala b/modules/tests/shared/src/test/scala/CommandTest.scala index 061e7a4a..8c73a4a3 100644 --- a/modules/tests/shared/src/test/scala/CommandTest.scala +++ b/modules/tests/shared/src/test/scala/CommandTest.scala @@ -664,4 +664,23 @@ class CommandTest extends SkunkTest { } yield "ok" } + sessionTestWithCleanup("grant role, revoke role") { s => + for { + _ <- s.execute(createRole) + _ <- s.execute(sql"""CREATE ROLE skunk_role2""".command) + c <- s.execute(sql"""GRANT skunk_role2 TO skunk_role""".command) + _ <- assertEqual("completion", c, Completion.GrantRole) + c <- s.execute(sql"""REVOKE skunk_role2 FROM skunk_role""".command) + _ <- assertEqual("completion", c, Completion.RevokeRole) + } yield "ok" + }(sql"""DROP ROLE IF EXISTS skunk_role2""".command, dropRole) + + sessionTest("alter default privileges") { s => + for { + c <- s.execute(sql"""ALTER DEFAULT PRIVILEGES GRANT SELECT ON TABLES TO PUBLIC""".command) + _ <- assertEqual("completion", c, Completion.AlterDefaultPrivileges) + _ <- s.assertHealthy + } yield "ok" + } + } diff --git a/modules/tests/shared/src/test/scala/SkunkTest.scala b/modules/tests/shared/src/test/scala/SkunkTest.scala index 3818b671..1e4ee72b 100644 --- a/modules/tests/shared/src/test/scala/SkunkTest.scala +++ b/modules/tests/shared/src/test/scala/SkunkTest.scala @@ -4,15 +4,16 @@ package tests -import cats.effect.{ IO, Resource } -import cats.syntax.all._ -import skunk.Session -import skunk.data._ -import skunk.codec.all._ -import skunk.implicits._ +import cats.effect.{IO, Resource} +import cats.syntax.all.* +import skunk.{Command, Session, Void} +import skunk.data.* +import skunk.codec.all.* +import skunk.implicits.* import skunk.util.Typer import natchez.Trace.Implicits.noop import munit.Location + import scala.concurrent.duration.Duration abstract class SkunkTest(debug: Boolean = false, strategy: Typer.Strategy = Typer.Strategy.BuiltinsOnly) extends ffstest.FTest { @@ -34,6 +35,18 @@ abstract class SkunkTest(debug: Boolean = false, strategy: Typer.Strategy = Type def sessionTest[A](name: String, readTimeout: Duration = Duration.Inf)(fa: Session[IO] => IO[A])(implicit loc: Location): Unit = test(name)(session(readTimeout).use(fa)) + def sessionTestWithCleanup[A](name: String, + readTimeout: Duration = Duration.Inf) + (fa: Session[IO] => IO[A]) + (cleanup: Command[Void]*) + (implicit loc: Location): Unit = + test(name) { + session(readTimeout).use { s => + (fa(s) >> s.assertHealthy) + .guarantee(cleanup.toList.traverse_(s.execute)) + } + } + def pooled(readTimeout: Duration): Resource[IO, Resource[IO, Session[IO]]] = Session.pooled( host = "localhost",