Skip to content

Commit a811bce

Browse files
committed
Add tests for required feature
1 parent daebaa3 commit a811bce

File tree

2 files changed

+62
-11
lines changed

2 files changed

+62
-11
lines changed

compiler-plugin/src/main/scala/scalapb/compiler/ParseFromGenerator.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,11 @@ private[compiler] class ParseFromGenerator(
110110

111111
private def usesBaseTypeInBuilder(field: FieldDescriptor) = field.isSingular
112112

113-
val requiredFieldMap: Map[FieldDescriptor, Int] =
114-
message.fields.filter(fd => fd.isRequired || fd.noBoxRequired).zipWithIndex.toMap
113+
private val requiredFields: Seq[(FieldDescriptor, Int)] =
114+
message.fields.filter(fd => fd.isRequired || fd.noBoxRequired).zipWithIndex
115+
116+
private val requiredFieldMap: Map[FieldDescriptor, Int] =
117+
requiredFields.toMap
115118

116119
val myFullScalaName = message.scalaType.fullNameWithMaybeRoot(message)
117120

@@ -231,16 +234,15 @@ private[compiler] class ParseFromGenerator(
231234
p.add(s"""if (${r}) {""")
232235
.indent
233236
.add("val __missingFields = Seq.newBuilder[_root_.scala.Predef.String]")
234-
.print(requiredFieldMap.toSeq.sortBy(_._2)) {
235-
case (p, (fieldDescriptor, fieldNumber)) =>
236-
val bitmask = s"0x${"%x".format(1L << fieldNumber)}L"
237-
val fieldVariable = s"__requiredFields${fieldNumber / 64}"
238-
p.add(
239-
s"""if (($fieldVariable & $bitmask) != 0L) __missingFields += "${fieldDescriptor.scalaName}""""
240-
)
237+
.print(requiredFields) { case (p, (fieldDescriptor, fieldNumber)) =>
238+
val bitmask = f"${1L << fieldNumber}%#018xL"
239+
val fieldVariable = s"__requiredFields${fieldNumber / 64}"
240+
p.add(
241+
s"""if (($fieldVariable & $bitmask) != 0L) __missingFields += "${fieldDescriptor.scalaName}""""
242+
)
241243
}
242244
.add(
243-
s"""val __message = s"Message missing required fields: $${__missingFields.result.mkString(", ")}"""",
245+
s"""val __message = s"Message missing required fields: $${__missingFields.result().mkString(", ")}"""",
244246
s"""throw new _root_.com.google.protobuf.InvalidProtocolBufferException(__message)"""
245247
)
246248
.outdent
Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,60 @@
11
import com.google.protobuf.InvalidProtocolBufferException
22
import com.thesamet.proto.e2e.reqs.RequiredFields
3+
import protobuf_unittest.unittest.TestEmptyMessage
4+
import scalapb.UnknownFieldSet
35

46
import org.scalatest.flatspec.AnyFlatSpec
57
import org.scalatest.matchers.must.Matchers
68

79
class RequiredFieldsSpec extends AnyFlatSpec with Matchers {
10+
11+
private val descriptor = RequiredFields.javaDescriptor
12+
13+
private def partialMessage(fields: Map[Int, Int]): Array[Byte] = {
14+
val fieldSet = fields.foldLeft(UnknownFieldSet.empty){ case (fieldSet, (field, value)) =>
15+
fieldSet
16+
.withField(field, UnknownFieldSet.Field(varint = Seq(value)))
17+
}
18+
19+
TestEmptyMessage(fieldSet).toByteArray
20+
}
21+
22+
private val allFieldsSet: Map[Int, Int] = (100 to 164).map(i => (i, i)).toMap
23+
824
"RequiredMessage" should "throw InvalidProtocolBufferException for empty byte array" in {
9-
intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(Array[Byte]()))
25+
val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(Array[Byte]()))
26+
27+
exception.getMessage() must startWith("Message missing required fields")
28+
}
29+
30+
it should "throw no exception when all fields are set correctly" in {
31+
val parsed = RequiredFields.parseFrom(partialMessage(allFieldsSet))
32+
parsed must be(a[RequiredFields])
33+
parsed.f0 must be(100)
34+
parsed.f64 must be(164)
35+
}
36+
37+
it should "throw an exception if a field is missing and name the missing field" in {
38+
val fields = allFieldsSet.removed(123)
39+
val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields)))
40+
41+
exception.getMessage() must be("Message missing required fields: f23")
42+
}
43+
44+
it should "throw an exception if a multiple fields are missing and name those missing fields" in {
45+
val fields = allFieldsSet.removed(123).removed(164).removed(130)
46+
val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields)))
47+
48+
exception.getMessage() must be("Message missing required fields: f23, f30, f64")
49+
}
50+
51+
it should "sort the missing fields by field number" in {
52+
val fields = Map.empty[Int, Int]
53+
val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields)))
54+
val missingFields =exception.getMessage().stripPrefix("Message missing required fields: ").split(", ")
55+
56+
missingFields.sortBy[Int](field => descriptor.findFieldByName(field).getNumber()) must be(missingFields)
57+
58+
missingFields.toSeq mustBe Seq.tabulate(65)(i => s"f$i")
1059
}
1160
}

0 commit comments

Comments
 (0)