diff --git a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt new file mode 100644 index 00000000..cf8da7d3 --- /dev/null +++ b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt @@ -0,0 +1,444 @@ +// Copyright (C) 2025 Slack Technologies, LLC +// SPDX-License-Identifier: Apache-2.0 +package slack.lint.compose + +import com.android.tools.lint.checks.DataFlowAnalyzer +import com.android.tools.lint.client.api.UElementHandler +import com.android.tools.lint.detector.api.Category +import com.android.tools.lint.detector.api.Detector +import com.android.tools.lint.detector.api.Issue +import com.android.tools.lint.detector.api.JavaContext +import com.android.tools.lint.detector.api.Severity +import com.android.tools.lint.detector.api.SourceCodeScanner +import com.android.tools.lint.detector.api.TextFormat +import com.android.tools.lint.detector.api.asCall +import com.intellij.psi.PsiArrayType +import com.intellij.psi.PsiClass +import com.intellij.psi.PsiClassType +import com.intellij.psi.PsiMethod +import com.intellij.psi.PsiParameter +import com.intellij.psi.PsiPrimitiveType +import com.intellij.psi.PsiType +import org.jetbrains.uast.UCallExpression +import org.jetbrains.uast.UElement +import org.jetbrains.uast.UExpression +import org.jetbrains.uast.UReturnExpression +import org.jetbrains.uast.UastEmptyExpression +import org.jetbrains.uast.skipParenthesizedExprDown +import org.jetbrains.uast.tryResolve +import org.jetbrains.uast.tryResolveNamed +import org.jetbrains.uast.unwrapReferenceNameElement +import org.jetbrains.uast.visitor.AbstractUastVisitor +import slack.lint.util.Name +import slack.lint.util.Package +import slack.lint.util.asClass +import slack.lint.util.implements +import slack.lint.util.isBoxedPrimitive +import slack.lint.util.isInPackageName +import slack.lint.util.sourceImplementation + +private const val REMEMBER_SAVEABLE_METHOD_NAME = "rememberSaveable" + +private val ComposeRuntimePackageName = Package("androidx.compose.runtime") +private val ComposeSaveablePackageName = Package("androidx.compose.runtime.saveable") +private val SaverFQN = Name(ComposeSaveablePackageName, "Saver").javaFqn +private val AUTO_SAVER = UastEmptyExpression(null) + +// todo +// - Collections (listOf, mapOf, etc) checks +// - ArrayList check? +// Think the savers likely go in a different detector +// - MapSaver check +// - ListSaver check +class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { + + override fun getApplicableUastTypes(): List> = + listOf(UCallExpression::class.java) + + override fun createUastHandler(context: JavaContext): UElementHandler = + RememberSaveableElementHandler(context) + + companion object { + const val MESSAGE = "remember" + const val ISSUE_ID = "RememberSaveableTypeMustBeAcceptable" + const val BRIEF_DESCRIPTION = "Brief description" + const val EXPLANATION = """Full explanation""" + val ISSUE: Issue = + Issue.create( + ISSUE_ID, + BRIEF_DESCRIPTION, + EXPLANATION, + Category.CORRECTNESS, + 10, + Severity.ERROR, + sourceImplementation(), + ) + } +} + +private class RememberSaveableElementHandler(val context: JavaContext) : UElementHandler() { + override fun visitCallExpression(node: UCallExpression) { + if (node.methodName != REMEMBER_SAVEABLE_METHOD_NAME) return + val evaluator = context.evaluator + val method = node.resolve() + val returnType = node.returnType + if ( + method != null && returnType != null && method.isInPackageName(ComposeSaveablePackageName) + ) { + val arguments = evaluator.computeArgumentMapping(node, method) + val saver = arguments.getSaver() + if (saver == AUTO_SAVER) { + visitAutoSaver(node, returnType, arguments) + } else { + visitCustomSaver(node, saver) + } + } + } + + private fun visitAutoSaver( + node: UCallExpression, + returnType: PsiType, + arguments: Map, + ) { + val init = arguments.getInit() + // If there is no init expression or if the return is an acceptable type, just return. + if (init == null || returnType.isAcceptableType()) { + return + } + // Check what is created in the init expression. + val (allAcceptableMutableStates, customPolices) = + returnsKnownMutableState(returnType, init, context) + if (allAcceptableMutableStates) { + // Found a known parcelable mutable state! + // Report the custom policy error when using the auto saver. + // todo Report error message "If you use a custom SnapshotMutationPolicy for your + // MutableState you have to write a custom Saver" + customPolices.forEach { + context.report( + RememberSaveableAcceptableDetector.Companion.ISSUE, + context.getLocation(it), + RememberSaveableAcceptableDetector.Companion.ISSUE.getBriefDescription(TextFormat.TEXT) + + " (Custom policy)", + ) + } + return + } + + if (returnsLambdaExpression(returnType, init)) { + // todo Report specific error language about kotlin Lambdas + context.report( + RememberSaveableAcceptableDetector.Companion.ISSUE, + context.getLocation(node), + RememberSaveableAcceptableDetector.Companion.ISSUE.getBriefDescription(TextFormat.TEXT) + + " (LAMBDA)", + ) + } else { + // "The default implementation only supports types which can be stored inside the Bundle. + // Please consider implementing a custom Saver for this class and pass it to + // rememberSaveable()." + context.report( + RememberSaveableAcceptableDetector.Companion.ISSUE, + context.getLocation(node), + RememberSaveableAcceptableDetector.Companion.ISSUE.getBriefDescription(TextFormat.TEXT), + ) + } + } + + private fun visitCustomSaver(node: UCallExpression, saver: UExpression) { + // todo Custom saver, check the return type of the save call and report an error. + val unwrapedElement = unwrapReferenceNameElement(saver) + val visitor = CustomSaverVisitor() + unwrapedElement?.accept(visitor) + if (visitor.saveableType?.isAcceptableType() == false) { + context.report( + RememberSaveableAcceptableDetector.Companion.ISSUE, + context.getLocation(node), + RememberSaveableAcceptableDetector.Companion.ISSUE.getBriefDescription(TextFormat.TEXT), + ) + } + } +} + +private class CustomSaverVisitor : AbstractUastVisitor() { + + var saveableType: PsiType? = null + + override fun visitExpression(node: UExpression): Boolean { + this.saveableType = node.getExpressionType()?.asClass()?.let { unwrapSaveableType(it) } + return saveableType != null || super.visitExpression(node) + } + + private fun unwrapSaveableType(element: PsiClassType): PsiType? { + val resolved = element.resolve() + return when { + resolved == null -> null + resolved.qualifiedName == SaverFQN -> element.parameters[1] + else -> resolved.superTypes.firstNotNullOfOrNull { type -> unwrapSaveableType(type) } + } + } +} + +private fun Map.getSaver(): UExpression { + val saver = + firstNotNullOfOrNull { (expression, parameter) -> + if (parameter.name == "saver") expression else null + } + ?.skipParenthesizedExprDown() + val resolved = saver?.tryResolve() + if ( + resolved is PsiMethod && + resolved.name == "autoSaver" && + resolved.isInPackageName(ComposeSaveablePackageName) + ) { + return AUTO_SAVER + } + return saver ?: AUTO_SAVER +} + +private fun Map.getInit(): UExpression? { + val init = firstNotNullOfOrNull { (expression, parameter) -> + if (parameter.name == "init") { + expression + } else null + } + return init +} + +/** + * The default Android `mutableStateOf` is `ParcelableSnapshotMutableState`, so check if all the + * return statements are default `mutableStateOf` calls. + */ +private fun returnsKnownMutableState( + returnType: PsiType, + expression: UExpression, + context: JavaContext, +): ReturnsKnownMutableStateResult { + return if (returnType.isAcceptableMutableStateClass()) { + val visitor = MutableStateOfVisitor(context) + expression.accept(visitor) + val allAcceptableMutableStates = visitor.checkAllReturned(expression) + ReturnsKnownMutableStateResult(allAcceptableMutableStates, visitor.customPolices) + } else { + ReturnsKnownMutableStateResult(false, emptyList()) + } +} + +private data class ReturnsKnownMutableStateResult( + val allAcceptableMutableStates: Boolean, + val customPolices: List, +) + +private fun PsiType?.isAcceptableMutableStateClass(): Boolean { + val psiClassType = asClass() + val isMutableState = + psiClassType + ?.resolve() + ?.implements("${ComposeRuntimePackageName.javaPackageName}.MutableState") == true + return isMutableState && psiClassType.parameters.all { it.isAcceptableType() } +} + +/** + * Visitor that tracks `mutableStateOf` function calls within an expression to determine if all such + * calls are returned from their containing expressions. + * + * This visitor is used to verify that when `rememberSaveable` contains `mutableStateOf` calls, + * those calls are actually being returned (and thus can be properly saved/restored by the + * auto-saver mechanism). + * + * @param context The JavaContext for the current lint analysis + */ +private class MutableStateOfVisitor(private val context: JavaContext) : AbstractUastVisitor() { + + /** List of all `mutableStateOf` function calls found during traversal */ + val mutableStateOfs = mutableListOf() + val customPolices = mutableListOf() + + /** + * Visits call expressions and tracks those that are known mutable state functions. + * + * @param node The call expression to examine + * @return true if the node is a known mutable state function, false otherwise + */ + override fun visitCallExpression(node: UCallExpression): Boolean = + if (isKnownMutableStateFunction(node, context)) { + mutableStateOfs.add(node) + true + } else false + + /** + * Checks if all tracked `mutableStateOf` calls are returned from their containing expressions. + * + * @return true if all tracked calls are returned, false otherwise + */ + fun checkAllReturned(expression: UExpression): Boolean { + return mutableStateOfs.isNotEmpty() && + mutableStateOfs.all { tracked -> + val visitor = ReturnsTracker(tracked) + expression.accept(visitor) + visitor.returned + } + } + + /** + * Check if the call is one of the acceptable mutable state functions, like `mutableStateOf` or + * `mutableFloatStateOf`. + */ + private fun isKnownMutableStateFunction(node: UElement, context: JavaContext): Boolean { + val resolved = node.tryResolve() + val isKnownMutableState = + resolved is PsiMethod && + resolved.isInPackageName(ComposeRuntimePackageName) && + resolved.name in AcceptableMutableStateMethods + return if (isKnownMutableState) { + // Check the type of mutableStateOf() + if (resolved.name == "mutableStateOf") { + val call = node.asCall() + val mutableClass = call?.returnType?.asClass() + val typeIsAcceptable = mutableClass?.parameters?.all { it.isAcceptableType() } == true + // Side effect to track custom policies + if (!hasAcceptablePolicy(call, context) && call != null) { + customPolices += call + } + typeIsAcceptable + } else { + // Known mutable[Primitive]StateOf() + true + } + } else false + } + + private class ReturnsTracker(tracked: UElement, var returned: Boolean = false) : + DataFlowAnalyzer(listOf(tracked)) { + override fun returns(expression: UReturnExpression) { + returned = true + } + } +} + +/** + * Checks if the mutableStateOf call uses an acceptable SnapshotMutationPolicy. Acceptable policies + * are: neverEqualPolicy, structuralEqualityPolicy, referentialEqualityPolicy + */ +private fun hasAcceptablePolicy(call: UCallExpression?, context: JavaContext): Boolean { + val method = call?.resolve() ?: return true // Default policy is acceptable + val arguments = context.evaluator.computeArgumentMapping(call, method) + + // Find the policy argument by parameter name + val policyArgument = + arguments + .firstNotNullOfOrNull { (expression, parameter) -> + if (parameter.name == "policy") expression else null + } + ?.skipParenthesizedExprDown() + ?.tryResolveNamed() + + // Check if the policy is acceptable + val isAcceptablePolicy = + policyArgument is PsiMethod && + policyArgument.isInPackageName(ComposeRuntimePackageName) && + policyArgument.name in AcceptablePolicyMethods + + // If no policy is specified, assume its the default + return policyArgument == null || isAcceptablePolicy +} + +/** + * > Lambdas in Kotlin implement Serializable, but will crash if you really try to save them. We + * > check for both Function and Serializable (see kotlin.jvm.internal.Lambda) to support custom + * > user defined classes implementing Function interface. + * + * https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L122-L124 + */ +private fun returnsLambdaExpression(returnType: PsiType, expression: UExpression): Boolean { + val isFunction = returnType.asClass()?.resolve()?.implements("kotlin.Function") == true + return if (isFunction) { + val visitor = ReturnsLambdaVisitor() + expression.accept(visitor) + visitor.returnedLambda + } else false +} + +private class ReturnsLambdaVisitor(var returnedLambda: Boolean = false) : AbstractUastVisitor() { + + override fun visitReturnExpression(node: UReturnExpression): Boolean { + val returnValue = node.returnExpression + if (returnValue != null && isLambdaExpression(returnValue)) { + returnedLambda = true + } + return super.visitReturnExpression(node) + } + + private fun isLambdaExpression(expression: UExpression): Boolean { + // Check if this is a lambda expression by looking at the expression type + val expressionType = expression.getExpressionType() + return expressionType?.asClass()?.resolve()?.implements("kotlin.Function") == true + } +} + +private fun PsiType.isAcceptableType(): Boolean { + return when (this) { + is PsiPrimitiveType -> true + is PsiArrayType -> componentType.isAcceptableType() + is PsiClassType -> { + val resolved = resolve() ?: return true // Can't resolve class type treat as acceptable? + resolved.isAcceptableClassType() && parameters.all { it.isAcceptableType() } + } + else -> false + } +} + +private fun PsiClass.isAcceptableClassType(): Boolean { + return isBoxedPrimitive() || AcceptableClasses.any { implements(it) } +} + +/** + * Based on this set used by `canBeSavedToBundle()`: + * ``` + * private val AcceptableClasses = + * arrayOf( + * Serializable::class.java, + * Parcelable::class.java, + * String::class.java, + * SparseArray::class.java, + * Binder::class.java, + * Size::class.java, + * SizeF::class.java, + * ) + * ``` + * + * https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L151-L160 + */ +private val AcceptableClasses = + setOf( + "java.io.Serializable", + "android.os.Parcelable", + "java.lang.String", + "android.util.SparseArray", + "android.os.Binder", + "android.util.Size", + "android.util.SizeF", + ) + +private val AcceptableMutableStateMethods = + setOf( + "mutableStateOf", + "mutableIntStateOf", + "mutableFloatStateOf", + "mutableDoubleStateOf", + "mutableLongStateOf", + ) + +/** + * Based on this check performed in `canBeSavedToBundle()`: + * ``` + * if ( + * value.policy === neverEqualPolicy() || + * value.policy === structuralEqualityPolicy() || + * value.policy === referentialEqualityPolicy() + * ) { + * ``` + * + * https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L112-L114 + */ +private val AcceptablePolicyMethods = + setOf("neverEqualPolicy", "structuralEqualityPolicy", "referentialEqualityPolicy") diff --git a/slack-lint-checks/src/main/java/slack/lint/util/LintUtils.kt b/slack-lint-checks/src/main/java/slack/lint/util/LintUtils.kt index f73d3967..c994efb5 100644 --- a/slack-lint-checks/src/main/java/slack/lint/util/LintUtils.kt +++ b/slack-lint-checks/src/main/java/slack/lint/util/LintUtils.kt @@ -366,3 +366,6 @@ internal val PsiMethod.returnsUnit */ internal val PsiType?.isVoidOrUnit get() = this == PsiTypes.voidType() || this?.canonicalText == "kotlin.Unit" + +/** Cast as a [PsiClassType] if possible. */ +internal fun PsiType?.asClass(): PsiClassType? = this as? PsiClassType diff --git a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt new file mode 100644 index 00000000..0a9239da --- /dev/null +++ b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt @@ -0,0 +1,711 @@ +// Copyright (C) 2025 Slack Technologies, LLC +// SPDX-License-Identifier: Apache-2.0 +package slack.lint.compose + +import com.android.tools.lint.checks.infrastructure.LintDetectorTest.java +import com.android.tools.lint.checks.infrastructure.LintDetectorTest.kotlin +import com.android.tools.lint.checks.infrastructure.TestFile +import org.junit.Ignore +import org.junit.Test +import slack.lint.BaseSlackLintTest + +class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { + + override fun getDetector() = RememberSaveableAcceptableDetector() + + override fun getIssues() = listOf(RememberSaveableAcceptableDetector.ISSUE) + + @Test + fun acceptableTypes_primitives_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val stringValue = rememberSaveable { "test" } + val intValue = rememberSaveable { 42 } + val booleanValue = rememberSaveable { true } + val floatValue = rememberSaveable { 1.0f } + val doubleValue = rememberSaveable { 1.0 } + val longValue = rememberSaveable { 1L } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun acceptableTypes_arrays_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val stringArray = rememberSaveable { arrayOf("test") } + val intArray = rememberSaveable { intArrayOf(42) } + val booleanArray = rememberSaveable { booleanArrayOf(true) } + val floatArray = rememberSaveable { floatArrayOf(1.0f) } + val doubleArray = rememberSaveable { doubleArrayOf(1.0) } + val longArray = rememberSaveable { longArrayOf(1L) } + val parcelableArray = rememberSaveable { arrayOf() } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun acceptableTypes_classes_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + import java.util.ArrayList + + @Composable + fun TestComposable() { + val serializableValue = rememberSaveable { TestSerializable() } + val parcelableValue = rememberSaveable { TestParcelable() } + val arrayListValue = rememberSaveable { ArrayList() } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun acceptableTypes_mutableStates_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.mutableDoubleStateOf + import androidx.compose.runtime.mutableFloatStateOf + import androidx.compose.runtime.mutableIntStateOf + import androidx.compose.runtime.mutableLongStateOf + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.structuralEqualityPolicy + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val mutableState = rememberSaveable { mutableStateOf("value") } + val mutableStateWithPolicy = rememberSaveable { mutableStateOf("value", structuralEqualityPolicy()) } + val mutableIntState = rememberSaveable { mutableIntStateOf(1) } + val mutableFloatState = rememberSaveable { mutableFloatStateOf(1f) } + val mutableDoubleState = rememberSaveable { mutableDoubleStateOf(1.0) } + val mutableLongState = rememberSaveable { mutableLongStateOf(1L) } + val mutableTestParcelableState = rememberSaveable { + mutableStateOf(TestParcelable()).apply { value = TestParcelable() } + } + val mutableTestSerializableState = rememberSaveable { + mutableStateOf(TestSerializable()).apply { value = TestSerializable() } + } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun acceptableTypes_autoSaver_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.saveable.autoSaver + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val mutableStateLabeled = + rememberSaveable(saver = autoSaver(), init = { mutableStateOf("value") }) + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun acceptableTypes_mutableStateWithAcceptablePolicy_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.neverEqualPolicy + import androidx.compose.runtime.referentialEqualityPolicy + import androidx.compose.runtime.structuralEqualityPolicy + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val structuralPolicy = rememberSaveable { mutableStateOf("value", structuralEqualityPolicy()) } + val referentialPolicy = rememberSaveable { mutableStateOf("value", referentialEqualityPolicy()) } + val neverEqualPolicy = rememberSaveable { mutableStateOf("value", neverEqualPolicy()) } + val defaultPolicy = rememberSaveable { mutableStateOf("value") } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun unacceptableTypes_mutableStateWithCustomPolicy_hasError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + fun customPolicy(): SnapshotMutationPolicy = TODO() + + val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } + } + + """ + .trimIndent() + ) + .indented() + + test(source) + .expect( + """ + src/test/test.kt:11: Error: Brief description (Custom policy) [RememberSaveableTypeMustBeAcceptable] + val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + 1 error + """ + .trimIndent() + ) + } + + @Test + fun unacceptableTypes_mutableStateWithCustomPolicy_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + fun customPolicy(): SnapshotMutationPolicy = TODO() + val saver = + Saver, String>( + save = { it.value }, + restore = { mutableStateOf(value = it, customPolicy()) }, + ) + val customPolicyState = + rememberSaveable(saver = saver) { + mutableStateOf(value = "value", policy = customPolicy()) + } + } + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Ignore // TODO implement + @Test + fun acceptableTypes_collections_wrong_saver_error() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val listValue = rememberSaveable { listOf("test") } + val mapValue = rememberSaveable { mapOf("key" to "value") } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun acceptableTypes_lambdas_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val lambdaInvokedValue = rememberSaveable { { "value" }() } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun acceptableTypes_lambda_errors() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val lambdaValue = rememberSaveable { { "value" } } + } + + """ + .trimIndent() + ) + .indented() + + test(source) + .expect( + """ + src/test/test.kt:8: Error: Brief description (LAMBDA) [RememberSaveableTypeMustBeAcceptable] + val lambdaValue = rememberSaveable { { "value" } } + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + 1 error + """ + .trimIndent() + ) + } + + @Test + fun acceptableTypes_multiReturnLambdas_errors() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + import kotlin.random.Random + + @Composable + fun TestComposable() { + val lambdaBlockValue = rememberSaveable { + val a = { "value" } + val b = { "other" } + val c = if (Random.nextInt() == 42) 1 else 2 + if (c == 1) { + a + } else { + ({ "other" }) + } + } + } + + """ + .trimIndent() + ) + .indented() + + test(source) + .expect( + """ + src/test/test.kt:9: Error: Brief description (LAMBDA) [RememberSaveableTypeMustBeAcceptable] + val lambdaBlockValue = rememberSaveable { + ^ + 1 error + """ + .trimIndent() + ) + } + + @Test + fun customSaver_acceptableTypes_primitives_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.Saver + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val string1Saver = Saver(save = { it }, restore = { it }) + val string3Saver = + object : Saver { + override fun restore(value: String) = value + + override fun SaverScope.save(value: String) = value + } + val string1Value = rememberSaveable(saver = string1Saver) { "test" } + val string2Value = rememberSaveable(saver = Saver(save = { it }, restore = { it })) { "test" } + val string3Value = rememberSaveable(saver = string3Saver) { "test" } + val intSaver = Saver(save = { it }, restore = { it }) + val intValue = rememberSaveable(saver = intSaver) { 42 } + val booleanSaver = Saver(save = { it }, restore = { it }) + val booleanValue = rememberSaveable(saver = booleanSaver) { true } + val floatSaver = Saver(save = { it }, restore = { it }) + val floatValue = rememberSaveable { 1.0f } + val doubleSaver = Saver(save = { it }, restore = { it }) + val doubleValue = rememberSaveable(saver = doubleSaver) { 1.0 } + val longSaver = Saver(save = { it }, restore = { it }) + val longValue = rememberSaveable { 1L } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun customSaver_acceptableTypes_arrays_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.Saver + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val stringSaver = Saver, Array>(save = { it }, restore = { it }) + val stringArray = rememberSaveable(saver = stringSaver) { arrayOf("test") } + val intSaver = Saver(save = { it }, restore = { it }) + val intArray = rememberSaveable(saver = intSaver) { intArrayOf(42) } + val booleanSaver = Saver(save = { it }, restore = { it }) + val booleanArray = rememberSaveable(saver = booleanSaver) { booleanArrayOf(true) } + val floatSaver = Saver(save = { it }, restore = { it }) + val floatArray = rememberSaveable(saver = floatSaver) { floatArrayOf(1.0f) } + val doubleSaver = Saver(save = { it }, restore = { it }) + val doubleArray = rememberSaveable(saver = doubleSaver) { doubleArrayOf(1.0) } + val longSaver = Saver(save = { it }, restore = { it }) + val longArray = rememberSaveable(saver = longSaver) { longArrayOf(1L) } + val parcelableSaver = + Saver, Array>(save = { it }, restore = { it }) + val parcelableArray = rememberSaveable(saver = parcelableSaver) { arrayOf() } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun customSaver_acceptableTypes_classes_noError() { + val saver = + kotlin( + """ + package test.saver + + import androidx.compose.runtime.saveable.Saver + import test.TestParcelable + + class CustomTestParcelableSaver : Saver { + override fun restore(value: TestParcelable) = value + + override fun SaverScope.save(value: TestParcelable) = value + } + + class CustomTestSerializableSaver : Saver { + override fun restore(value: TestSerializable) = value + + override fun SaverScope.save(value: TestSerializable) = value + } + + class CustomArrayListSaver : Saver, ArrayList> { + override fun restore(value: ArrayList) = value + + override fun SaverScope.save(value: ArrayList) = value + } + + """ + .trimIndent() + ) + .indented() + + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + import test.saver.CustomTestParcelableSaver + import test.saver.CustomTestSerializableSaver + + @Composable + fun TestComposable() { + val serializableValue = + rememberSaveable(saver = CustomTestSerializableSaver()) { TestSerializable() } + val parcelableValue = + rememberSaveable(saver = CustomTestParcelableSaver()) { TestParcelable() } + val arrayListValue = rememberSaveable(saver = CustomArrayListSaver()) { ArrayList() } + } + + """ + .trimIndent() + ) + .indented() + + test(saver, source).expectClean() + } + + private fun test(vararg source: TestFile) = + lint() + .files( + JAVA_IO, + JAVA_UTIL, + ANDROID_OS, + ANDROID_UTIL, + COMPOSE_RUNTIME, + COMPOSE_SAVEABLE, + TEST_SAVEABLES, + *source, + ) + .run() +} + +private val JAVA_IO = + java( + """ + package java.io; + + public interface Serializable { + } + + """ + .trimIndent() + ) + .indented() + +private val JAVA_UTIL = + java( + """ + package java.util; + + import java.io.Serializable; + + public class ArrayList implements Serializable {} + + """ + .trimIndent() + ) + .indented() + +private val ANDROID_OS = + java( + """ + package android.os; + + public interface Parcelable {} + + public final class Bundle implements Parcelable {} + + + """ + .trimIndent() + ) + +private val ANDROID_UTIL = + java( + """ + package android.util; + + public class SparseArray {} + public final class Size {} + public final class SizeF implements Parcelable {} + + """ + .trimIndent() + ) + +private val COMPOSE_RUNTIME = + kotlin( + """ + package androidx.compose.runtime + + annotation class Composable + + annotation class Stable + + interface SnapshotMutationPolicy + + interface State { + val value: T + } + + interface MutableState : State { + override var value: T + } + + fun mutableStateOf( + value: T, + policy: SnapshotMutationPolicy = structuralEqualityPolicy(), + ): MutableState = TODO() + + interface MutableIntState : IntState, MutableState { + override var value: Int + override var intValue: Int + } + + fun mutableIntStateOf(value: Int): MutableIntState = TODO() + + interface MutableFloatState : FloatState, MutableState { + override var value: Float + override var floatValue: Float + } + + fun mutableFloatStateOf(value: Float): MutableFloatState = TODO() + + interface MutableDoubleState : MutableState { + override var value: Double + override var doubleValue: Double + } + + fun mutableDoubleStateOf(value: Double): MutableDoubleState = TODO() + + interface MutableLongState : MutableState { + override var value: Long + override var longValue: Long + } + + fun mutableLongStateOf(value: Long): MutableLongState = TODO() + + fun structuralEqualityPolicy(): SnapshotMutationPolicy = TODO() + + fun referentialEqualityPolicy(): SnapshotMutationPolicy = TODO() + + fun neverEqualPolicy(): SnapshotMutationPolicy = TODO() + + """ + .trimIndent() + ) + .indented() + +private val COMPOSE_SAVEABLE = + kotlin( + """ + package androidx.compose.runtime.saveable + + interface Saver { + fun save(value: Original): Saveable? + fun restore(value: Saveable): Original? + } + + public fun Saver( + save: SaverScope.(value: Original) -> Saveable?, + restore: (value: Saveable) -> Original?, + ): Saver = TODO() + + fun autoSaver(): Saver = TODO() + + @Composable + fun rememberSaveable( + vararg inputs: Any?, + saver: Saver = autoSaver(), + key: String? = null, + init: () -> T + ): T = TODO() + + @Composable + fun rememberSaveable( + vararg inputs: Any?, + stateSaver: Saver, + key: String? = null, + init: () -> MutableState + ): MutableState = TODO() + + """ + .trimIndent() + ) + .indented() + +@Suppress("") +private val TEST_SAVEABLES = + kotlin( + """ + package test + + import android.os.Parcelable + import java.io.Serializable + + class TestSerializable : Serializable + class TestParcelable : Parcelable + """ + .trimIndent() + ) + .indented()