diff --git a/src/main/scala/com/fasterxml/jackson/module/scala/deser/ScalaObjectDeserializerModule.scala b/src/main/scala/com/fasterxml/jackson/module/scala/deser/ScalaObjectDeserializerModule.scala index fe061873..309a19e3 100644 --- a/src/main/scala/com/fasterxml/jackson/module/scala/deser/ScalaObjectDeserializerModule.scala +++ b/src/main/scala/com/fasterxml/jackson/module/scala/deser/ScalaObjectDeserializerModule.scala @@ -1,31 +1,27 @@ package com.fasterxml.jackson.module.scala.deser import com.fasterxml.jackson.core.JsonParser +import com.fasterxml.jackson.databind._ import com.fasterxml.jackson.databind.deser.Deserializers import com.fasterxml.jackson.databind.deser.std.StdDeserializer -import com.fasterxml.jackson.databind._ import com.fasterxml.jackson.module.scala.JacksonModule -import com.fasterxml.jackson.module.scala.util.ClassW +import com.fasterxml.jackson.module.scala.util.ScalaObject import scala.languageFeature.postfixOps -import scala.util.control.NonFatal -private class ScalaObjectDeserializer(clazz: Class[_]) extends StdDeserializer[Any](classOf[Any]) { - override def deserialize(p: JsonParser, ctxt: DeserializationContext): Any = { - try { - clazz.getField("MODULE$").get(null) - } catch { - case NonFatal(_) => null - } - } +private class ScalaObjectDeserializer(scalaObject: Any) extends StdDeserializer[Any](classOf[Any]) { + override def deserialize(p: JsonParser, ctxt: DeserializationContext): Any = scalaObject } private object ScalaObjectDeserializerResolver extends Deserializers.Base { override def findBeanDeserializer(javaType: JavaType, config: DeserializationConfig, beanDesc: BeanDescription): JsonDeserializer[_] = { val clazz = javaType.getRawClass - if (ClassW(clazz).isScalaObject) - new ScalaObjectDeserializer(clazz) - else null + + Option(clazz) + .collect { + case ScalaObject(value) => new ScalaObjectDeserializer(value) + } + .orNull } } diff --git a/src/main/scala/com/fasterxml/jackson/module/scala/util/ScalaObject.scala b/src/main/scala/com/fasterxml/jackson/module/scala/util/ScalaObject.scala new file mode 100644 index 00000000..60414126 --- /dev/null +++ b/src/main/scala/com/fasterxml/jackson/module/scala/util/ScalaObject.scala @@ -0,0 +1,29 @@ +package com.fasterxml.jackson.module.scala.util + +import java.lang.reflect.Field + +private [scala] object ScalaObject { + + private val MODULE_FIELD_NAME = "MODULE$" + + private def getStaticField(field: Field): Option[Any] = + try Some(field.get(null)) + catch { + case _: NullPointerException | _: IllegalAccessException => None + } + + private def moduleFieldOption(clazz: Class[_]): Option[Field] = + try Some(clazz.getDeclaredField(MODULE_FIELD_NAME)) + catch { + case _: NoSuchFieldException => None + } + + private def moduleFieldValue(clazz: Class[_]): Option[Any] = for { + moduleField <- moduleFieldOption(clazz) + value <- getStaticField(moduleField) + } yield value + + def unapply(clazz: Class[_]): Option[Any] = + if (clazz.getName.endsWith("$")) moduleFieldValue(clazz) + else None +} diff --git a/src/test/scala/com/fasterxml/jackson/module/scala/util/ScalaObjectTest.scala b/src/test/scala/com/fasterxml/jackson/module/scala/util/ScalaObjectTest.scala new file mode 100644 index 00000000..d3b96f2e --- /dev/null +++ b/src/test/scala/com/fasterxml/jackson/module/scala/util/ScalaObjectTest.scala @@ -0,0 +1,48 @@ +package com.fasterxml.jackson.module.scala.util + +import org.scalatest.matchers.should.Matchers.{contain, convertToAnyShouldWrapper, empty} +import org.scalatest.wordspec.AnyWordSpecLike + +object TestObject + +case object TestCaseObject + +class TestClass + +class TestClassWithModuleField { + val MODULE$: TestClassWithModuleField = this +} + +object BarWrapper { + object Bar { + final case class Baz() + } +} + +class ScalaObjectTest extends AnyWordSpecLike { + + "ScalaObject" must { + "return Some(TestObject) for unapply(TestObject.getClass)" in { + ScalaObject.unapply(TestObject.getClass) should contain(TestObject) + } + + "return Some(TestCaseObject) for unapply(TestCaseObject.getClass)" in { + ScalaObject.unapply(TestCaseObject.getClass) should contain(TestCaseObject) + } + + "return None for unapply(testClassInstance.getClass)" in { + val testClassInstance = new TestClass + ScalaObject.unapply(testClassInstance.getClass) shouldBe empty + } + + "return None for unapply(testClassWithModuleFieldInstance.getClass)" in { + val testClassWithModuleFieldInstance = new TestClassWithModuleField + ScalaObject.unapply(testClassWithModuleFieldInstance.getClass) shouldBe empty + } + + "return None for unapply(bazInstance.getClass)" in { + val bazInstance = BarWrapper.Bar.Baz() + ScalaObject.unapply(bazInstance.getClass) shouldBe empty + } + } +}