From bcfb0efd4a6f8b4fe927fb8eb2e27797e28b70b3 Mon Sep 17 00:00:00 2001 From: Josh Stagg Date: Wed, 30 Jul 2025 19:46:29 -0700 Subject: [PATCH 01/10] wip --- .../RememberSaveableAcceptableDetector.kt | 233 +++++++++++++++ .../RememberSaveableAcceptableDetectorTest.kt | 274 ++++++++++++++++++ 2 files changed, 507 insertions(+) create mode 100644 slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt create mode 100644 slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt diff --git a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt new file mode 100644 index 00000000..42b06110 --- /dev/null +++ b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt @@ -0,0 +1,233 @@ +// Copyright (C) 2025 Slack Technologies, LLC +// SPDX-License-Identifier: Apache-2.0 +package slack.lint.compose + +import com.android.tools.lint.client.api.UElementHandler +import com.android.tools.lint.detector.api.Category +import com.android.tools.lint.detector.api.Detector +import com.android.tools.lint.detector.api.Issue +import com.android.tools.lint.detector.api.JavaContext +import com.android.tools.lint.detector.api.JavaContext.Companion.getMethodName +import com.android.tools.lint.detector.api.Severity +import com.android.tools.lint.detector.api.SourceCodeScanner +import com.android.tools.lint.detector.api.TextFormat +import com.android.tools.lint.detector.api.asCall +import com.intellij.psi.PsiArrayType +import com.intellij.psi.PsiClass +import com.intellij.psi.PsiClassType +import com.intellij.psi.PsiMethod +import com.intellij.psi.PsiParameter +import com.intellij.psi.PsiPrimitiveType +import com.intellij.psi.PsiType +import org.jetbrains.kotlin.asJava.elements.KtLightMember +import org.jetbrains.kotlin.asJava.elements.KtLightMethod +import org.jetbrains.uast.UBlockExpression +import org.jetbrains.uast.UCallExpression +import org.jetbrains.uast.UElement +import org.jetbrains.uast.UExpression +import org.jetbrains.uast.ULambdaExpression +import org.jetbrains.uast.UReturnExpression +import org.jetbrains.uast.kotlin.KotlinUImplicitReturnExpression +import org.jetbrains.uast.tryResolve +import slack.lint.util.Package +import slack.lint.util.implements +import slack.lint.util.isBoxedPrimitive +import slack.lint.util.isInPackageName +import slack.lint.util.sourceImplementation + +private val RememberSaveablePackageName = Package("androidx.compose.runtime.saveable") +private const val RememberSaveableMethodName = "rememberSaveable" + +// todo Rewrite this so its checking this instead +// Android only source set +// rememberSaveable with autoSaver -> Check value type is in "AcceptableClasses" +// rememberSaveable with custom saver -> Check Saver.save is in in "AcceptableClasses" + +class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { + + override fun getApplicableUastTypes(): List> = + listOf(UCallExpression::class.java) + + override fun createUastHandler(context: JavaContext) = + object : UElementHandler() { + override fun visitCallExpression(node: UCallExpression) { + if (node.methodName != RememberSaveableMethodName) return + val evaluator = context.evaluator + val method = node.resolve() + val returnType = node.returnType + if ( + method == null || + returnType == null || + !method.isInPackageName(RememberSaveablePackageName) + ) { + return + } + + val arguments = evaluator.computeArgumentMapping(node, method) + + // todo Check for custom saver or included saver, null is autoSaver + val saver = arguments.getSaver() + val initReturnType = arguments.getInit() + // Auto saver checks + if (saver == null && returnType.isAcceptableType()) { + return + } + val source = saver?.sourcePsi + if (saver != null && source != null) { + val resolve = context.evaluator.resolve(source) + saver.getExpressionType() + getMethodName(saver) + } + context.report(ISSUE, context.getLocation(node), ISSUE.getBriefDescription(TextFormat.TEXT)) + } + } + + companion object { + const val MESSAGE = "remember" + const val ISSUE_ID = "RememberSaveableTypeMustBeAcceptable" + const val BRIEF_DESCRIPTION = "todo" + const val EXPLANATION = """todo""" + + val ISSUE: Issue = + Issue.create( + ISSUE_ID, + BRIEF_DESCRIPTION, + EXPLANATION, + Category.CORRECTNESS, + 10, + Severity.ERROR, + sourceImplementation(), + ) + } +} + +private fun Map.getSaver(): UExpression? { + val saver = firstNotNullOfOrNull { (expression, parameter) -> + if (parameter.name == "saver") { + expression + } else null + } + val resolved = saver?.tryResolve() + if ( + resolved is PsiMethod && + resolved.name == "autoSaver" && + resolved.isInPackageName(RememberSaveablePackageName) + ) { + return null + } + return saver +} + +private fun Map.getInit(): UExpression? { + val init = firstNotNullOfOrNull { (expression, parameter) -> + if (parameter.name == "init") { + expression + } else null + } + resolveLambdaType(init) + // return valueArguments.filterIsInstance().find { arg -> + // arg.getArgumentName()?.referenceExpression?.getReferencedName() == "init" + // } + return init +} + +private fun resolveLambdaType(init: UExpression?) { + if (init is ULambdaExpression) { + val body = init.body + when (body) { + is UBlockExpression -> { + val returnTypes = mutableListOf() + for (expresssion in body.expressions) { + if (expresssion is UReturnExpression) { + val returnExpression = expresssion.returnExpression + if (returnExpression is UCallExpression) { + val resolved = returnExpression.resolve() + if (resolved is KtLightMethod) { + resolved.body + } + } + // todo Handle objects + } + } + } + } + } +} + +private fun PsiType.isAcceptableType(): Boolean { + return when (this) { + is PsiPrimitiveType -> true + is PsiArrayType -> componentType.isAcceptableType() + is PsiClassType -> { + val resolved = resolve() ?: return true // Can't resolve class type treat as acceptable? + resolved.isAcceptableClassType() && parameters.all { it.isAcceptableType() } + } + + else -> false + } +} + +fun PsiClass.isAcceptableClassType(): Boolean { + return isBoxedPrimitive() || AcceptableClasses.any { implements(it) } +} + +// From +// https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L108 +/* +/** Checks that [value] can be stored inside [Bundle]. */ +private fun canBeSavedToBundle(value: Any): Boolean { + // SnapshotMutableStateImpl is Parcelable, but we do extra checks + if (value is SnapshotMutableState<*>) { + if ( + value.policy === neverEqualPolicy() || + value.policy === structuralEqualityPolicy() || + value.policy === referentialEqualityPolicy() + ) { + val stateValue = value.value + return if (stateValue == null) true else canBeSavedToBundle(stateValue) + } else { + return false + } + } + // lambdas in Kotlin implement Serializable, but will crash if you really try to save them. + // we check for both Function and Serializable (see kotlin.jvm.internal.Lambda) to support + // custom user defined classes implementing Function interface. + if (value is Function<*> && value is Serializable) { + return false + } + for (cl in AcceptableClasses) { + if (cl.isInstance(value)) { + return true + } + } + return false +} + */ + +// From +// https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L151 +/** + * Contains Classes which can be stored inside [Bundle]. + * + * Some of the classes are not added separately because: + * + * This classes implement Serializable: + * - Arrays (DoubleArray, BooleanArray, IntArray, LongArray, ByteArray, FloatArray, ShortArray, + * CharArray, Array) + * - ArrayList + * - Primitives (Boolean, Int, Long, Double, Float, Byte, Short, Char) will be boxed when casted to + * Any, and all the boxed classes implements Serializable. This class implements Parcelable: + * - Bundle + * + * Note: it is simplified copy of the array from SavedStateHandle (lifecycle-viewmodel-savedstate). + */ +private val AcceptableClasses = + arrayOf( + "java.io.Serializable", + "android.os.Parcelable", + "java.lang.String", + "android.util.SparseArray", + "android.os.Binder", + "android.util.Size", + "android.util.SizeF", + ) diff --git a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt new file mode 100644 index 00000000..5cd8a16c --- /dev/null +++ b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt @@ -0,0 +1,274 @@ +// Copyright (C) 2025 Slack Technologies, LLC +// SPDX-License-Identifier: Apache-2.0 +package slack.lint.compose + +import org.junit.Test +import slack.lint.BaseSlackLintTest + +class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { + + private val javaIo = + java( + """ + package java.io; + + public interface Serializable { + } + + """ + .trimIndent() + ) + .indented() + + private val javaUtil = + java( + """ + package java.util; + + import java.io.Serializable; + + public class ArrayList implements Serializable {} + + """ + .trimIndent() + ) + .indented() + + private val androidOs = + java( + """ + package android.os; + + public interface Parcelable {} + + public final class Bundle implements Parcelable {} + + + """ + .trimIndent() + ) + + private val androidUtil = + java( + """ + package android.util; + + public class SparseArray {} + public final class Size {} + public final class SizeF implements Parcelable {} + + """ + .trimIndent() + ) + + private val composeRuntime = + kotlin( + """ + package androidx.compose.runtime + + annotation class Composable + + annotation class Stable + + interface SnapshotMutationPolicy + + interface State { + val value: T + } + + interface MutableState : State { + override var value: T + } + + fun mutableStateOf( + value: T, + policy: SnapshotMutationPolicy = structuralEqualityPolicy(), + ): MutableState = createSnapshotMutableState(value, policy) + + fun createSnapshotMutableState( + value: T, + policy: SnapshotMutationPolicy, + ): SnapshotMutableState = ParcelableSnapshotMutableState(value, policy) + + private class ParcelableSnapshotMutableState( + val value: T, + val policy: SnapshotMutationPolicy, + ) : Parcelable + + fun mutableIntStateOf(value: Int): MutableIntState = createSnapshotMutableIntState(value) + + fun createSnapshotMutableIntState(value: Int): MutableIntState = + ParcelableSnapshotMutableIntState(value) + + private class ParcelableSnapshotMutableIntState(val value: Int) : Parcelable + + interface IntState : State { + override val value: Int + get() = intValue + + val intValue: Int + } + + interface MutableIntState : IntState, MutableState { + override var value: Int + override var intValue: Int + } + + fun mutableFloatStateOf(value: Float): MutableFloatState = TODO() + + interface MutableFloatState : FloatState, MutableState { + override var value: Float + get() = floatValue + set(value) { + floatValue = value + } + + override var floatValue: Float + } + + interface FloatState : State { + override val value: Float + get() = floatValue + + val floatValue: Float + } + + """ + .trimIndent() + ) + .indented() + + private val composeSaveable = + kotlin( + """ + package androidx.compose.runtime.saveable + + interface Saver { + fun save(value: Original): Saveable? + fun restore(value: Saveable): Original? + } + + fun autoSaver(): Saver = TODO() + + @Composable + fun rememberSaveable( + vararg inputs: Any?, + saver: Saver = autoSaver(), + key: String? = null, + init: () -> T + ): T = TODO() + + @Composable + fun rememberSaveable( + vararg inputs: Any?, + stateSaver: Saver, + key: String? = null, + init: () -> MutableState + ): MutableState = TODO() + + """ + .trimIndent() + ) + .indented() + + private val testSaveables = + kotlin( + """ + package test + + import android.os.Parcelable + import java.io.Serializable + + class TestSerializable : Serializable + class TestParcelable : Parcelable + """ + .trimIndent() + ) + .indented() + + override fun getDetector() = RememberSaveableAcceptableDetector() + + override fun getIssues() = listOf(RememberSaveableAcceptableDetector.ISSUE) + + @Test + fun acceptableTypes_autoSaver_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + import androidx.compose.runtime.saveable.autoSaver + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.mutableIntStateOf + import androidx.compose.runtime.mutableFloatStateOf + import java.util.ArrayList + + @Composable + fun TestComposable() { + + val mutableStateLabeled = rememberSaveable(saver = autoSaver(), init = { mutableStateOf("value") }) + // Acceptable primitive types + val stringValue = rememberSaveable { "test" } + val intValue = rememberSaveable { 42 } + val booleanValue = rememberSaveable { true } + val floatValue = rememberSaveable { 1.0f } + val doubleValue = rememberSaveable { 1.0 } + val longValue = rememberSaveable { 1L } + + // Acceptable array types + val stringArray = rememberSaveable { arrayOf("test") } + val intArray = rememberSaveable { intArrayOf(42) } + val booleanArray = rememberSaveable { booleanArrayOf(true) } + val floatArray = rememberSaveable { floatArrayOf(1.0f) } + val doubleArray = rememberSaveable { doubleArrayOf(1.0) } + val longArray = rememberSaveable { longArrayOf(1L) } + val parcelableArray = rememberSaveable { arrayOf() } + + // Acceptable class types + val serializableValue = rememberSaveable { TestSerializable() } + val parcelableValue = rememberSaveable { TestParcelable() } + val arrayListValue = rememberSaveable { ArrayList() } + + // Nullable acceptable types + val nullableString = rememberSaveable { null } + val nullableInt = rememberSaveable { null } + val nullableBoolean = rememberSaveable { null } + + // Mutable state types + // Check policy and internal type + val mutableStateValue = rememberSaveable { mutableStateOf("value") } + val mutableIntStateValue = rememberSaveable { mutableIntStateOf(1) } + val mutableFloatStateValue = rememberSaveable { mutableFloatStateOf(1f) } + + // Acceptable collections with acceptable types + + // Kotlin promises listOf is serializable on jvm + val listValue = rememberSaveable { listOf("test") } + // Fail without map saver + val mapValue = rememberSaveable { mapOf("key" to "value") } + // Check policy and internal type + val mutableStateValue = rememberSaveable { mutableStateOf("value") } + + } + """ + .trimIndent() + ) + .indented() + + lint() + .files( + javaIo, + javaUtil, + androidOs, + androidUtil, + composeRuntime, + composeSaveable, + testSaveables, + source, + ) + .run() + .expectClean() + } +} From 9aee9eec7099e4a3b75b314cfadf371cfbee97e1 Mon Sep 17 00:00:00 2001 From: Josh Stagg Date: Thu, 31 Jul 2025 11:09:11 -0700 Subject: [PATCH 02/10] Check mutable state returns --- .../RememberSaveableAcceptableDetector.kt | 127 +++++-- .../RememberSaveableAcceptableDetectorTest.kt | 317 +++++++++--------- 2 files changed, 248 insertions(+), 196 deletions(-) diff --git a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt index 42b06110..c5b46bb8 100644 --- a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt +++ b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt @@ -12,6 +12,7 @@ import com.android.tools.lint.detector.api.Severity import com.android.tools.lint.detector.api.SourceCodeScanner import com.android.tools.lint.detector.api.TextFormat import com.android.tools.lint.detector.api.asCall +import com.android.tools.lint.detector.api.isUnconditionalReturn import com.intellij.psi.PsiArrayType import com.intellij.psi.PsiClass import com.intellij.psi.PsiClassType @@ -19,24 +20,25 @@ import com.intellij.psi.PsiMethod import com.intellij.psi.PsiParameter import com.intellij.psi.PsiPrimitiveType import com.intellij.psi.PsiType -import org.jetbrains.kotlin.asJava.elements.KtLightMember -import org.jetbrains.kotlin.asJava.elements.KtLightMethod import org.jetbrains.uast.UBlockExpression import org.jetbrains.uast.UCallExpression import org.jetbrains.uast.UElement import org.jetbrains.uast.UExpression import org.jetbrains.uast.ULambdaExpression import org.jetbrains.uast.UReturnExpression -import org.jetbrains.uast.kotlin.KotlinUImplicitReturnExpression +import org.jetbrains.uast.UastEmptyExpression import org.jetbrains.uast.tryResolve +import org.jetbrains.uast.visitor.UastVisitor import slack.lint.util.Package import slack.lint.util.implements import slack.lint.util.isBoxedPrimitive import slack.lint.util.isInPackageName import slack.lint.util.sourceImplementation +private val ComposeRuntimePackageName = Package("androidx.compose.runtime") private val RememberSaveablePackageName = Package("androidx.compose.runtime.saveable") private const val RememberSaveableMethodName = "rememberSaveable" +private val AUTO_SAVER = UastEmptyExpression(null) // todo Rewrite this so its checking this instead // Android only source set @@ -50,6 +52,7 @@ class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { override fun createUastHandler(context: JavaContext) = object : UElementHandler() { + @Suppress("ReturnCount") override fun visitCallExpression(node: UCallExpression) { if (node.methodName != RememberSaveableMethodName) return val evaluator = context.evaluator @@ -64,16 +67,20 @@ class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { } val arguments = evaluator.computeArgumentMapping(node, method) - - // todo Check for custom saver or included saver, null is autoSaver val saver = arguments.getSaver() - val initReturnType = arguments.getInit() - // Auto saver checks - if (saver == null && returnType.isAcceptableType()) { + // With an auto saver return the check + if (saver == AUTO_SAVER && returnType.isAcceptableType()) { + return + } + // If there is no init expression just return. + val init = arguments.getInit() ?: return + val returnExpressions = resolveReturns(init) + if (returnsKnownMutableState(returnExpressions)) { return } - val source = saver?.sourcePsi - if (saver != null && source != null) { + + val source = saver.sourcePsi + if (source != null) { val resolve = context.evaluator.resolve(source) saver.getExpressionType() getMethodName(saver) @@ -101,11 +108,9 @@ class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { } } -private fun Map.getSaver(): UExpression? { +private fun Map.getSaver(): UExpression { val saver = firstNotNullOfOrNull { (expression, parameter) -> - if (parameter.name == "saver") { - expression - } else null + if (parameter.name == "saver") expression else null } val resolved = saver?.tryResolve() if ( @@ -113,9 +118,9 @@ private fun Map.getSaver(): UExpression? { resolved.name == "autoSaver" && resolved.isInPackageName(RememberSaveablePackageName) ) { - return null + return AUTO_SAVER } - return saver + return saver ?: AUTO_SAVER } private fun Map.getInit(): UExpression? { @@ -124,36 +129,82 @@ private fun Map.getInit(): UExpression? { expression } else null } - resolveLambdaType(init) - // return valueArguments.filterIsInstance().find { arg -> - // arg.getArgumentName()?.referenceExpression?.getReferencedName() == "init" - // } return init } -private fun resolveLambdaType(init: UExpression?) { +private fun resolveReturns(init: UExpression?): List { + val returnExpressions = mutableListOf() if (init is ULambdaExpression) { val body = init.body when (body) { + // todo Handle function calls? is UBlockExpression -> { - val returnTypes = mutableListOf() - for (expresssion in body.expressions) { - if (expresssion is UReturnExpression) { - val returnExpression = expresssion.returnExpression - if (returnExpression is UCallExpression) { - val resolved = returnExpression.resolve() - if (resolved is KtLightMethod) { - resolved.body - } - } - // todo Handle objects + for (expression in body.expressions) { + // todo What are the other cases? + if (expression is UReturnExpression) { + returnExpressions += expression } } } } } + return returnExpressions +} + +private fun returnsKnownMutableState(returnExpressions: List): Boolean { + var areAllAcceptableMutableStates = false + for (returnExpression in returnExpressions) { + areAllAcceptableMutableStates = + returnExpression.isUnconditionalReturn() && + MutableStateVisitor().run { + returnExpression.accept(this) + hasAcceptableMutableState + } + if (!areAllAcceptableMutableStates) break + } + return areAllAcceptableMutableStates +} + +private class MutableStateVisitor : UastVisitor { + + private var visitedAcceptableMutableState: Boolean = false + var hasAcceptableMutableState = false + private set + + override fun afterVisitReturnExpression(node: UReturnExpression) { + super.afterVisitReturnExpression(node) + if (visitedAcceptableMutableState) { + hasAcceptableMutableState = + node.returnExpression?.getExpressionType()?.asClass()?.parameters?.all { + it.isAcceptableType() + } == true + } + } + + override fun visitElement(node: UElement): Boolean { + val resolved = node.tryResolve() + val isKnownMutableState = + resolved is PsiMethod && + resolved.isInPackageName(ComposeRuntimePackageName) && + resolved.name in AcceptableMutableStateMethods + if (isKnownMutableState) { + // Check the type in mutableStateOf() + if (resolved.name == "mutableStateOf") { + val mutableClass = node.asCall()?.returnType?.asClass() + mutableClass + ?.parameters + ?.all { it.isAcceptableType() } + ?.let { visitedAcceptableMutableState = it } + } else { + visitedAcceptableMutableState = true + } + } + return isKnownMutableState + } } +private fun PsiType?.asClass(): PsiClassType? = this as? PsiClassType + private fun PsiType.isAcceptableType(): Boolean { return when (this) { is PsiPrimitiveType -> true @@ -162,7 +213,6 @@ private fun PsiType.isAcceptableType(): Boolean { val resolved = resolve() ?: return true // Can't resolve class type treat as acceptable? resolved.isAcceptableClassType() && parameters.all { it.isAcceptableType() } } - else -> false } } @@ -222,7 +272,7 @@ private fun canBeSavedToBundle(value: Any): Boolean { * Note: it is simplified copy of the array from SavedStateHandle (lifecycle-viewmodel-savedstate). */ private val AcceptableClasses = - arrayOf( + setOf( "java.io.Serializable", "android.os.Parcelable", "java.lang.String", @@ -231,3 +281,12 @@ private val AcceptableClasses = "android.util.Size", "android.util.SizeF", ) + +private val AcceptableMutableStateMethods = + setOf( + "mutableStateOf", + "mutableIntStateOf", + "mutableFloatStateOf", + "mutableDoubleStateOf", + "mutableLongStateOf", + ) diff --git a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt index 5cd8a16c..21d93020 100644 --- a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt +++ b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt @@ -2,27 +2,128 @@ // SPDX-License-Identifier: Apache-2.0 package slack.lint.compose +import com.android.tools.lint.checks.infrastructure.LintDetectorTest.java +import com.android.tools.lint.checks.infrastructure.LintDetectorTest.kotlin import org.junit.Test import slack.lint.BaseSlackLintTest class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { - private val javaIo = - java( - """ + override fun getDetector() = RememberSaveableAcceptableDetector() + + override fun getIssues() = listOf(RememberSaveableAcceptableDetector.ISSUE) + + @Test + fun acceptableTypes_autoSaver_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.mutableDoubleStateOf + import androidx.compose.runtime.mutableFloatStateOf + import androidx.compose.runtime.mutableIntStateOf + import androidx.compose.runtime.mutableLongStateOf + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.saveable.autoSaver + import androidx.compose.runtime.saveable.rememberSaveable + import java.util.ArrayList + + @Composable + fun TestComposable() { + // Acceptable primitive types + val stringValue = rememberSaveable { "test" } + val intValue = rememberSaveable { 42 } + val booleanValue = rememberSaveable { true } + val floatValue = rememberSaveable { 1.0f } + val doubleValue = rememberSaveable { 1.0 } + val longValue = rememberSaveable { 1L } + + // Acceptable array types + val stringArray = rememberSaveable { arrayOf("test") } + val intArray = rememberSaveable { intArrayOf(42) } + val booleanArray = rememberSaveable { booleanArrayOf(true) } + val floatArray = rememberSaveable { floatArrayOf(1.0f) } + val doubleArray = rememberSaveable { doubleArrayOf(1.0) } + val longArray = rememberSaveable { longArrayOf(1L) } + val parcelableArray = rememberSaveable { arrayOf() } + + // Acceptable class types + val serializableValue = rememberSaveable { TestSerializable() } + val parcelableValue = rememberSaveable { TestParcelable() } + val arrayListValue = rememberSaveable { ArrayList() } + + // Nullable acceptable types + val nullableString = rememberSaveable { null } + val nullableInt = rememberSaveable { null } + val nullableBoolean = rememberSaveable { null } + + // Mutable state types + // Check policy and internal type? + val mutableState = rememberSaveable { mutableStateOf("value") } + val mutableIntState = rememberSaveable { mutableIntStateOf(1) } + val mutableFloatState = rememberSaveable { mutableFloatStateOf(1f) } + val mutableDoubleState = rememberSaveable { mutableDoubleStateOf(1.0) } + val mutableLongState = rememberSaveable { mutableLongStateOf(1L) } + val mutableTestParcelableState = rememberSaveable { + mutableStateOf(TestParcelable()).apply { value = TestParcelable() } + } + val mutableTestSerializableState = rememberSaveable { + mutableStateOf(TestSerializable()).apply { value = TestSerializable() } + } + + // AutoSaver is specified + val mutableStateLabeled = + rememberSaveable(saver = autoSaver(), init = { mutableStateOf("value") }) + + // Acceptable collections with acceptable types + + // Kotlin promises listOf is serializable on jvm + val listValue = rememberSaveable { listOf("test") } + // Fail without map saver + val mapValue = rememberSaveable { mapOf("key" to "value") } + // Check policy and internal type + val mutableStateValue = rememberSaveable { mutableStateOf("value") } +} + + """ + .trimIndent() + ) + .indented() + + lint() + .files( + JAVA_IO, + JAVA_UTIL, + ANDROID_OS, + ANDROID_UTIL, + COMPOSE_RUNTIME, + COMPOSE_SAVEABLE, + TEST_SAVEABLES, + source, + ) + .run() + .expectClean() + } +} + +private val JAVA_IO = + java( + """ package java.io; public interface Serializable { } """ - .trimIndent() - ) - .indented() + .trimIndent() + ) + .indented() - private val javaUtil = - java( - """ +private val JAVA_UTIL = + java( + """ package java.util; import java.io.Serializable; @@ -30,13 +131,13 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { public class ArrayList implements Serializable {} """ - .trimIndent() - ) - .indented() + .trimIndent() + ) + .indented() - private val androidOs = - java( - """ +private val ANDROID_OS = + java( + """ package android.os; public interface Parcelable {} @@ -45,12 +146,12 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { """ - .trimIndent() - ) + .trimIndent() + ) - private val androidUtil = - java( - """ +private val ANDROID_UTIL = + java( + """ package android.util; public class SparseArray {} @@ -58,12 +159,12 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { public final class SizeF implements Parcelable {} """ - .trimIndent() - ) + .trimIndent() + ) - private val composeRuntime = - kotlin( - """ +private val COMPOSE_RUNTIME = + kotlin( + """ package androidx.compose.runtime annotation class Composable @@ -83,64 +184,43 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { fun mutableStateOf( value: T, policy: SnapshotMutationPolicy = structuralEqualityPolicy(), - ): MutableState = createSnapshotMutableState(value, policy) - - fun createSnapshotMutableState( - value: T, - policy: SnapshotMutationPolicy, - ): SnapshotMutableState = ParcelableSnapshotMutableState(value, policy) - - private class ParcelableSnapshotMutableState( - val value: T, - val policy: SnapshotMutationPolicy, - ) : Parcelable - - fun mutableIntStateOf(value: Int): MutableIntState = createSnapshotMutableIntState(value) - - fun createSnapshotMutableIntState(value: Int): MutableIntState = - ParcelableSnapshotMutableIntState(value) - - private class ParcelableSnapshotMutableIntState(val value: Int) : Parcelable - - interface IntState : State { - override val value: Int - get() = intValue - - val intValue: Int - } + ): MutableState = TODO() interface MutableIntState : IntState, MutableState { override var value: Int override var intValue: Int } - fun mutableFloatStateOf(value: Float): MutableFloatState = TODO() + fun mutableIntStateOf(value: Int): MutableIntState = TODO() interface MutableFloatState : FloatState, MutableState { override var value: Float - get() = floatValue - set(value) { - floatValue = value - } - override var floatValue: Float } - interface FloatState : State { - override val value: Float - get() = floatValue + fun mutableFloatStateOf(value: Float): MutableFloatState = TODO() + + interface MutableDoubleState : MutableState { + override var value: Double + override var doubleValue: Double + } + + fun mutableDoubleStateOf(value: Double): MutableDoubleState = TODO() - val floatValue: Float + interface MutableLongState : MutableState { + override var value: Long + override var longValue: Long } + fun mutableLongStateOf(value: Long): MutableLongState = TODO() """ - .trimIndent() - ) - .indented() + .trimIndent() + ) + .indented() - private val composeSaveable = - kotlin( - """ +private val COMPOSE_SAVEABLE = + kotlin( + """ package androidx.compose.runtime.saveable interface Saver { @@ -167,13 +247,13 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { ): MutableState = TODO() """ - .trimIndent() - ) - .indented() + .trimIndent() + ) + .indented() - private val testSaveables = - kotlin( - """ +private val TEST_SAVEABLES = + kotlin( + """ package test import android.os.Parcelable @@ -182,93 +262,6 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { class TestSerializable : Serializable class TestParcelable : Parcelable """ - .trimIndent() - ) - .indented() - - override fun getDetector() = RememberSaveableAcceptableDetector() - - override fun getIssues() = listOf(RememberSaveableAcceptableDetector.ISSUE) - - @Test - fun acceptableTypes_autoSaver_noError() { - val source = - kotlin( - """ - package test - - import androidx.compose.runtime.Composable - import androidx.compose.runtime.saveable.rememberSaveable - import androidx.compose.runtime.saveable.autoSaver - import androidx.compose.runtime.mutableStateOf - import androidx.compose.runtime.mutableIntStateOf - import androidx.compose.runtime.mutableFloatStateOf - import java.util.ArrayList - - @Composable - fun TestComposable() { - - val mutableStateLabeled = rememberSaveable(saver = autoSaver(), init = { mutableStateOf("value") }) - // Acceptable primitive types - val stringValue = rememberSaveable { "test" } - val intValue = rememberSaveable { 42 } - val booleanValue = rememberSaveable { true } - val floatValue = rememberSaveable { 1.0f } - val doubleValue = rememberSaveable { 1.0 } - val longValue = rememberSaveable { 1L } - - // Acceptable array types - val stringArray = rememberSaveable { arrayOf("test") } - val intArray = rememberSaveable { intArrayOf(42) } - val booleanArray = rememberSaveable { booleanArrayOf(true) } - val floatArray = rememberSaveable { floatArrayOf(1.0f) } - val doubleArray = rememberSaveable { doubleArrayOf(1.0) } - val longArray = rememberSaveable { longArrayOf(1L) } - val parcelableArray = rememberSaveable { arrayOf() } - - // Acceptable class types - val serializableValue = rememberSaveable { TestSerializable() } - val parcelableValue = rememberSaveable { TestParcelable() } - val arrayListValue = rememberSaveable { ArrayList() } - - // Nullable acceptable types - val nullableString = rememberSaveable { null } - val nullableInt = rememberSaveable { null } - val nullableBoolean = rememberSaveable { null } - - // Mutable state types - // Check policy and internal type - val mutableStateValue = rememberSaveable { mutableStateOf("value") } - val mutableIntStateValue = rememberSaveable { mutableIntStateOf(1) } - val mutableFloatStateValue = rememberSaveable { mutableFloatStateOf(1f) } - - // Acceptable collections with acceptable types - - // Kotlin promises listOf is serializable on jvm - val listValue = rememberSaveable { listOf("test") } - // Fail without map saver - val mapValue = rememberSaveable { mapOf("key" to "value") } - // Check policy and internal type - val mutableStateValue = rememberSaveable { mutableStateOf("value") } - - } - """ - .trimIndent() - ) - .indented() - - lint() - .files( - javaIo, - javaUtil, - androidOs, - androidUtil, - composeRuntime, - composeSaveable, - testSaveables, - source, - ) - .run() - .expectClean() - } -} + .trimIndent() + ) + .indented() From 34f3f1759a74f0254c955484e1998bd3eed97001 Mon Sep 17 00:00:00 2001 From: Josh Stagg Date: Wed, 6 Aug 2025 19:48:34 -0700 Subject: [PATCH 03/10] Better `mutableStateOf` checking --- .../RememberSaveableAcceptableDetector.kt | 163 ++++++++++-------- .../RememberSaveableAcceptableDetectorTest.kt | 146 +++++++++------- 2 files changed, 166 insertions(+), 143 deletions(-) diff --git a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt index c5b46bb8..dd96c0dc 100644 --- a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt +++ b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt @@ -2,17 +2,16 @@ // SPDX-License-Identifier: Apache-2.0 package slack.lint.compose +import com.android.tools.lint.checks.DataFlowAnalyzer import com.android.tools.lint.client.api.UElementHandler import com.android.tools.lint.detector.api.Category import com.android.tools.lint.detector.api.Detector import com.android.tools.lint.detector.api.Issue import com.android.tools.lint.detector.api.JavaContext -import com.android.tools.lint.detector.api.JavaContext.Companion.getMethodName import com.android.tools.lint.detector.api.Severity import com.android.tools.lint.detector.api.SourceCodeScanner import com.android.tools.lint.detector.api.TextFormat import com.android.tools.lint.detector.api.asCall -import com.android.tools.lint.detector.api.isUnconditionalReturn import com.intellij.psi.PsiArrayType import com.intellij.psi.PsiClass import com.intellij.psi.PsiClassType @@ -20,15 +19,13 @@ import com.intellij.psi.PsiMethod import com.intellij.psi.PsiParameter import com.intellij.psi.PsiPrimitiveType import com.intellij.psi.PsiType -import org.jetbrains.uast.UBlockExpression import org.jetbrains.uast.UCallExpression import org.jetbrains.uast.UElement import org.jetbrains.uast.UExpression -import org.jetbrains.uast.ULambdaExpression import org.jetbrains.uast.UReturnExpression import org.jetbrains.uast.UastEmptyExpression import org.jetbrains.uast.tryResolve -import org.jetbrains.uast.visitor.UastVisitor +import org.jetbrains.uast.visitor.AbstractUastVisitor import slack.lint.util.Package import slack.lint.util.implements import slack.lint.util.isBoxedPrimitive @@ -68,22 +65,25 @@ class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { val arguments = evaluator.computeArgumentMapping(node, method) val saver = arguments.getSaver() - // With an auto saver return the check + // With an auto saver check the return type. if (saver == AUTO_SAVER && returnType.isAcceptableType()) { return } // If there is no init expression just return. val init = arguments.getInit() ?: return - val returnExpressions = resolveReturns(init) - if (returnsKnownMutableState(returnExpressions)) { + // Check whats created in the init expression. + if (returnsKnownMutableState(returnType, init)) { + // Found a known parcelable mutable state. return } - - val source = saver.sourcePsi - if (source != null) { - val resolve = context.evaluator.resolve(source) - saver.getExpressionType() - getMethodName(saver) + if (returnsLambdaExpression(returnType, init)) { + // todo Report specific error language about kotlin Lambdas + context.report( + ISSUE, + context.getLocation(node), + ISSUE.getBriefDescription(TextFormat.TEXT) + " (LAMBDA)", + ) + return } context.report(ISSUE, context.getLocation(node), ISSUE.getBriefDescription(TextFormat.TEXT)) } @@ -92,8 +92,8 @@ class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { companion object { const val MESSAGE = "remember" const val ISSUE_ID = "RememberSaveableTypeMustBeAcceptable" - const val BRIEF_DESCRIPTION = "todo" - const val EXPLANATION = """todo""" + const val BRIEF_DESCRIPTION = "Brief description" + const val EXPLANATION = """Full explanation""" val ISSUE: Issue = Issue.create( @@ -132,77 +132,88 @@ private fun Map.getInit(): UExpression? { return init } -private fun resolveReturns(init: UExpression?): List { - val returnExpressions = mutableListOf() - if (init is ULambdaExpression) { - val body = init.body - when (body) { - // todo Handle function calls? - is UBlockExpression -> { - for (expression in body.expressions) { - // todo What are the other cases? - if (expression is UReturnExpression) { - returnExpressions += expression - } +/** + * The default Android `mutableStateOf` is `ParcelableSnapshotMutableState`, so check if all the + * return statements are default `mutableStateOf` calls. + */ +private fun returnsKnownMutableState(returnType: PsiType, expression: UExpression): Boolean { + return if (returnType.isAcceptableMutableStateClass()) { + val visitor = MutableStateOfVisitor() + expression.accept(visitor) + val allReturned = + visitor.mutableStateOfs.isNotEmpty() && + visitor.mutableStateOfs.all { tracked -> + val returnsTracker = ReturnsTracker(tracked) + expression.accept(returnsTracker) + returnsTracker.returned } - } - } - } - return returnExpressions + allReturned + } else false } -private fun returnsKnownMutableState(returnExpressions: List): Boolean { - var areAllAcceptableMutableStates = false - for (returnExpression in returnExpressions) { - areAllAcceptableMutableStates = - returnExpression.isUnconditionalReturn() && - MutableStateVisitor().run { - returnExpression.accept(this) - hasAcceptableMutableState - } - if (!areAllAcceptableMutableStates) break +private fun PsiType?.isAcceptableMutableStateClass(): Boolean { + val psiClassType = asClass() + val isMutableState = + psiClassType + ?.resolve() + ?.implements("${ComposeRuntimePackageName.javaPackageName}.MutableState") == true + return isMutableState && psiClassType.parameters.all { it.isAcceptableType() } +} + +private fun isKnownMutableStateFunction(node: UElement): Boolean { + val resolved = node.tryResolve() + val isKnownMutableState = + resolved is PsiMethod && + resolved.isInPackageName(ComposeRuntimePackageName) && + resolved.name in AcceptableMutableStateMethods + if (isKnownMutableState) { + // Check the type of mutableStateOf() + if (resolved.name == "mutableStateOf") { + val mutableClass = node.asCall()?.returnType?.asClass() + mutableClass?.parameters?.all { it.isAcceptableType() } == true + } else { + // Known mutable[Primitive]StateOf() + true + } } - return areAllAcceptableMutableStates + return isKnownMutableState } -private class MutableStateVisitor : UastVisitor { +private class MutableStateOfVisitor : AbstractUastVisitor() { - private var visitedAcceptableMutableState: Boolean = false - var hasAcceptableMutableState = false - private set + val mutableStateOfs = mutableListOf() - override fun afterVisitReturnExpression(node: UReturnExpression) { - super.afterVisitReturnExpression(node) - if (visitedAcceptableMutableState) { - hasAcceptableMutableState = - node.returnExpression?.getExpressionType()?.asClass()?.parameters?.all { - it.isAcceptableType() - } == true - } - } + override fun visitCallExpression(node: UCallExpression): Boolean = + if (isKnownMutableStateFunction(node)) { + mutableStateOfs.add(node) + true + } else false +} - override fun visitElement(node: UElement): Boolean { - val resolved = node.tryResolve() - val isKnownMutableState = - resolved is PsiMethod && - resolved.isInPackageName(ComposeRuntimePackageName) && - resolved.name in AcceptableMutableStateMethods - if (isKnownMutableState) { - // Check the type in mutableStateOf() - if (resolved.name == "mutableStateOf") { - val mutableClass = node.asCall()?.returnType?.asClass() - mutableClass - ?.parameters - ?.all { it.isAcceptableType() } - ?.let { visitedAcceptableMutableState = it } - } else { - visitedAcceptableMutableState = true - } - } - return isKnownMutableState +private class ReturnsTracker(tracked: UElement, var returned: Boolean = false) : + DataFlowAnalyzer(listOf(tracked)) { + override fun returns(expression: UReturnExpression) { + returned = true } } +/** + * > Lambdas in Kotlin implement Serializable, but will crash if you really try to save them. We + * > check for both Function and Serializable (see kotlin.jvm.internal.Lambda) to support custom + * > user defined classes implementing Function interface. + * - From: + * https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L122-L124 + */ +private fun returnsLambdaExpression(returnType: PsiType, expression: UExpression): Boolean { + val isFunction = returnType.asClass()?.resolve()?.implements("kotlin.Function") == true + // todo Find all the function returns and see if any are lambda expressions + val visitor = ReturnsLambdaVisitor() + expression.accept(visitor) + return isFunction // && visitor.returned +} + +private class ReturnsLambdaVisitor(var returned: Boolean = false) : AbstractUastVisitor() {} + private fun PsiType?.asClass(): PsiClassType? = this as? PsiClassType private fun PsiType.isAcceptableType(): Boolean { @@ -217,7 +228,7 @@ private fun PsiType.isAcceptableType(): Boolean { } } -fun PsiClass.isAcceptableClassType(): Boolean { +private fun PsiClass.isAcceptableClassType(): Boolean { return isBoxedPrimitive() || AcceptableClasses.any { implements(it) } } diff --git a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt index 21d93020..f3e3f950 100644 --- a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt +++ b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt @@ -18,73 +18,85 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { val source = kotlin( """ - package test - - import androidx.compose.runtime.Composable - import androidx.compose.runtime.mutableDoubleStateOf - import androidx.compose.runtime.mutableFloatStateOf - import androidx.compose.runtime.mutableIntStateOf - import androidx.compose.runtime.mutableLongStateOf - import androidx.compose.runtime.mutableStateOf - import androidx.compose.runtime.saveable.autoSaver - import androidx.compose.runtime.saveable.rememberSaveable - import java.util.ArrayList - - @Composable - fun TestComposable() { - // Acceptable primitive types - val stringValue = rememberSaveable { "test" } - val intValue = rememberSaveable { 42 } - val booleanValue = rememberSaveable { true } - val floatValue = rememberSaveable { 1.0f } - val doubleValue = rememberSaveable { 1.0 } - val longValue = rememberSaveable { 1L } - - // Acceptable array types - val stringArray = rememberSaveable { arrayOf("test") } - val intArray = rememberSaveable { intArrayOf(42) } - val booleanArray = rememberSaveable { booleanArrayOf(true) } - val floatArray = rememberSaveable { floatArrayOf(1.0f) } - val doubleArray = rememberSaveable { doubleArrayOf(1.0) } - val longArray = rememberSaveable { longArrayOf(1L) } - val parcelableArray = rememberSaveable { arrayOf() } - - // Acceptable class types - val serializableValue = rememberSaveable { TestSerializable() } - val parcelableValue = rememberSaveable { TestParcelable() } - val arrayListValue = rememberSaveable { ArrayList() } - - // Nullable acceptable types - val nullableString = rememberSaveable { null } - val nullableInt = rememberSaveable { null } - val nullableBoolean = rememberSaveable { null } - - // Mutable state types - // Check policy and internal type? - val mutableState = rememberSaveable { mutableStateOf("value") } - val mutableIntState = rememberSaveable { mutableIntStateOf(1) } - val mutableFloatState = rememberSaveable { mutableFloatStateOf(1f) } - val mutableDoubleState = rememberSaveable { mutableDoubleStateOf(1.0) } - val mutableLongState = rememberSaveable { mutableLongStateOf(1L) } - val mutableTestParcelableState = rememberSaveable { - mutableStateOf(TestParcelable()).apply { value = TestParcelable() } - } - val mutableTestSerializableState = rememberSaveable { - mutableStateOf(TestSerializable()).apply { value = TestSerializable() } - } - - // AutoSaver is specified - val mutableStateLabeled = - rememberSaveable(saver = autoSaver(), init = { mutableStateOf("value") }) - - // Acceptable collections with acceptable types - - // Kotlin promises listOf is serializable on jvm - val listValue = rememberSaveable { listOf("test") } - // Fail without map saver - val mapValue = rememberSaveable { mapOf("key" to "value") } - // Check policy and internal type - val mutableStateValue = rememberSaveable { mutableStateOf("value") } +package test + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.mutableDoubleStateOf +import androidx.compose.runtime.mutableFloatStateOf +import androidx.compose.runtime.mutableIntStateOf +import androidx.compose.runtime.mutableLongStateOf +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.saveable.autoSaver +import androidx.compose.runtime.saveable.rememberSaveable + +@Composable +fun TestComposable() { + // // Acceptable primitive types + // val stringValue = rememberSaveable { "test" } + // val intValue = rememberSaveable { 42 } + // val booleanValue = rememberSaveable { true } + // val floatValue = rememberSaveable { 1.0f } + // val doubleValue = rememberSaveable { 1.0 } + // val longValue = rememberSaveable { 1L } + // + // // Acceptable array types + // val stringArray = rememberSaveable { arrayOf("test") } + // val intArray = rememberSaveable { intArrayOf(42) } + // val booleanArray = rememberSaveable { booleanArrayOf(true) } + // val floatArray = rememberSaveable { floatArrayOf(1.0f) } + // val doubleArray = rememberSaveable { doubleArrayOf(1.0) } + // val longArray = rememberSaveable { longArrayOf(1L) } + // val parcelableArray = rememberSaveable { arrayOf() } + // + // // Acceptable class types + // val serializableValue = rememberSaveable { TestSerializable() } + // val parcelableValue = rememberSaveable { TestParcelable() } + // val arrayListValue = rememberSaveable { ArrayList() } + // + // // Nullable acceptable types + // val nullableString = rememberSaveable { null } + // val nullableInt = rememberSaveable { null } + // val nullableBoolean = rememberSaveable { null } + + // Mutable state types + // Check policy and internal type? + val mutableState = rememberSaveable { mutableStateOf("value") } + val mutableIntState = rememberSaveable { mutableIntStateOf(1) } + val mutableFloatState = rememberSaveable { mutableFloatStateOf(1f) } + val mutableDoubleState = rememberSaveable { mutableDoubleStateOf(1.0) } + val mutableLongState = rememberSaveable { mutableLongStateOf(1L) } + val mutableTestParcelableState = rememberSaveable { + mutableStateOf(TestParcelable()).apply { value = TestParcelable() } + } + val mutableTestSerializableState = rememberSaveable { + mutableStateOf(TestSerializable()).apply { value = TestSerializable() } + } + + // AutoSaver is specified + val mutableStateLabeled = + rememberSaveable(saver = autoSaver(), init = { mutableStateOf("value") }) + + // Acceptable collections with acceptable types + + // Kotlin promises listOf is serializable on jvm, need to verify? + val listValue = rememberSaveable { listOf("test") } + // Fail without map saver + val mapValue = rememberSaveable { mapOf("key" to "value") } + // Check policy and internal type + val mutableStateValue = rememberSaveable { mutableStateOf("value") } + // Check lambdas + val lambdaInvokedValue = rememberSaveable { { "value" }() } + val lambdaValue = rememberSaveable { { "value" } } + val lambdaBlockValue = rememberSaveable { + val a = { "value" } + val b = { "other" } + val c = if (Random() == 42) 1 else 2 + if (c == 1) { + a + } else { + ({ "other" }) + } + } } """ From 27f00bf06464b727bee7ad921dacfcd927e860a2 Mon Sep 17 00:00:00 2001 From: Josh Stagg Date: Tue, 2 Sep 2025 17:07:01 -0700 Subject: [PATCH 04/10] split up the test --- .../RememberSaveableAcceptableDetectorTest.kt | 317 ++++++++++++++---- 1 file changed, 253 insertions(+), 64 deletions(-) diff --git a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt index f3e3f950..3a83f5d4 100644 --- a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt +++ b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt @@ -14,7 +14,163 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { override fun getIssues() = listOf(RememberSaveableAcceptableDetector.ISSUE) @Test - fun acceptableTypes_autoSaver_noError() { + fun acceptableTypes_primitives_noError() { + val source = + kotlin( + """ +package test + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.saveable.rememberSaveable + +@Composable +fun TestComposable() { + val stringValue = rememberSaveable { "test" } + val intValue = rememberSaveable { 42 } + val booleanValue = rememberSaveable { true } + val floatValue = rememberSaveable { 1.0f } + val doubleValue = rememberSaveable { 1.0 } + val longValue = rememberSaveable { 1L } +} + + """ + .trimIndent() + ) + .indented() + + lint() + .files( + JAVA_IO, + JAVA_UTIL, + ANDROID_OS, + ANDROID_UTIL, + COMPOSE_RUNTIME, + COMPOSE_SAVEABLE, + TEST_SAVEABLES, + source, + ) + .run() + .expectClean() + } + + @Test + fun acceptableTypes_arrays_noError() { + val source = + kotlin( + """ +package test + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.saveable.rememberSaveable + +@Composable +fun TestComposable() { + val stringArray = rememberSaveable { arrayOf("test") } + val intArray = rememberSaveable { intArrayOf(42) } + val booleanArray = rememberSaveable { booleanArrayOf(true) } + val floatArray = rememberSaveable { floatArrayOf(1.0f) } + val doubleArray = rememberSaveable { doubleArrayOf(1.0) } + val longArray = rememberSaveable { longArrayOf(1L) } + val parcelableArray = rememberSaveable { arrayOf() } +} + + """ + .trimIndent() + ) + .indented() + + lint() + .files( + JAVA_IO, + JAVA_UTIL, + ANDROID_OS, + ANDROID_UTIL, + COMPOSE_RUNTIME, + COMPOSE_SAVEABLE, + TEST_SAVEABLES, + source, + ) + .run() + .expectClean() + } + + @Test + fun acceptableTypes_classes_noError() { + val source = + kotlin( + """ +package test + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.saveable.rememberSaveable +import java.util.ArrayList + +@Composable +fun TestComposable() { + val serializableValue = rememberSaveable { TestSerializable() } + val parcelableValue = rememberSaveable { TestParcelable() } + val arrayListValue = rememberSaveable { ArrayList() } +} + + """ + .trimIndent() + ) + .indented() + + lint() + .files( + JAVA_IO, + JAVA_UTIL, + ANDROID_OS, + ANDROID_UTIL, + COMPOSE_RUNTIME, + COMPOSE_SAVEABLE, + TEST_SAVEABLES, + source, + ) + .run() + .expectClean() + } + + @Test + fun acceptableTypes_nullables_noError() { + val source = + kotlin( + """ +package test + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.saveable.rememberSaveable + +@Composable +fun TestComposable() { + val nullableString = rememberSaveable { null } + val nullableInt = rememberSaveable { null } + val nullableBoolean = rememberSaveable { null } +} + + """ + .trimIndent() + ) + .indented() + + lint() + .files( + JAVA_IO, + JAVA_UTIL, + ANDROID_OS, + ANDROID_UTIL, + COMPOSE_RUNTIME, + COMPOSE_SAVEABLE, + TEST_SAVEABLES, + source, + ) + .run() + .expectClean() + } + + @Test + fun acceptableTypes_mutableStates_noError() { val source = kotlin( """ @@ -26,77 +182,110 @@ import androidx.compose.runtime.mutableFloatStateOf import androidx.compose.runtime.mutableIntStateOf import androidx.compose.runtime.mutableLongStateOf import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.saveable.autoSaver import androidx.compose.runtime.saveable.rememberSaveable @Composable fun TestComposable() { - // // Acceptable primitive types - // val stringValue = rememberSaveable { "test" } - // val intValue = rememberSaveable { 42 } - // val booleanValue = rememberSaveable { true } - // val floatValue = rememberSaveable { 1.0f } - // val doubleValue = rememberSaveable { 1.0 } - // val longValue = rememberSaveable { 1L } - // - // // Acceptable array types - // val stringArray = rememberSaveable { arrayOf("test") } - // val intArray = rememberSaveable { intArrayOf(42) } - // val booleanArray = rememberSaveable { booleanArrayOf(true) } - // val floatArray = rememberSaveable { floatArrayOf(1.0f) } - // val doubleArray = rememberSaveable { doubleArrayOf(1.0) } - // val longArray = rememberSaveable { longArrayOf(1L) } - // val parcelableArray = rememberSaveable { arrayOf() } - // - // // Acceptable class types - // val serializableValue = rememberSaveable { TestSerializable() } - // val parcelableValue = rememberSaveable { TestParcelable() } - // val arrayListValue = rememberSaveable { ArrayList() } - // - // // Nullable acceptable types - // val nullableString = rememberSaveable { null } - // val nullableInt = rememberSaveable { null } - // val nullableBoolean = rememberSaveable { null } - - // Mutable state types - // Check policy and internal type? - val mutableState = rememberSaveable { mutableStateOf("value") } - val mutableIntState = rememberSaveable { mutableIntStateOf(1) } - val mutableFloatState = rememberSaveable { mutableFloatStateOf(1f) } - val mutableDoubleState = rememberSaveable { mutableDoubleStateOf(1.0) } - val mutableLongState = rememberSaveable { mutableLongStateOf(1L) } - val mutableTestParcelableState = rememberSaveable { - mutableStateOf(TestParcelable()).apply { value = TestParcelable() } + val mutableState = rememberSaveable { mutableStateOf("value") } + val mutableIntState = rememberSaveable { mutableIntStateOf(1) } + val mutableFloatState = rememberSaveable { mutableFloatStateOf(1f) } + val mutableDoubleState = rememberSaveable { mutableDoubleStateOf(1.0) } + val mutableLongState = rememberSaveable { mutableLongStateOf(1L) } + val mutableTestParcelableState = rememberSaveable { + mutableStateOf(TestParcelable()).apply { value = TestParcelable() } + } + val mutableTestSerializableState = rememberSaveable { + mutableStateOf(TestSerializable()).apply { value = TestSerializable() } + } +} + + """ + .trimIndent() + ) + .indented() + + lint() + .files( + JAVA_IO, + JAVA_UTIL, + ANDROID_OS, + ANDROID_UTIL, + COMPOSE_RUNTIME, + COMPOSE_SAVEABLE, + TEST_SAVEABLES, + source, + ) + .run() + .expectClean() } - val mutableTestSerializableState = rememberSaveable { - mutableStateOf(TestSerializable()).apply { value = TestSerializable() } + + @Test + fun acceptableTypes_autoSaver_noError() { + val source = + kotlin( + """ +package test + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.saveable.autoSaver +import androidx.compose.runtime.saveable.rememberSaveable + +@Composable +fun TestComposable() { + val mutableStateLabeled = + rememberSaveable(saver = autoSaver(), init = { mutableStateOf("value") }) +} + + """ + .trimIndent() + ) + .indented() + + lint() + .files( + JAVA_IO, + JAVA_UTIL, + ANDROID_OS, + ANDROID_UTIL, + COMPOSE_RUNTIME, + COMPOSE_SAVEABLE, + TEST_SAVEABLES, + source, + ) + .run() + .expectClean() } - // AutoSaver is specified - val mutableStateLabeled = - rememberSaveable(saver = autoSaver(), init = { mutableStateOf("value") }) - - // Acceptable collections with acceptable types - - // Kotlin promises listOf is serializable on jvm, need to verify? - val listValue = rememberSaveable { listOf("test") } - // Fail without map saver - val mapValue = rememberSaveable { mapOf("key" to "value") } - // Check policy and internal type - val mutableStateValue = rememberSaveable { mutableStateOf("value") } - // Check lambdas - val lambdaInvokedValue = rememberSaveable { { "value" }() } - val lambdaValue = rememberSaveable { { "value" } } - val lambdaBlockValue = rememberSaveable { - val a = { "value" } - val b = { "other" } - val c = if (Random() == 42) 1 else 2 - if (c == 1) { - a - } else { - ({ "other" }) + @Test + fun acceptableTypes_collectionsAndLambdas_noError() { + val source = + kotlin( + """ +package test + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.saveable.rememberSaveable +import kotlin.random.Random + +@Composable +fun TestComposable() { + val listValue = rememberSaveable { listOf("test") } + val mapValue = rememberSaveable { mapOf("key" to "value") } + val mutableStateValue = rememberSaveable { mutableStateOf("value") } + val lambdaInvokedValue = rememberSaveable { { "value" }() } + val lambdaValue = rememberSaveable { { "value" } } + val lambdaBlockValue = rememberSaveable { + val a = { "value" } + val b = { "other" } + val c = if (Random.nextInt() == 42) 1 else 2 + if (c == 1) { + a + } else { + ({ "other" }) + } } - } } """ From eea9ccb8f2300d9d07aa935a7e8a4cfcb22aebac Mon Sep 17 00:00:00 2001 From: Josh Stagg Date: Tue, 2 Sep 2025 17:45:35 -0700 Subject: [PATCH 05/10] update to check the policy --- .../RememberSaveableAcceptableDetector.kt | 71 +++++++++++--- .../RememberSaveableAcceptableDetectorTest.kt | 95 +++++++++++++++++++ 2 files changed, 151 insertions(+), 15 deletions(-) diff --git a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt index dd96c0dc..299f57a8 100644 --- a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt +++ b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt @@ -24,7 +24,9 @@ import org.jetbrains.uast.UElement import org.jetbrains.uast.UExpression import org.jetbrains.uast.UReturnExpression import org.jetbrains.uast.UastEmptyExpression +import org.jetbrains.uast.skipParenthesizedExprDown import org.jetbrains.uast.tryResolve +import org.jetbrains.uast.tryResolveNamed import org.jetbrains.uast.visitor.AbstractUastVisitor import slack.lint.util.Package import slack.lint.util.implements @@ -32,9 +34,10 @@ import slack.lint.util.isBoxedPrimitive import slack.lint.util.isInPackageName import slack.lint.util.sourceImplementation +private const val REMEMBER_SAVEABLE_METHOD_NAME = "rememberSaveable" + private val ComposeRuntimePackageName = Package("androidx.compose.runtime") private val RememberSaveablePackageName = Package("androidx.compose.runtime.saveable") -private const val RememberSaveableMethodName = "rememberSaveable" private val AUTO_SAVER = UastEmptyExpression(null) // todo Rewrite this so its checking this instead @@ -51,7 +54,7 @@ class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { object : UElementHandler() { @Suppress("ReturnCount") override fun visitCallExpression(node: UCallExpression) { - if (node.methodName != RememberSaveableMethodName) return + if (node.methodName != REMEMBER_SAVEABLE_METHOD_NAME) return val evaluator = context.evaluator val method = node.resolve() val returnType = node.returnType @@ -72,7 +75,7 @@ class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { // If there is no init expression just return. val init = arguments.getInit() ?: return // Check whats created in the init expression. - if (returnsKnownMutableState(returnType, init)) { + if (returnsKnownMutableState(returnType, init, context)) { // Found a known parcelable mutable state. return } @@ -136,9 +139,13 @@ private fun Map.getInit(): UExpression? { * The default Android `mutableStateOf` is `ParcelableSnapshotMutableState`, so check if all the * return statements are default `mutableStateOf` calls. */ -private fun returnsKnownMutableState(returnType: PsiType, expression: UExpression): Boolean { +private fun returnsKnownMutableState( + returnType: PsiType, + expression: UExpression, + context: JavaContext, +): Boolean { return if (returnType.isAcceptableMutableStateClass()) { - val visitor = MutableStateOfVisitor() + val visitor = MutableStateOfVisitor(context) expression.accept(visitor) val allReturned = visitor.mutableStateOfs.isNotEmpty() && @@ -160,31 +167,62 @@ private fun PsiType?.isAcceptableMutableStateClass(): Boolean { return isMutableState && psiClassType.parameters.all { it.isAcceptableType() } } -private fun isKnownMutableStateFunction(node: UElement): Boolean { +private fun isKnownMutableStateFunction(node: UElement, context: JavaContext): Boolean { val resolved = node.tryResolve() val isKnownMutableState = resolved is PsiMethod && resolved.isInPackageName(ComposeRuntimePackageName) && resolved.name in AcceptableMutableStateMethods - if (isKnownMutableState) { + return if (isKnownMutableState) { // Check the type of mutableStateOf() if (resolved.name == "mutableStateOf") { - val mutableClass = node.asCall()?.returnType?.asClass() - mutableClass?.parameters?.all { it.isAcceptableType() } == true + val call = node.asCall() + val mutableClass = call?.returnType?.asClass() + val typeIsAcceptable = mutableClass?.parameters?.all { it.isAcceptableType() } == true + val policyIsAcceptable = hasAcceptablePolicy(call, context) + typeIsAcceptable && policyIsAcceptable } else { // Known mutable[Primitive]StateOf() true } + } else false +} + +/** + * Checks if the mutableStateOf call uses an acceptable SnapshotMutationPolicy. Acceptable policies + * are: neverEqualPolicy, structuralEqualityPolicy, referentialEqualityPolicy + */ +private fun hasAcceptablePolicy(call: UCallExpression?, context: JavaContext): Boolean { + val method = call?.resolve() ?: return true // Default policy is acceptable + val evaluator = context.evaluator + val arguments = evaluator.computeArgumentMapping(call, method) + + // Find the policy argument by parameter name + val policyArgument = + arguments.firstNotNullOfOrNull { (expression, parameter) -> + if (parameter.name == "policy") expression.skipParenthesizedExprDown() else null + } + + // If no policy is specified, the default (structuralEqualityPolicy) is acceptable + if (policyArgument == null) return true + + val resolved = policyArgument.tryResolveNamed() + if ( + resolved is PsiMethod && + resolved.isInPackageName(ComposeRuntimePackageName) && + resolved.name in AcceptablePolicyMethods + ) { + return true } - return isKnownMutableState + return false } -private class MutableStateOfVisitor : AbstractUastVisitor() { +private class MutableStateOfVisitor(private val context: JavaContext) : AbstractUastVisitor() { val mutableStateOfs = mutableListOf() override fun visitCallExpression(node: UCallExpression): Boolean = - if (isKnownMutableStateFunction(node)) { + if (isKnownMutableStateFunction(node, context)) { mutableStateOfs.add(node) true } else false @@ -265,9 +303,9 @@ private fun canBeSavedToBundle(value: Any): Boolean { } */ -// From -// https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L151 -/** +/* + * From: https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L151 + * * Contains Classes which can be stored inside [Bundle]. * * Some of the classes are not added separately because: @@ -301,3 +339,6 @@ private val AcceptableMutableStateMethods = "mutableDoubleStateOf", "mutableLongStateOf", ) + +private val AcceptablePolicyMethods = + setOf("neverEqualPolicy", "structuralEqualityPolicy", "referentialEqualityPolicy") diff --git a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt index 3a83f5d4..39aa53e9 100644 --- a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt +++ b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt @@ -182,11 +182,13 @@ import androidx.compose.runtime.mutableFloatStateOf import androidx.compose.runtime.mutableIntStateOf import androidx.compose.runtime.mutableLongStateOf import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.structuralEqualityPolicy import androidx.compose.runtime.saveable.rememberSaveable @Composable fun TestComposable() { val mutableState = rememberSaveable { mutableStateOf("value") } + val mutableStateWithPolicy = rememberSaveable { mutableStateOf("value", structuralEqualityPolicy()) } val mutableIntState = rememberSaveable { mutableIntStateOf(1) } val mutableFloatState = rememberSaveable { mutableFloatStateOf(1f) } val mutableDoubleState = rememberSaveable { mutableDoubleStateOf(1.0) } @@ -257,6 +259,94 @@ fun TestComposable() { .expectClean() } + @Test + fun acceptableTypes_mutableStateWithAcceptablePolicy_noError() { + val source = + kotlin( + """ +package test + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.neverEqualPolicy +import androidx.compose.runtime.referentialEqualityPolicy +import androidx.compose.runtime.structuralEqualityPolicy +import androidx.compose.runtime.saveable.rememberSaveable + +@Composable +fun TestComposable() { + val structuralPolicy = rememberSaveable { mutableStateOf("value", structuralEqualityPolicy()) } + val referentialPolicy = rememberSaveable { mutableStateOf("value", referentialEqualityPolicy()) } + val neverEqualPolicy = rememberSaveable { mutableStateOf("value", neverEqualPolicy()) } + val defaultPolicy = rememberSaveable { mutableStateOf("value") } +} + + """ + .trimIndent() + ) + .indented() + + lint() + .files( + JAVA_IO, + JAVA_UTIL, + ANDROID_OS, + ANDROID_UTIL, + COMPOSE_RUNTIME, + COMPOSE_SAVEABLE, + TEST_SAVEABLES, + source, + ) + .run() + .expectClean() + } + + @Test + fun unacceptableTypes_mutableStateWithCustomPolicy_hasError() { + val source = + kotlin( + """ +package test + +import androidx.compose.runtime.Composable +import androidx.compose.runtime.mutableStateOf +import androidx.compose.runtime.saveable.rememberSaveable + +@Composable +fun TestComposable() { + fun customPolicy(): SnapshotMutationPolicy = TODO() + + val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } +} + + """ + .trimIndent() + ) + .indented() + + lint() + .files( + JAVA_IO, + JAVA_UTIL, + ANDROID_OS, + ANDROID_UTIL, + COMPOSE_RUNTIME, + COMPOSE_SAVEABLE, + TEST_SAVEABLES, + source, + ) + .run() + .expect( + """ +src/test/test.kt:11: Error: Brief description [RememberSaveableTypeMustBeAcceptable] + val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +1 errors, 0 warnings + """ + .trimIndent() + ) + } + @Test fun acceptableTypes_collectionsAndLambdas_noError() { val source = @@ -414,6 +504,11 @@ private val COMPOSE_RUNTIME = } fun mutableLongStateOf(value: Long): MutableLongState = TODO() + + fun structuralEqualityPolicy(): SnapshotMutationPolicy = TODO() + fun referentialEqualityPolicy(): SnapshotMutationPolicy = TODO() + fun neverEqualPolicy(): SnapshotMutationPolicy = TODO() + """ .trimIndent() ) From 213fb543b6ca6d77dd44fbbc37b31c8b6d0eeea2 Mon Sep 17 00:00:00 2001 From: Josh Stagg Date: Tue, 2 Sep 2025 17:59:33 -0700 Subject: [PATCH 06/10] test updates --- .../RememberSaveableAcceptableDetector.kt | 4 - .../RememberSaveableAcceptableDetectorTest.kt | 655 +++++++++--------- 2 files changed, 316 insertions(+), 343 deletions(-) diff --git a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt index 299f57a8..4a06bcb7 100644 --- a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt +++ b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt @@ -40,10 +40,6 @@ private val ComposeRuntimePackageName = Package("androidx.compose.runtime") private val RememberSaveablePackageName = Package("androidx.compose.runtime.saveable") private val AUTO_SAVER = UastEmptyExpression(null) -// todo Rewrite this so its checking this instead -// Android only source set -// rememberSaveable with autoSaver -> Check value type is in "AcceptableClasses" -// rememberSaveable with custom saver -> Check Saver.save is in in "AcceptableClasses" class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { diff --git a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt index 39aa53e9..08011d82 100644 --- a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt +++ b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt @@ -4,6 +4,7 @@ package slack.lint.compose import com.android.tools.lint.checks.infrastructure.LintDetectorTest.java import com.android.tools.lint.checks.infrastructure.LintDetectorTest.kotlin +import com.android.tools.lint.checks.infrastructure.TestFile import org.junit.Test import slack.lint.BaseSlackLintTest @@ -13,31 +14,7 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { override fun getIssues() = listOf(RememberSaveableAcceptableDetector.ISSUE) - @Test - fun acceptableTypes_primitives_noError() { - val source = - kotlin( - """ -package test - -import androidx.compose.runtime.Composable -import androidx.compose.runtime.saveable.rememberSaveable - -@Composable -fun TestComposable() { - val stringValue = rememberSaveable { "test" } - val intValue = rememberSaveable { 42 } - val booleanValue = rememberSaveable { true } - val floatValue = rememberSaveable { 1.0f } - val doubleValue = rememberSaveable { 1.0 } - val longValue = rememberSaveable { 1L } -} - - """ - .trimIndent() - ) - .indented() - + private fun test(source: TestFile) = lint() .files( JAVA_IO, @@ -50,7 +27,33 @@ fun TestComposable() { source, ) .run() - .expectClean() + + @Test + fun acceptableTypes_primitives_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val stringValue = rememberSaveable { "test" } + val intValue = rememberSaveable { 42 } + val booleanValue = rememberSaveable { true } + val floatValue = rememberSaveable { 1.0f } + val doubleValue = rememberSaveable { 1.0 } + val longValue = rememberSaveable { 1L } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() } @Test @@ -58,40 +61,28 @@ fun TestComposable() { val source = kotlin( """ -package test - -import androidx.compose.runtime.Composable -import androidx.compose.runtime.saveable.rememberSaveable - -@Composable -fun TestComposable() { - val stringArray = rememberSaveable { arrayOf("test") } - val intArray = rememberSaveable { intArrayOf(42) } - val booleanArray = rememberSaveable { booleanArrayOf(true) } - val floatArray = rememberSaveable { floatArrayOf(1.0f) } - val doubleArray = rememberSaveable { doubleArrayOf(1.0) } - val longArray = rememberSaveable { longArrayOf(1L) } - val parcelableArray = rememberSaveable { arrayOf() } -} - - """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val stringArray = rememberSaveable { arrayOf("test") } + val intArray = rememberSaveable { intArrayOf(42) } + val booleanArray = rememberSaveable { booleanArrayOf(true) } + val floatArray = rememberSaveable { floatArrayOf(1.0f) } + val doubleArray = rememberSaveable { doubleArrayOf(1.0) } + val longArray = rememberSaveable { longArrayOf(1L) } + val parcelableArray = rememberSaveable { arrayOf() } + } + + """ .trimIndent() ) .indented() - lint() - .files( - JAVA_IO, - JAVA_UTIL, - ANDROID_OS, - ANDROID_UTIL, - COMPOSE_RUNTIME, - COMPOSE_SAVEABLE, - TEST_SAVEABLES, - source, - ) - .run() - .expectClean() + test(source).expectClean() } @Test @@ -99,37 +90,25 @@ fun TestComposable() { val source = kotlin( """ -package test + package test -import androidx.compose.runtime.Composable -import androidx.compose.runtime.saveable.rememberSaveable -import java.util.ArrayList + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + import java.util.ArrayList -@Composable -fun TestComposable() { - val serializableValue = rememberSaveable { TestSerializable() } - val parcelableValue = rememberSaveable { TestParcelable() } - val arrayListValue = rememberSaveable { ArrayList() } -} + @Composable + fun TestComposable() { + val serializableValue = rememberSaveable { TestSerializable() } + val parcelableValue = rememberSaveable { TestParcelable() } + val arrayListValue = rememberSaveable { ArrayList() } + } - """ + """ .trimIndent() ) .indented() - lint() - .files( - JAVA_IO, - JAVA_UTIL, - ANDROID_OS, - ANDROID_UTIL, - COMPOSE_RUNTIME, - COMPOSE_SAVEABLE, - TEST_SAVEABLES, - source, - ) - .run() - .expectClean() + test(source).expectClean() } @Test @@ -137,36 +116,24 @@ fun TestComposable() { val source = kotlin( """ -package test + package test -import androidx.compose.runtime.Composable -import androidx.compose.runtime.saveable.rememberSaveable + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable -@Composable -fun TestComposable() { - val nullableString = rememberSaveable { null } - val nullableInt = rememberSaveable { null } - val nullableBoolean = rememberSaveable { null } -} + @Composable + fun TestComposable() { + val nullableString = rememberSaveable { null } + val nullableInt = rememberSaveable { null } + val nullableBoolean = rememberSaveable { null } + } - """ + """ .trimIndent() ) .indented() - lint() - .files( - JAVA_IO, - JAVA_UTIL, - ANDROID_OS, - ANDROID_UTIL, - COMPOSE_RUNTIME, - COMPOSE_SAVEABLE, - TEST_SAVEABLES, - source, - ) - .run() - .expectClean() + test(source).expectClean() } @Test @@ -174,51 +141,39 @@ fun TestComposable() { val source = kotlin( """ -package test - -import androidx.compose.runtime.Composable -import androidx.compose.runtime.mutableDoubleStateOf -import androidx.compose.runtime.mutableFloatStateOf -import androidx.compose.runtime.mutableIntStateOf -import androidx.compose.runtime.mutableLongStateOf -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.structuralEqualityPolicy -import androidx.compose.runtime.saveable.rememberSaveable - -@Composable -fun TestComposable() { - val mutableState = rememberSaveable { mutableStateOf("value") } - val mutableStateWithPolicy = rememberSaveable { mutableStateOf("value", structuralEqualityPolicy()) } - val mutableIntState = rememberSaveable { mutableIntStateOf(1) } - val mutableFloatState = rememberSaveable { mutableFloatStateOf(1f) } - val mutableDoubleState = rememberSaveable { mutableDoubleStateOf(1.0) } - val mutableLongState = rememberSaveable { mutableLongStateOf(1L) } - val mutableTestParcelableState = rememberSaveable { - mutableStateOf(TestParcelable()).apply { value = TestParcelable() } - } - val mutableTestSerializableState = rememberSaveable { - mutableStateOf(TestSerializable()).apply { value = TestSerializable() } - } -} - - """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.mutableDoubleStateOf + import androidx.compose.runtime.mutableFloatStateOf + import androidx.compose.runtime.mutableIntStateOf + import androidx.compose.runtime.mutableLongStateOf + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.structuralEqualityPolicy + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val mutableState = rememberSaveable { mutableStateOf("value") } + val mutableStateWithPolicy = rememberSaveable { mutableStateOf("value", structuralEqualityPolicy()) } + val mutableIntState = rememberSaveable { mutableIntStateOf(1) } + val mutableFloatState = rememberSaveable { mutableFloatStateOf(1f) } + val mutableDoubleState = rememberSaveable { mutableDoubleStateOf(1.0) } + val mutableLongState = rememberSaveable { mutableLongStateOf(1L) } + val mutableTestParcelableState = rememberSaveable { + mutableStateOf(TestParcelable()).apply { value = TestParcelable() } + } + val mutableTestSerializableState = rememberSaveable { + mutableStateOf(TestSerializable()).apply { value = TestSerializable() } + } + } + + """ .trimIndent() ) .indented() - lint() - .files( - JAVA_IO, - JAVA_UTIL, - ANDROID_OS, - ANDROID_UTIL, - COMPOSE_RUNTIME, - COMPOSE_SAVEABLE, - TEST_SAVEABLES, - source, - ) - .run() - .expectClean() + test(source).expectClean() } @Test @@ -226,37 +181,25 @@ fun TestComposable() { val source = kotlin( """ -package test + package test -import androidx.compose.runtime.Composable -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.saveable.autoSaver -import androidx.compose.runtime.saveable.rememberSaveable + import androidx.compose.runtime.Composable + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.saveable.autoSaver + import androidx.compose.runtime.saveable.rememberSaveable -@Composable -fun TestComposable() { - val mutableStateLabeled = - rememberSaveable(saver = autoSaver(), init = { mutableStateOf("value") }) -} + @Composable + fun TestComposable() { + val mutableStateLabeled = + rememberSaveable(saver = autoSaver(), init = { mutableStateOf("value") }) + } - """ + """ .trimIndent() ) .indented() - lint() - .files( - JAVA_IO, - JAVA_UTIL, - ANDROID_OS, - ANDROID_UTIL, - COMPOSE_RUNTIME, - COMPOSE_SAVEABLE, - TEST_SAVEABLES, - source, - ) - .run() - .expectClean() + test(source).expectClean() } @Test @@ -264,41 +207,29 @@ fun TestComposable() { val source = kotlin( """ -package test - -import androidx.compose.runtime.Composable -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.neverEqualPolicy -import androidx.compose.runtime.referentialEqualityPolicy -import androidx.compose.runtime.structuralEqualityPolicy -import androidx.compose.runtime.saveable.rememberSaveable - -@Composable -fun TestComposable() { - val structuralPolicy = rememberSaveable { mutableStateOf("value", structuralEqualityPolicy()) } - val referentialPolicy = rememberSaveable { mutableStateOf("value", referentialEqualityPolicy()) } - val neverEqualPolicy = rememberSaveable { mutableStateOf("value", neverEqualPolicy()) } - val defaultPolicy = rememberSaveable { mutableStateOf("value") } -} - - """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.neverEqualPolicy + import androidx.compose.runtime.referentialEqualityPolicy + import androidx.compose.runtime.structuralEqualityPolicy + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val structuralPolicy = rememberSaveable { mutableStateOf("value", structuralEqualityPolicy()) } + val referentialPolicy = rememberSaveable { mutableStateOf("value", referentialEqualityPolicy()) } + val neverEqualPolicy = rememberSaveable { mutableStateOf("value", neverEqualPolicy()) } + val defaultPolicy = rememberSaveable { mutableStateOf("value") } + } + + """ .trimIndent() ) .indented() - lint() - .files( - JAVA_IO, - JAVA_UTIL, - ANDROID_OS, - ANDROID_UTIL, - COMPOSE_RUNTIME, - COMPOSE_SAVEABLE, - TEST_SAVEABLES, - source, - ) - .run() - .expectClean() + test(source).expectClean() } @Test @@ -306,108 +237,151 @@ fun TestComposable() { val source = kotlin( """ -package test + package test -import androidx.compose.runtime.Composable -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.saveable.rememberSaveable + import androidx.compose.runtime.Composable + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.saveable.rememberSaveable -@Composable -fun TestComposable() { - fun customPolicy(): SnapshotMutationPolicy = TODO() - - val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } -} + @Composable + fun TestComposable() { + fun customPolicy(): SnapshotMutationPolicy = TODO() + + val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } + } - """ + """ .trimIndent() ) .indented() - lint() - .files( - JAVA_IO, - JAVA_UTIL, - ANDROID_OS, - ANDROID_UTIL, - COMPOSE_RUNTIME, - COMPOSE_SAVEABLE, - TEST_SAVEABLES, - source, - ) - .run() + test(source) .expect( """ -src/test/test.kt:11: Error: Brief description [RememberSaveableTypeMustBeAcceptable] - val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -1 errors, 0 warnings - """ + src/test/test.kt:11: Error: Brief description [RememberSaveableTypeMustBeAcceptable] + val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + 1 errors, 0 warnings + """ .trimIndent() ) } @Test - fun acceptableTypes_collectionsAndLambdas_noError() { + fun acceptableTypes_collections_noError() { val source = kotlin( """ -package test - -import androidx.compose.runtime.Composable -import androidx.compose.runtime.mutableStateOf -import androidx.compose.runtime.saveable.rememberSaveable -import kotlin.random.Random - -@Composable -fun TestComposable() { - val listValue = rememberSaveable { listOf("test") } - val mapValue = rememberSaveable { mapOf("key" to "value") } - val mutableStateValue = rememberSaveable { mutableStateOf("value") } - val lambdaInvokedValue = rememberSaveable { { "value" }() } - val lambdaValue = rememberSaveable { { "value" } } - val lambdaBlockValue = rememberSaveable { - val a = { "value" } - val b = { "other" } - val c = if (Random.nextInt() == 42) 1 else 2 - if (c == 1) { - a - } else { - ({ "other" }) - } - } -} + package test - """ + import androidx.compose.runtime.Composable + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val listValue = rememberSaveable { listOf("test") } + val mapValue = rememberSaveable { mapOf("key" to "value") } + val mutableStateValue = rememberSaveable { mutableStateOf("value") } + } + + """ .trimIndent() ) .indented() - lint() - .files( - JAVA_IO, - JAVA_UTIL, - ANDROID_OS, - ANDROID_UTIL, - COMPOSE_RUNTIME, - COMPOSE_SAVEABLE, - TEST_SAVEABLES, - source, + test(source).expectClean() + } + + @Test + fun acceptableTypes_lambdas_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val lambdaInvokedValue = rememberSaveable { { "value" }() } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun acceptableTypes_lambda_errors() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val lambdaValue = rememberSaveable { { "value" } } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expect("") + } + + @Test + fun acceptableTypes_multiReturnLambdas_errors() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + import kotlin.random.Random + + @Composable + fun TestComposable() { + val lambdaBlockValue = rememberSaveable { + val a = { "value" } + val b = { "other" } + val c = if (Random.nextInt() == 42) 1 else 2 + if (c == 1) { + a + } else { + ({ "other" }) + } + } + } + + """ + .trimIndent() ) - .run() - .expectClean() + .indented() + + test(source).expect("") } } private val JAVA_IO = java( """ - package java.io; + package java.io; - public interface Serializable { - } + public interface Serializable { + } - """ + """ .trimIndent() ) .indented() @@ -415,13 +389,13 @@ private val JAVA_IO = private val JAVA_UTIL = java( """ - package java.util; + package java.util; - import java.io.Serializable; + import java.io.Serializable; - public class ArrayList implements Serializable {} + public class ArrayList implements Serializable {} - """ + """ .trimIndent() ) .indented() @@ -456,60 +430,62 @@ private val ANDROID_UTIL = private val COMPOSE_RUNTIME = kotlin( """ - package androidx.compose.runtime + package androidx.compose.runtime - annotation class Composable + annotation class Composable - annotation class Stable + annotation class Stable - interface SnapshotMutationPolicy + interface SnapshotMutationPolicy - interface State { - val value: T - } + interface State { + val value: T + } - interface MutableState : State { - override var value: T - } + interface MutableState : State { + override var value: T + } - fun mutableStateOf( - value: T, - policy: SnapshotMutationPolicy = structuralEqualityPolicy(), - ): MutableState = TODO() + fun mutableStateOf( + value: T, + policy: SnapshotMutationPolicy = structuralEqualityPolicy(), + ): MutableState = TODO() - interface MutableIntState : IntState, MutableState { - override var value: Int - override var intValue: Int - } + interface MutableIntState : IntState, MutableState { + override var value: Int + override var intValue: Int + } - fun mutableIntStateOf(value: Int): MutableIntState = TODO() + fun mutableIntStateOf(value: Int): MutableIntState = TODO() - interface MutableFloatState : FloatState, MutableState { - override var value: Float - override var floatValue: Float - } + interface MutableFloatState : FloatState, MutableState { + override var value: Float + override var floatValue: Float + } - fun mutableFloatStateOf(value: Float): MutableFloatState = TODO() + fun mutableFloatStateOf(value: Float): MutableFloatState = TODO() - interface MutableDoubleState : MutableState { + interface MutableDoubleState : MutableState { override var value: Double override var doubleValue: Double - } + } - fun mutableDoubleStateOf(value: Double): MutableDoubleState = TODO() + fun mutableDoubleStateOf(value: Double): MutableDoubleState = TODO() - interface MutableLongState : MutableState { + interface MutableLongState : MutableState { override var value: Long override var longValue: Long - } + } - fun mutableLongStateOf(value: Long): MutableLongState = TODO() + fun mutableLongStateOf(value: Long): MutableLongState = TODO() - fun structuralEqualityPolicy(): SnapshotMutationPolicy = TODO() - fun referentialEqualityPolicy(): SnapshotMutationPolicy = TODO() - fun neverEqualPolicy(): SnapshotMutationPolicy = TODO() + fun structuralEqualityPolicy(): SnapshotMutationPolicy = TODO() - """ + fun referentialEqualityPolicy(): SnapshotMutationPolicy = TODO() + + fun neverEqualPolicy(): SnapshotMutationPolicy = TODO() + + """ .trimIndent() ) .indented() @@ -517,47 +493,48 @@ private val COMPOSE_RUNTIME = private val COMPOSE_SAVEABLE = kotlin( """ - package androidx.compose.runtime.saveable - - interface Saver { - fun save(value: Original): Saveable? - fun restore(value: Saveable): Original? - } - - fun autoSaver(): Saver = TODO() - - @Composable - fun rememberSaveable( - vararg inputs: Any?, - saver: Saver = autoSaver(), - key: String? = null, - init: () -> T - ): T = TODO() - - @Composable - fun rememberSaveable( - vararg inputs: Any?, - stateSaver: Saver, - key: String? = null, - init: () -> MutableState - ): MutableState = TODO() + package androidx.compose.runtime.saveable + + interface Saver { + fun save(value: Original): Saveable? + fun restore(value: Saveable): Original? + } + + fun autoSaver(): Saver = TODO() + + @Composable + fun rememberSaveable( + vararg inputs: Any?, + saver: Saver = autoSaver(), + key: String? = null, + init: () -> T + ): T = TODO() + + @Composable + fun rememberSaveable( + vararg inputs: Any?, + stateSaver: Saver, + key: String? = null, + init: () -> MutableState + ): MutableState = TODO() - """ + """ .trimIndent() ) .indented() +@Suppress("") private val TEST_SAVEABLES = kotlin( """ - package test + package test - import android.os.Parcelable - import java.io.Serializable + import android.os.Parcelable + import java.io.Serializable - class TestSerializable : Serializable - class TestParcelable : Parcelable - """ + class TestSerializable : Serializable + class TestParcelable : Parcelable + """ .trimIndent() ) .indented() From 20dcf31705ed3ada29bf8a884bd67c1e8a0c0e82 Mon Sep 17 00:00:00 2001 From: Josh Stagg Date: Tue, 2 Sep 2025 18:15:48 -0700 Subject: [PATCH 07/10] lambda returns --- .../RememberSaveableAcceptableDetector.kt | 51 +++++++++++-------- .../RememberSaveableAcceptableDetectorTest.kt | 17 +++++-- 2 files changed, 43 insertions(+), 25 deletions(-) diff --git a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt index 4a06bcb7..0508b4e2 100644 --- a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt +++ b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt @@ -40,7 +40,6 @@ private val ComposeRuntimePackageName = Package("androidx.compose.runtime") private val RememberSaveablePackageName = Package("androidx.compose.runtime.saveable") private val AUTO_SAVER = UastEmptyExpression(null) - class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { override fun getApplicableUastTypes(): List> = @@ -190,27 +189,19 @@ private fun isKnownMutableStateFunction(node: UElement, context: JavaContext): B */ private fun hasAcceptablePolicy(call: UCallExpression?, context: JavaContext): Boolean { val method = call?.resolve() ?: return true // Default policy is acceptable - val evaluator = context.evaluator - val arguments = evaluator.computeArgumentMapping(call, method) + val arguments = context.evaluator.computeArgumentMapping(call, method) // Find the policy argument by parameter name val policyArgument = - arguments.firstNotNullOfOrNull { (expression, parameter) -> - if (parameter.name == "policy") expression.skipParenthesizedExprDown() else null - } - - // If no policy is specified, the default (structuralEqualityPolicy) is acceptable - if (policyArgument == null) return true + arguments + .firstNotNullOfOrNull { (expression, parameter) -> + if (parameter.name == "policy") expression.skipParenthesizedExprDown() else null + } + ?.tryResolveNamed() ?: return true // If no policy is specified, assume its the default - val resolved = policyArgument.tryResolveNamed() - if ( - resolved is PsiMethod && - resolved.isInPackageName(ComposeRuntimePackageName) && - resolved.name in AcceptablePolicyMethods - ) { - return true - } - return false + return policyArgument is PsiMethod && + policyArgument.isInPackageName(ComposeRuntimePackageName) && + policyArgument.name in AcceptablePolicyMethods } private class MutableStateOfVisitor(private val context: JavaContext) : AbstractUastVisitor() { @@ -240,13 +231,31 @@ private class ReturnsTracker(tracked: UElement, var returned: Boolean = false) : */ private fun returnsLambdaExpression(returnType: PsiType, expression: UExpression): Boolean { val isFunction = returnType.asClass()?.resolve()?.implements("kotlin.Function") == true - // todo Find all the function returns and see if any are lambda expressions + + // Find all the function returns and see if any are lambda expressions val visitor = ReturnsLambdaVisitor() expression.accept(visitor) - return isFunction // && visitor.returned + + // Return true if the return type is a function AND we found lambda returns + return isFunction && visitor.returnedLambda } -private class ReturnsLambdaVisitor(var returned: Boolean = false) : AbstractUastVisitor() {} +private class ReturnsLambdaVisitor(var returnedLambda: Boolean = false) : AbstractUastVisitor() { + + override fun visitReturnExpression(node: UReturnExpression): Boolean { + val returnValue = node.returnExpression + if (returnValue != null && isLambdaExpression(returnValue)) { + returnedLambda = true + } + return super.visitReturnExpression(node) + } + + private fun isLambdaExpression(expression: UExpression): Boolean { + // Check if this is a lambda expression by looking at the expression type + val expressionType = expression.getExpressionType() + return expressionType?.asClass()?.resolve()?.implements("kotlin.Function") == true + } +} private fun PsiType?.asClass(): PsiClassType? = this as? PsiClassType diff --git a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt index 08011d82..958f667b 100644 --- a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt +++ b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt @@ -343,7 +343,7 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { fun acceptableTypes_multiReturnLambdas_errors() { val source = kotlin( - """ + """ package test import androidx.compose.runtime.Composable @@ -365,11 +365,20 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { } """ - .trimIndent() - ) + .trimIndent() + ) .indented() - test(source).expect("") + test(source) + .expect( + """ + src/test/test.kt:9: Error: Brief description (LAMBDA) [RememberSaveableTypeMustBeAcceptable] + val lambdaBlockValue = rememberSaveable { + ^ + 1 error + """ + .trimIndent() + ) } } From ea3bcc373d08ecc9f0ce656bb9dc8824e59e7f3a Mon Sep 17 00:00:00 2001 From: Josh Stagg Date: Tue, 2 Sep 2025 18:29:13 -0700 Subject: [PATCH 08/10] todos --- .../RememberSaveableAcceptableDetector.kt | 7 ++++ .../RememberSaveableAcceptableDetectorTest.kt | 41 ++++++++++++++++--- 2 files changed, 43 insertions(+), 5 deletions(-) diff --git a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt index 0508b4e2..154aac8f 100644 --- a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt +++ b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt @@ -40,6 +40,13 @@ private val ComposeRuntimePackageName = Package("androidx.compose.runtime") private val RememberSaveablePackageName = Package("androidx.compose.runtime.saveable") private val AUTO_SAVER = UastEmptyExpression(null) +// todo +// - Collections (listOf, mapOf, etc) checks +// - ArrayList check? +// Think the savers likely go in a different detector +// - MapSaver check +// - ListSaver check +// - Custom saver check class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { override fun getApplicableUastTypes(): List> = diff --git a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt index 958f667b..9c1a511e 100644 --- a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt +++ b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt @@ -268,21 +268,19 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { } @Test - fun acceptableTypes_collections_noError() { + fun acceptableTypes_collections_wrong_saver_error() { val source = kotlin( """ package test import androidx.compose.runtime.Composable - import androidx.compose.runtime.mutableStateOf - import androidx.compose.runtime.saveable.rememberSaveable + import androidx.compose.runtime.saveable.rememberSaveable @Composable fun TestComposable() { val listValue = rememberSaveable { listOf("test") } val mapValue = rememberSaveable { mapOf("key" to "value") } - val mutableStateValue = rememberSaveable { mutableStateOf("value") } } """ @@ -293,6 +291,30 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { test(source).expectClean() } + @Test + fun acceptableTypes_collections_wrong_saver_error() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val listValue = rememberSaveable { listOf("test") } + val mapValue = rememberSaveable { mapOf("key" to "value") } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + @Test fun acceptableTypes_lambdas_noError() { val source = @@ -336,7 +358,16 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { ) .indented() - test(source).expect("") + test(source) + .expect( + """ + src/test/test.kt:8: Error: Brief description (LAMBDA) [RememberSaveableTypeMustBeAcceptable] + val lambdaValue = rememberSaveable { { "value" } } + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + 1 error + """ + .trimIndent() + ) } @Test From 1303d3a3d9aa2ddef99ce003b4cec6cf6fe304b4 Mon Sep 17 00:00:00 2001 From: Josh Stagg Date: Wed, 3 Sep 2025 12:34:37 -0700 Subject: [PATCH 09/10] cleanup --- .../RememberSaveableAcceptableDetector.kt | 362 ++++++++++-------- .../main/java/slack/lint/util/LintUtils.kt | 3 + .../RememberSaveableAcceptableDetectorTest.kt | 53 +-- 3 files changed, 231 insertions(+), 187 deletions(-) diff --git a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt index 154aac8f..e312748f 100644 --- a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt +++ b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt @@ -19,6 +19,7 @@ import com.intellij.psi.PsiMethod import com.intellij.psi.PsiParameter import com.intellij.psi.PsiPrimitiveType import com.intellij.psi.PsiType +import kotlin.contracts.ExperimentalContracts import org.jetbrains.uast.UCallExpression import org.jetbrains.uast.UElement import org.jetbrains.uast.UExpression @@ -29,6 +30,7 @@ import org.jetbrains.uast.tryResolve import org.jetbrains.uast.tryResolveNamed import org.jetbrains.uast.visitor.AbstractUastVisitor import slack.lint.util.Package +import slack.lint.util.asClass import slack.lint.util.implements import slack.lint.util.isBoxedPrimitive import slack.lint.util.isInPackageName @@ -52,54 +54,14 @@ class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { override fun getApplicableUastTypes(): List> = listOf(UCallExpression::class.java) - override fun createUastHandler(context: JavaContext) = - object : UElementHandler() { - @Suppress("ReturnCount") - override fun visitCallExpression(node: UCallExpression) { - if (node.methodName != REMEMBER_SAVEABLE_METHOD_NAME) return - val evaluator = context.evaluator - val method = node.resolve() - val returnType = node.returnType - if ( - method == null || - returnType == null || - !method.isInPackageName(RememberSaveablePackageName) - ) { - return - } - - val arguments = evaluator.computeArgumentMapping(node, method) - val saver = arguments.getSaver() - // With an auto saver check the return type. - if (saver == AUTO_SAVER && returnType.isAcceptableType()) { - return - } - // If there is no init expression just return. - val init = arguments.getInit() ?: return - // Check whats created in the init expression. - if (returnsKnownMutableState(returnType, init, context)) { - // Found a known parcelable mutable state. - return - } - if (returnsLambdaExpression(returnType, init)) { - // todo Report specific error language about kotlin Lambdas - context.report( - ISSUE, - context.getLocation(node), - ISSUE.getBriefDescription(TextFormat.TEXT) + " (LAMBDA)", - ) - return - } - context.report(ISSUE, context.getLocation(node), ISSUE.getBriefDescription(TextFormat.TEXT)) - } - } + override fun createUastHandler(context: JavaContext): UElementHandler = + RememberSaveableElementHandler(context) companion object { const val MESSAGE = "remember" const val ISSUE_ID = "RememberSaveableTypeMustBeAcceptable" const val BRIEF_DESCRIPTION = "Brief description" const val EXPLANATION = """Full explanation""" - val ISSUE: Issue = Issue.create( ISSUE_ID, @@ -113,6 +75,84 @@ class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { } } +private class RememberSaveableElementHandler(val context: JavaContext) : UElementHandler() { + override fun visitCallExpression(node: UCallExpression) { + if (node.methodName != REMEMBER_SAVEABLE_METHOD_NAME) return + val evaluator = context.evaluator + val method = node.resolve() + val returnType = node.returnType + if ( + method != null && returnType != null && method.isInPackageName(RememberSaveablePackageName) + ) { + val arguments = evaluator.computeArgumentMapping(node, method) + val saver = arguments.getSaver() + if (saver == AUTO_SAVER) { + visitAutoSaver(node, returnType, arguments) + } else { + visitCustomSaver(node, saver) + } + } + } + + private fun visitAutoSaver( + node: UCallExpression, + returnType: PsiType, + arguments: Map, + ) { + val init = arguments.getInit() + // If there is no init expression or if the return is an acceptable type, just return. + if (init == null || returnType.isAcceptableType()) { + return + } + // Check what is created in the init expression. + val (allAcceptableMutableStates, customPolices) = + returnsKnownMutableState(returnType, init, context) + if (allAcceptableMutableStates) { + // Found a known parcelable mutable state! + // Report the custom policy error when using the auto saver. + // todo Report error message "If you use a custom SnapshotMutationPolicy for your + // MutableState you have to write a custom Saver" + customPolices.forEach { + context.report( + RememberSaveableAcceptableDetector.Companion.ISSUE, + context.getLocation(it), + RememberSaveableAcceptableDetector.Companion.ISSUE.getBriefDescription(TextFormat.TEXT) + + " (Custom policy)", + ) + } + return + } + + if (returnsLambdaExpression(returnType, init)) { + // todo Report specific error language about kotlin Lambdas + context.report( + RememberSaveableAcceptableDetector.Companion.ISSUE, + context.getLocation(node), + RememberSaveableAcceptableDetector.Companion.ISSUE.getBriefDescription(TextFormat.TEXT) + + " (LAMBDA)", + ) + } else { + // "The default implementation only supports types which can be stored inside the Bundle. + // Please consider implementing a custom Saver for this class and pass it to + // rememberSaveable()." + context.report( + RememberSaveableAcceptableDetector.Companion.ISSUE, + context.getLocation(node), + RememberSaveableAcceptableDetector.Companion.ISSUE.getBriefDescription(TextFormat.TEXT), + ) + } + } + + private fun visitCustomSaver(node: UCallExpression, saver: UExpression) { + // todo Custom saver, check the return type of the save call and report an error. + context.report( + RememberSaveableAcceptableDetector.Companion.ISSUE, + context.getLocation(node), + RememberSaveableAcceptableDetector.Companion.ISSUE.getBriefDescription(TextFormat.TEXT), + ) + } +} + private fun Map.getSaver(): UExpression { val saver = firstNotNullOfOrNull { (expression, parameter) -> if (parameter.name == "saver") expression else null @@ -145,21 +185,22 @@ private fun returnsKnownMutableState( returnType: PsiType, expression: UExpression, context: JavaContext, -): Boolean { +): ReturnsKnownMutableStateResult { return if (returnType.isAcceptableMutableStateClass()) { val visitor = MutableStateOfVisitor(context) expression.accept(visitor) - val allReturned = - visitor.mutableStateOfs.isNotEmpty() && - visitor.mutableStateOfs.all { tracked -> - val returnsTracker = ReturnsTracker(tracked) - expression.accept(returnsTracker) - returnsTracker.returned - } - allReturned - } else false + val allAcceptableMutableStates = visitor.checkAllReturned(expression) + ReturnsKnownMutableStateResult(allAcceptableMutableStates, visitor.customPolices) + } else { + ReturnsKnownMutableStateResult(false, emptyList()) + } } +private data class ReturnsKnownMutableStateResult( + val allAcceptableMutableStates: Boolean, + val customPolices: List, +) + private fun PsiType?.isAcceptableMutableStateClass(): Boolean { val psiClassType = asClass() val isMutableState = @@ -169,31 +210,89 @@ private fun PsiType?.isAcceptableMutableStateClass(): Boolean { return isMutableState && psiClassType.parameters.all { it.isAcceptableType() } } -private fun isKnownMutableStateFunction(node: UElement, context: JavaContext): Boolean { - val resolved = node.tryResolve() - val isKnownMutableState = - resolved is PsiMethod && - resolved.isInPackageName(ComposeRuntimePackageName) && - resolved.name in AcceptableMutableStateMethods - return if (isKnownMutableState) { - // Check the type of mutableStateOf() - if (resolved.name == "mutableStateOf") { - val call = node.asCall() - val mutableClass = call?.returnType?.asClass() - val typeIsAcceptable = mutableClass?.parameters?.all { it.isAcceptableType() } == true - val policyIsAcceptable = hasAcceptablePolicy(call, context) - typeIsAcceptable && policyIsAcceptable - } else { - // Known mutable[Primitive]StateOf() +/** + * Visitor that tracks `mutableStateOf` function calls within an expression to determine if all such + * calls are returned from their containing expressions. + * + * This visitor is used to verify that when `rememberSaveable` contains `mutableStateOf` calls, + * those calls are actually being returned (and thus can be properly saved/restored by the + * auto-saver mechanism). + * + * @param context The JavaContext for the current lint analysis + */ +private class MutableStateOfVisitor(private val context: JavaContext) : AbstractUastVisitor() { + + /** List of all `mutableStateOf` function calls found during traversal */ + val mutableStateOfs = mutableListOf() + val customPolices = mutableListOf() + + /** + * Visits call expressions and tracks those that are known mutable state functions. + * + * @param node The call expression to examine + * @return true if the node is a known mutable state function, false otherwise + */ + override fun visitCallExpression(node: UCallExpression): Boolean = + if (isKnownMutableStateFunction(node, context)) { + mutableStateOfs.add(node) true + } else false + + /** + * Checks if all tracked `mutableStateOf` calls are returned from their containing expressions. + * + * @return true if all tracked calls are returned, false otherwise + */ + fun checkAllReturned(expression: UExpression): Boolean { + return mutableStateOfs.isNotEmpty() && + mutableStateOfs.all { tracked -> + val visitor = ReturnsTracker(tracked) + expression.accept(visitor) + visitor.returned + } + } + + /** + * Check if the call is one of the acceptable mutable state functions, like `mutableStateOf` or + * `mutableFloatStateOf`. + */ + private fun isKnownMutableStateFunction(node: UElement, context: JavaContext): Boolean { + val resolved = node.tryResolve() + val isKnownMutableState = + resolved is PsiMethod && + resolved.isInPackageName(ComposeRuntimePackageName) && + resolved.name in AcceptableMutableStateMethods + return if (isKnownMutableState) { + // Check the type of mutableStateOf() + if (resolved.name == "mutableStateOf") { + val call = node.asCall() + val mutableClass = call?.returnType?.asClass() + val typeIsAcceptable = mutableClass?.parameters?.all { it.isAcceptableType() } == true + // Side effect to track custom policies + if (!hasAcceptablePolicy(call, context) && call != null) { + customPolices += call + } + typeIsAcceptable + } else { + // Known mutable[Primitive]StateOf() + true + } + } else false + } + + private class ReturnsTracker(tracked: UElement, var returned: Boolean = false) : + DataFlowAnalyzer(listOf(tracked)) { + override fun returns(expression: UReturnExpression) { + returned = true } - } else false + } } /** * Checks if the mutableStateOf call uses an acceptable SnapshotMutationPolicy. Acceptable policies * are: neverEqualPolicy, structuralEqualityPolicy, referentialEqualityPolicy */ +@OptIn(ExperimentalContracts::class) private fun hasAcceptablePolicy(call: UCallExpression?, context: JavaContext): Boolean { val method = call?.resolve() ?: return true // Default policy is acceptable val arguments = context.evaluator.computeArgumentMapping(call, method) @@ -202,49 +301,35 @@ private fun hasAcceptablePolicy(call: UCallExpression?, context: JavaContext): B val policyArgument = arguments .firstNotNullOfOrNull { (expression, parameter) -> - if (parameter.name == "policy") expression.skipParenthesizedExprDown() else null + if (parameter.name == "policy") expression else null } - ?.tryResolveNamed() ?: return true // If no policy is specified, assume its the default - - return policyArgument is PsiMethod && - policyArgument.isInPackageName(ComposeRuntimePackageName) && - policyArgument.name in AcceptablePolicyMethods -} - -private class MutableStateOfVisitor(private val context: JavaContext) : AbstractUastVisitor() { + ?.skipParenthesizedExprDown() + ?.tryResolveNamed() - val mutableStateOfs = mutableListOf() - - override fun visitCallExpression(node: UCallExpression): Boolean = - if (isKnownMutableStateFunction(node, context)) { - mutableStateOfs.add(node) - true - } else false -} + // Check if the policy is acceptable + val isAcceptablePolicy = + policyArgument is PsiMethod && + policyArgument.isInPackageName(ComposeRuntimePackageName) && + policyArgument.name in AcceptablePolicyMethods -private class ReturnsTracker(tracked: UElement, var returned: Boolean = false) : - DataFlowAnalyzer(listOf(tracked)) { - override fun returns(expression: UReturnExpression) { - returned = true - } + // If no policy is specified, assume its the default + return policyArgument == null || isAcceptablePolicy } /** * > Lambdas in Kotlin implement Serializable, but will crash if you really try to save them. We * > check for both Function and Serializable (see kotlin.jvm.internal.Lambda) to support custom * > user defined classes implementing Function interface. - * - From: - * https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L122-L124 + * + * https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L122-L124 */ private fun returnsLambdaExpression(returnType: PsiType, expression: UExpression): Boolean { val isFunction = returnType.asClass()?.resolve()?.implements("kotlin.Function") == true - - // Find all the function returns and see if any are lambda expressions - val visitor = ReturnsLambdaVisitor() - expression.accept(visitor) - - // Return true if the return type is a function AND we found lambda returns - return isFunction && visitor.returnedLambda + return if (isFunction) { + val visitor = ReturnsLambdaVisitor() + expression.accept(visitor) + visitor.returnedLambda + } else false } private class ReturnsLambdaVisitor(var returnedLambda: Boolean = false) : AbstractUastVisitor() { @@ -264,8 +349,6 @@ private class ReturnsLambdaVisitor(var returnedLambda: Boolean = false) : Abstra } } -private fun PsiType?.asClass(): PsiClassType? = this as? PsiClassType - private fun PsiType.isAcceptableType(): Boolean { return when (this) { is PsiPrimitiveType -> true @@ -282,55 +365,22 @@ private fun PsiClass.isAcceptableClassType(): Boolean { return isBoxedPrimitive() || AcceptableClasses.any { implements(it) } } -// From -// https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L108 -/* -/** Checks that [value] can be stored inside [Bundle]. */ -private fun canBeSavedToBundle(value: Any): Boolean { - // SnapshotMutableStateImpl is Parcelable, but we do extra checks - if (value is SnapshotMutableState<*>) { - if ( - value.policy === neverEqualPolicy() || - value.policy === structuralEqualityPolicy() || - value.policy === referentialEqualityPolicy() - ) { - val stateValue = value.value - return if (stateValue == null) true else canBeSavedToBundle(stateValue) - } else { - return false - } - } - // lambdas in Kotlin implement Serializable, but will crash if you really try to save them. - // we check for both Function and Serializable (see kotlin.jvm.internal.Lambda) to support - // custom user defined classes implementing Function interface. - if (value is Function<*> && value is Serializable) { - return false - } - for (cl in AcceptableClasses) { - if (cl.isInstance(value)) { - return true - } - } - return false -} - */ - -/* - * From: https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L151 - * - * Contains Classes which can be stored inside [Bundle]. - * - * Some of the classes are not added separately because: - * - * This classes implement Serializable: - * - Arrays (DoubleArray, BooleanArray, IntArray, LongArray, ByteArray, FloatArray, ShortArray, - * CharArray, Array) - * - ArrayList - * - Primitives (Boolean, Int, Long, Double, Float, Byte, Short, Char) will be boxed when casted to - * Any, and all the boxed classes implements Serializable. This class implements Parcelable: - * - Bundle +/** + * Based on this set used by `canBeSavedToBundle()`: + * ``` + * private val AcceptableClasses = + * arrayOf( + * Serializable::class.java, + * Parcelable::class.java, + * String::class.java, + * SparseArray::class.java, + * Binder::class.java, + * Size::class.java, + * SizeF::class.java, + * ) + * ``` * - * Note: it is simplified copy of the array from SavedStateHandle (lifecycle-viewmodel-savedstate). + * https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L151-L160 */ private val AcceptableClasses = setOf( @@ -352,5 +402,17 @@ private val AcceptableMutableStateMethods = "mutableLongStateOf", ) +/** + * Based on this check performed in `canBeSavedToBundle()`: + * ``` + * if ( + * value.policy === neverEqualPolicy() || + * value.policy === structuralEqualityPolicy() || + * value.policy === referentialEqualityPolicy() + * ) { + * ``` + * + * https://github.com/androidx/androidx/blob/989d1e676252c69e1b9b2e0639c3dba039e7ac99/compose/ui/ui/src/androidMain/kotlin/androidx/compose/ui/platform/DisposableSaveableStateRegistry.android.kt#L112-L114 + */ private val AcceptablePolicyMethods = setOf("neverEqualPolicy", "structuralEqualityPolicy", "referentialEqualityPolicy") diff --git a/slack-lint-checks/src/main/java/slack/lint/util/LintUtils.kt b/slack-lint-checks/src/main/java/slack/lint/util/LintUtils.kt index f73d3967..c994efb5 100644 --- a/slack-lint-checks/src/main/java/slack/lint/util/LintUtils.kt +++ b/slack-lint-checks/src/main/java/slack/lint/util/LintUtils.kt @@ -366,3 +366,6 @@ internal val PsiMethod.returnsUnit */ internal val PsiType?.isVoidOrUnit get() = this == PsiTypes.voidType() || this?.canonicalText == "kotlin.Unit" + +/** Cast as a [PsiClassType] if possible. */ +internal fun PsiType?.asClass(): PsiClassType? = this as? PsiClassType diff --git a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt index 9c1a511e..32b180cc 100644 --- a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt +++ b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt @@ -111,31 +111,6 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { test(source).expectClean() } - @Test - fun acceptableTypes_nullables_noError() { - val source = - kotlin( - """ - package test - - import androidx.compose.runtime.Composable - import androidx.compose.runtime.saveable.rememberSaveable - - @Composable - fun TestComposable() { - val nullableString = rememberSaveable { null } - val nullableInt = rememberSaveable { null } - val nullableBoolean = rememberSaveable { null } - } - - """ - .trimIndent() - ) - .indented() - - test(source).expectClean() - } - @Test fun acceptableTypes_mutableStates_noError() { val source = @@ -258,29 +233,33 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { test(source) .expect( """ - src/test/test.kt:11: Error: Brief description [RememberSaveableTypeMustBeAcceptable] - val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - 1 errors, 0 warnings - """ + src/test/test.kt:11: Error: Brief description (Custom policy) [RememberSaveableTypeMustBeAcceptable] + val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + 1 error + """ .trimIndent() ) } @Test - fun acceptableTypes_collections_wrong_saver_error() { + fun unacceptableTypes_mutableStateWithCustomPolicy_noError() { val source = kotlin( """ package test import androidx.compose.runtime.Composable - import androidx.compose.runtime.saveable.rememberSaveable + import androidx.compose.runtime.mutableStateOf + import androidx.compose.runtime.saveable.rememberSaveable @Composable fun TestComposable() { - val listValue = rememberSaveable { listOf("test") } - val mapValue = rememberSaveable { mapOf("key" to "value") } + fun customPolicy(): SnapshotMutationPolicy = TODO() + val saver = + Saver, String>({ it.value }, { mutableStateOf(it, customPolicy()) }) + val customPolicyState = + rememberSaveable(saver = saver) { mutableStateOf("value", customPolicy()) } } """ @@ -295,7 +274,7 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { fun acceptableTypes_collections_wrong_saver_error() { val source = kotlin( - """ + """ package test import androidx.compose.runtime.Composable @@ -308,8 +287,8 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { } """ - .trimIndent() - ) + .trimIndent() + ) .indented() test(source).expectClean() From f8afdfa45a3b4ac1870e6c5d3a5f889878b9810a Mon Sep 17 00:00:00 2001 From: Josh Stagg Date: Wed, 3 Sep 2025 15:52:03 -0700 Subject: [PATCH 10/10] Some custom saver checking --- .../RememberSaveableAcceptableDetector.kt | 54 +++-- .../RememberSaveableAcceptableDetectorTest.kt | 208 +++++++++++++++--- 2 files changed, 220 insertions(+), 42 deletions(-) diff --git a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt index e312748f..cf8da7d3 100644 --- a/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt +++ b/slack-lint-checks/src/main/java/slack/lint/compose/RememberSaveableAcceptableDetector.kt @@ -19,7 +19,6 @@ import com.intellij.psi.PsiMethod import com.intellij.psi.PsiParameter import com.intellij.psi.PsiPrimitiveType import com.intellij.psi.PsiType -import kotlin.contracts.ExperimentalContracts import org.jetbrains.uast.UCallExpression import org.jetbrains.uast.UElement import org.jetbrains.uast.UExpression @@ -28,7 +27,9 @@ import org.jetbrains.uast.UastEmptyExpression import org.jetbrains.uast.skipParenthesizedExprDown import org.jetbrains.uast.tryResolve import org.jetbrains.uast.tryResolveNamed +import org.jetbrains.uast.unwrapReferenceNameElement import org.jetbrains.uast.visitor.AbstractUastVisitor +import slack.lint.util.Name import slack.lint.util.Package import slack.lint.util.asClass import slack.lint.util.implements @@ -39,7 +40,8 @@ import slack.lint.util.sourceImplementation private const val REMEMBER_SAVEABLE_METHOD_NAME = "rememberSaveable" private val ComposeRuntimePackageName = Package("androidx.compose.runtime") -private val RememberSaveablePackageName = Package("androidx.compose.runtime.saveable") +private val ComposeSaveablePackageName = Package("androidx.compose.runtime.saveable") +private val SaverFQN = Name(ComposeSaveablePackageName, "Saver").javaFqn private val AUTO_SAVER = UastEmptyExpression(null) // todo @@ -48,7 +50,6 @@ private val AUTO_SAVER = UastEmptyExpression(null) // Think the savers likely go in a different detector // - MapSaver check // - ListSaver check -// - Custom saver check class RememberSaveableAcceptableDetector : Detector(), SourceCodeScanner { override fun getApplicableUastTypes(): List> = @@ -82,7 +83,7 @@ private class RememberSaveableElementHandler(val context: JavaContext) : UElemen val method = node.resolve() val returnType = node.returnType if ( - method != null && returnType != null && method.isInPackageName(RememberSaveablePackageName) + method != null && returnType != null && method.isInPackageName(ComposeSaveablePackageName) ) { val arguments = evaluator.computeArgumentMapping(node, method) val saver = arguments.getSaver() @@ -145,23 +146,49 @@ private class RememberSaveableElementHandler(val context: JavaContext) : UElemen private fun visitCustomSaver(node: UCallExpression, saver: UExpression) { // todo Custom saver, check the return type of the save call and report an error. - context.report( - RememberSaveableAcceptableDetector.Companion.ISSUE, - context.getLocation(node), - RememberSaveableAcceptableDetector.Companion.ISSUE.getBriefDescription(TextFormat.TEXT), - ) + val unwrapedElement = unwrapReferenceNameElement(saver) + val visitor = CustomSaverVisitor() + unwrapedElement?.accept(visitor) + if (visitor.saveableType?.isAcceptableType() == false) { + context.report( + RememberSaveableAcceptableDetector.Companion.ISSUE, + context.getLocation(node), + RememberSaveableAcceptableDetector.Companion.ISSUE.getBriefDescription(TextFormat.TEXT), + ) + } } } -private fun Map.getSaver(): UExpression { - val saver = firstNotNullOfOrNull { (expression, parameter) -> - if (parameter.name == "saver") expression else null +private class CustomSaverVisitor : AbstractUastVisitor() { + + var saveableType: PsiType? = null + + override fun visitExpression(node: UExpression): Boolean { + this.saveableType = node.getExpressionType()?.asClass()?.let { unwrapSaveableType(it) } + return saveableType != null || super.visitExpression(node) } + + private fun unwrapSaveableType(element: PsiClassType): PsiType? { + val resolved = element.resolve() + return when { + resolved == null -> null + resolved.qualifiedName == SaverFQN -> element.parameters[1] + else -> resolved.superTypes.firstNotNullOfOrNull { type -> unwrapSaveableType(type) } + } + } +} + +private fun Map.getSaver(): UExpression { + val saver = + firstNotNullOfOrNull { (expression, parameter) -> + if (parameter.name == "saver") expression else null + } + ?.skipParenthesizedExprDown() val resolved = saver?.tryResolve() if ( resolved is PsiMethod && resolved.name == "autoSaver" && - resolved.isInPackageName(RememberSaveablePackageName) + resolved.isInPackageName(ComposeSaveablePackageName) ) { return AUTO_SAVER } @@ -292,7 +319,6 @@ private class MutableStateOfVisitor(private val context: JavaContext) : Abstract * Checks if the mutableStateOf call uses an acceptable SnapshotMutationPolicy. Acceptable policies * are: neverEqualPolicy, structuralEqualityPolicy, referentialEqualityPolicy */ -@OptIn(ExperimentalContracts::class) private fun hasAcceptablePolicy(call: UCallExpression?, context: JavaContext): Boolean { val method = call?.resolve() ?: return true // Default policy is acceptable val arguments = context.evaluator.computeArgumentMapping(call, method) diff --git a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt index 32b180cc..0a9239da 100644 --- a/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt +++ b/slack-lint-checks/src/test/java/slack/lint/compose/RememberSaveableAcceptableDetectorTest.kt @@ -5,6 +5,7 @@ package slack.lint.compose import com.android.tools.lint.checks.infrastructure.LintDetectorTest.java import com.android.tools.lint.checks.infrastructure.LintDetectorTest.kotlin import com.android.tools.lint.checks.infrastructure.TestFile +import org.junit.Ignore import org.junit.Test import slack.lint.BaseSlackLintTest @@ -14,20 +15,6 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { override fun getIssues() = listOf(RememberSaveableAcceptableDetector.ISSUE) - private fun test(source: TestFile) = - lint() - .files( - JAVA_IO, - JAVA_UTIL, - ANDROID_OS, - ANDROID_UTIL, - COMPOSE_RUNTIME, - COMPOSE_SAVEABLE, - TEST_SAVEABLES, - source, - ) - .run() - @Test fun acceptableTypes_primitives_noError() { val source = @@ -221,7 +208,7 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { @Composable fun TestComposable() { fun customPolicy(): SnapshotMutationPolicy = TODO() - + val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } } @@ -233,11 +220,11 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { test(source) .expect( """ - src/test/test.kt:11: Error: Brief description (Custom policy) [RememberSaveableTypeMustBeAcceptable] - val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } - ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - 1 error - """ + src/test/test.kt:11: Error: Brief description (Custom policy) [RememberSaveableTypeMustBeAcceptable] + val customPolicyState = rememberSaveable { mutableStateOf("value", customPolicy()) } + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + 1 error + """ .trimIndent() ) } @@ -255,14 +242,18 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { @Composable fun TestComposable() { - fun customPolicy(): SnapshotMutationPolicy = TODO() - val saver = - Saver, String>({ it.value }, { mutableStateOf(it, customPolicy()) }) - val customPolicyState = - rememberSaveable(saver = saver) { mutableStateOf("value", customPolicy()) } + fun customPolicy(): SnapshotMutationPolicy = TODO() + val saver = + Saver, String>( + save = { it.value }, + restore = { mutableStateOf(value = it, customPolicy()) }, + ) + val customPolicyState = + rememberSaveable(saver = saver) { + mutableStateOf(value = "value", policy = customPolicy()) + } } - - """ + """ .trimIndent() ) .indented() @@ -270,6 +261,7 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { test(source).expectClean() } + @Ignore // TODO implement @Test fun acceptableTypes_collections_wrong_saver_error() { val source = @@ -326,7 +318,7 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { import androidx.compose.runtime.Composable import androidx.compose.runtime.saveable.rememberSaveable - + @Composable fun TestComposable() { val lambdaValue = rememberSaveable { { "value" } } @@ -390,6 +382,161 @@ class RememberSaveableAcceptableDetectorTest : BaseSlackLintTest() { .trimIndent() ) } + + @Test + fun customSaver_acceptableTypes_primitives_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.Saver + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val string1Saver = Saver(save = { it }, restore = { it }) + val string3Saver = + object : Saver { + override fun restore(value: String) = value + + override fun SaverScope.save(value: String) = value + } + val string1Value = rememberSaveable(saver = string1Saver) { "test" } + val string2Value = rememberSaveable(saver = Saver(save = { it }, restore = { it })) { "test" } + val string3Value = rememberSaveable(saver = string3Saver) { "test" } + val intSaver = Saver(save = { it }, restore = { it }) + val intValue = rememberSaveable(saver = intSaver) { 42 } + val booleanSaver = Saver(save = { it }, restore = { it }) + val booleanValue = rememberSaveable(saver = booleanSaver) { true } + val floatSaver = Saver(save = { it }, restore = { it }) + val floatValue = rememberSaveable { 1.0f } + val doubleSaver = Saver(save = { it }, restore = { it }) + val doubleValue = rememberSaveable(saver = doubleSaver) { 1.0 } + val longSaver = Saver(save = { it }, restore = { it }) + val longValue = rememberSaveable { 1L } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun customSaver_acceptableTypes_arrays_noError() { + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.Saver + import androidx.compose.runtime.saveable.rememberSaveable + + @Composable + fun TestComposable() { + val stringSaver = Saver, Array>(save = { it }, restore = { it }) + val stringArray = rememberSaveable(saver = stringSaver) { arrayOf("test") } + val intSaver = Saver(save = { it }, restore = { it }) + val intArray = rememberSaveable(saver = intSaver) { intArrayOf(42) } + val booleanSaver = Saver(save = { it }, restore = { it }) + val booleanArray = rememberSaveable(saver = booleanSaver) { booleanArrayOf(true) } + val floatSaver = Saver(save = { it }, restore = { it }) + val floatArray = rememberSaveable(saver = floatSaver) { floatArrayOf(1.0f) } + val doubleSaver = Saver(save = { it }, restore = { it }) + val doubleArray = rememberSaveable(saver = doubleSaver) { doubleArrayOf(1.0) } + val longSaver = Saver(save = { it }, restore = { it }) + val longArray = rememberSaveable(saver = longSaver) { longArrayOf(1L) } + val parcelableSaver = + Saver, Array>(save = { it }, restore = { it }) + val parcelableArray = rememberSaveable(saver = parcelableSaver) { arrayOf() } + } + + """ + .trimIndent() + ) + .indented() + + test(source).expectClean() + } + + @Test + fun customSaver_acceptableTypes_classes_noError() { + val saver = + kotlin( + """ + package test.saver + + import androidx.compose.runtime.saveable.Saver + import test.TestParcelable + + class CustomTestParcelableSaver : Saver { + override fun restore(value: TestParcelable) = value + + override fun SaverScope.save(value: TestParcelable) = value + } + + class CustomTestSerializableSaver : Saver { + override fun restore(value: TestSerializable) = value + + override fun SaverScope.save(value: TestSerializable) = value + } + + class CustomArrayListSaver : Saver, ArrayList> { + override fun restore(value: ArrayList) = value + + override fun SaverScope.save(value: ArrayList) = value + } + + """ + .trimIndent() + ) + .indented() + + val source = + kotlin( + """ + package test + + import androidx.compose.runtime.Composable + import androidx.compose.runtime.saveable.rememberSaveable + import test.saver.CustomTestParcelableSaver + import test.saver.CustomTestSerializableSaver + + @Composable + fun TestComposable() { + val serializableValue = + rememberSaveable(saver = CustomTestSerializableSaver()) { TestSerializable() } + val parcelableValue = + rememberSaveable(saver = CustomTestParcelableSaver()) { TestParcelable() } + val arrayListValue = rememberSaveable(saver = CustomArrayListSaver()) { ArrayList() } + } + + """ + .trimIndent() + ) + .indented() + + test(saver, source).expectClean() + } + + private fun test(vararg source: TestFile) = + lint() + .files( + JAVA_IO, + JAVA_UTIL, + ANDROID_OS, + ANDROID_UTIL, + COMPOSE_RUNTIME, + COMPOSE_SAVEABLE, + TEST_SAVEABLES, + *source, + ) + .run() } private val JAVA_IO = @@ -519,6 +666,11 @@ private val COMPOSE_SAVEABLE = fun restore(value: Saveable): Original? } + public fun Saver( + save: SaverScope.(value: Original) -> Saveable?, + restore: (value: Saveable) -> Original?, + ): Saver = TODO() + fun autoSaver(): Saver = TODO() @Composable