|
1 | 1 | import com.google.protobuf.InvalidProtocolBufferException
|
2 | 2 | import com.thesamet.proto.e2e.reqs.RequiredFields
|
| 3 | +import protobuf_unittest.unittest.TestEmptyMessage |
| 4 | +import scalapb.UnknownFieldSet |
3 | 5 |
|
4 | 6 | import org.scalatest.flatspec.AnyFlatSpec
|
5 | 7 | import org.scalatest.matchers.must.Matchers
|
6 | 8 |
|
7 | 9 | class RequiredFieldsSpec extends AnyFlatSpec with Matchers {
|
| 10 | + |
| 11 | + private val descriptor = RequiredFields.javaDescriptor |
| 12 | + |
| 13 | + private def partialMessage(fields: Map[Int, Int]): Array[Byte] = { |
| 14 | + val fieldSet = fields.foldLeft(UnknownFieldSet.empty){ case (fieldSet, (field, value)) => |
| 15 | + fieldSet |
| 16 | + .withField(field, UnknownFieldSet.Field(varint = Seq(value))) |
| 17 | + } |
| 18 | + |
| 19 | + TestEmptyMessage(fieldSet).toByteArray |
| 20 | + } |
| 21 | + |
| 22 | + private val allFieldsSet: Map[Int, Int] = (100 to 164).map(i => (i, i)).toMap |
| 23 | + |
8 | 24 | "RequiredMessage" should "throw InvalidProtocolBufferException for empty byte array" in {
|
9 |
| - intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(Array[Byte]())) |
| 25 | + val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(Array[Byte]())) |
| 26 | + |
| 27 | + exception.getMessage() must startWith("Message missing required fields") |
| 28 | + } |
| 29 | + |
| 30 | + it should "throw no exception when all fields are set correctly" in { |
| 31 | + val parsed = RequiredFields.parseFrom(partialMessage(allFieldsSet)) |
| 32 | + parsed must be(a[RequiredFields]) |
| 33 | + parsed.f0 must be(100) |
| 34 | + parsed.f64 must be(164) |
| 35 | + } |
| 36 | + |
| 37 | + it should "throw an exception if a field is missing and name the missing field" in { |
| 38 | + val fields = allFieldsSet.removed(123) |
| 39 | + val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields))) |
| 40 | + |
| 41 | + exception.getMessage() must be("Message missing required fields: f23") |
| 42 | + } |
| 43 | + |
| 44 | + it should "throw an exception if a multiple fields are missing and name those missing fields" in { |
| 45 | + val fields = allFieldsSet.removed(123).removed(164).removed(130) |
| 46 | + val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields))) |
| 47 | + |
| 48 | + exception.getMessage() must be("Message missing required fields: f23, f30, f64") |
| 49 | + } |
| 50 | + |
| 51 | + it should "sort the missing fields by field number" in { |
| 52 | + val fields = Map.empty[Int, Int] |
| 53 | + val exception = intercept[InvalidProtocolBufferException](RequiredFields.parseFrom(partialMessage(fields))) |
| 54 | + val missingFields =exception.getMessage().stripPrefix("Message missing required fields: ").split(", ") |
| 55 | + |
| 56 | + missingFields.sortBy[Int](field => descriptor.findFieldByName(field).getNumber()) must be(missingFields) |
| 57 | + |
| 58 | + missingFields.toSeq mustBe Seq.tabulate(65)(i => s"f$i") |
10 | 59 | }
|
11 | 60 | }
|
0 commit comments