From 5b49b86b236c50b03379df3aff69eb4d891905ae Mon Sep 17 00:00:00 2001 From: Joey Yakimowich-Payne Date: Sun, 5 Apr 2020 21:36:01 -0600 Subject: [PATCH] Refactor decodeField and get* --- protobuf_serialization.nim | 117 +++++++++++-------------------------- 1 file changed, 35 insertions(+), 82 deletions(-) diff --git a/protobuf_serialization.nim b/protobuf_serialization.nim index 53590fe..e4c4293 100644 --- a/protobuf_serialization.nim +++ b/protobuf_serialization.nim @@ -41,6 +41,8 @@ type AnyProtoType* = SomeVarint | SomeLengthDelimited | SomeFixed | object + UnexpectedTypeError* = object of ValueError + proc newProtoBuffer*(): ProtoBuffer = ProtoBuffer(outstream: OutputStream.init(), fieldNum: 1) @@ -111,12 +113,12 @@ proc encodeField(stream: OutputStreamVar, fieldNum: int, value: SomeFixed32) {.i stream.put(value) proc put(stream: OutputStreamVar, value: SomeLengthDelimited) {.inline.} = + stream.put(len(value).uint) for b in value: stream.append byte(b) proc encodeField(stream: OutputStreamVar, fieldNum: int, value: SomeLengthDelimited) {.inline.} = stream.append protoHeader(fieldNum, LengthDelimited) - stream.put(len(value).uint) stream.put(value) proc put(stream: OutputStreamVar, value: object) {.inline.} @@ -131,7 +133,6 @@ proc encodeField(stream: OutputStreamVar, fieldNum: int, value: object) {.inline let objOutput = objStream.getOutput() if objOutput.len > 0: stream.append protoHeader(fieldNum, LengthDelimited) - stream.put(len(objOutput).uint) stream.put(objOutput) proc put(stream: OutputStreamVar, value: object) {.inline.} = @@ -145,14 +146,14 @@ proc put(stream: OutputStreamVar, value: object) {.inline.} = proc encode*(protobuf: var ProtoBuffer, value: object) {.inline.} = protobuf.outstream.put(value) -proc encodeField*(protobuf: var ProtoBuffer, value: AnyProtoType) {.inline.} = - protobuf.outstream.encodeField(protobuf.fieldNum, value) - inc protobuf.fieldNum - proc encodeField*(protobuf: var ProtoBuffer, fieldNum: int, value: AnyProtoType) {.inline.} = protobuf.outstream.encodeField(fieldNum, value) -proc getFixed*[T: SomeFixed]( +proc encodeField*(protobuf: var ProtoBuffer, value: AnyProtoType) {.inline.} = + protobuf.encodeField(protobuf.fieldNum, value) + inc protobuf.fieldNum + +proc get*[T: SomeFixed]( bytes: var seq[byte], ty: typedesc[T], outOffset: var int, @@ -173,7 +174,7 @@ proc getFixed*[T: SomeFixed]( result = cast[T](value) -proc getVarint[T: SomeVarint]( +proc get[T: SomeVarint]( bytes: var seq[byte], ty: typedesc[T], outOffset: var int, @@ -188,6 +189,7 @@ proc getVarint[T: SomeVarint]( var value: byte else: var value: T + var shiftAmount = 0 while true: value += type(value)(bytes[outOffset] and 0b0111_1111) shl shiftAmount @@ -206,53 +208,30 @@ proc getVarint[T: SomeVarint]( else: result = T(value) -proc decodeField*[T: SomeVarint]( - bytes: var seq[byte], - ty: typedesc[T], - outOffset: var int, - outBytesProcessed: var int, - numBytesToRead = none(int) -): ProtoField[T] {.inline.} = - # Only up to 128 bits supported by the spec - assert sizeof(T) <= 16 - - var bytesRead = 0 - - let wireTy = wireType(bytes[outOffset]) +proc checkType[T: SomeVarint](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} = + let wireTy = wireType(tyByte) if wireTy != Varint: - raise newException(Exception, fmt"Not a varint at offset {outOffset}! Received a {wireTy}") + raise newException(UnexpectedTypeError, fmt"Not a varint at offset {offset}! Received a {wireTy}") - result.index = fieldNumber(bytes[outOffset]) - increaseBytesRead() +proc checkType[T: SomeFixed](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} = + let wireTy = wireType(tyByte) + if wireTy notin {Fixed32, Fixed64}: + raise newException(UnexpectedTypeError, fmt"Not a fixed32 or fixed64 at offset {offset}! Received a {wireTy}") - result.value = getVarint(bytes, ty, outOffset, outBytesProcessed, numBytesToRead) +proc checkType[T: SomeLengthDelimited](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} = + let wireTy = wireType(tyByte) + if wireTy != LengthDelimited: + raise newException(UnexpectedTypeError, fmt"Not a length delimited value at offset {offset}! Received a {wireTy}") -proc decodeField*[T: SomeFixed]( +proc get*[T: SomeLengthDelimited]( bytes: var seq[byte], ty: typedesc[T], outOffset: var int, outBytesProcessed: var int, numBytesToRead = none(int) -): ProtoField[T] {.inline.} = - var bytesRead = 0 - - let wireTy = wireType(bytes[outOffset]) - if wireTy notin {Fixed32, Fixed64}: - raise newException(Exception, fmt"Not a fixed32 or fixed64 at offset {outOffset}! Received a {wireTy}") - - result.index = fieldNumber(bytes[outOffset]) - increaseBytesRead() - - result.value = getFixed(bytes, ty, outOffset, outBytesProcessed, numBytesToRead) - -proc getLengthDelimited*[T: SomeLengthDelimited]( - bytes: var seq[byte], - ty: typedesc[T], outOffset: var int, - outBytesProcessed: var int, - numBytesToRead = none(int) ): T {.inline.} = var bytesRead = 0 - let decodedSize = getVarint(bytes, uint, outOffset, outBytesProcessed, numBytesToRead) + let decodedSize = bytes.get(uint, outOffset, outBytesProcessed, numBytesToRead) let length = decodedSize.int when T is string: @@ -268,7 +247,7 @@ proc getLengthDelimited*[T: SomeLengthDelimited]( increaseBytesRead(length) -proc decodeField*[T: SomeLengthDelimited]( +proc decodeField*[T: SomeFixed | SomeVarint | SomeLengthDelimited]( bytes: var seq[byte], ty: typedesc[T], outOffset: var int, @@ -276,47 +255,21 @@ proc decodeField*[T: SomeLengthDelimited]( numBytesToRead = none(int) ): ProtoField[T] {.inline.} = var bytesRead = 0 - let wireTy = wireType(bytes[outOffset]) - if wireTy != LengthDelimited: - raise newException(Exception, fmt"Not a length delimited value at offset {outOffset}! Received a {wireTy}") + + checkType(bytes[outOffset], ty, outOffset) result.index = fieldNumber(bytes[outOffset]) increaseBytesRead() - result.value = getLengthDelimited(bytes, ty, outOffset, outBytesProcessed, numBytesToRead) + result.value = bytes.get(ty, outOffset, outBytesProcessed, numBytesToRead) -macro getField(obj: typed, fieldNum: int, ty: typedesc): untyped = - template fieldTypeCheck(obj, field, fieldNum, ty) = - when type(obj.field) is type(ty): - obj.field - else: - let fnum {.inject.} = fieldNum - raise newException(Exception, fmt"Could not find field at position {fnum}.") - - let typeFields = obj.getTypeInst.getType - - let objFields = typeFields[2] - expectKind objFields, nnkRecList - - result = newStmtList() - let caseStmt = newNimNode(nnkCaseStmt) - caseStmt.add(fieldNum) - - for i in 0 ..< len(objFields) - 1: - let field = objFields[i] - let ofBranch = newNimNode(nnkOfBranch) - ofBranch.add(newLit(i+1)) - ofBranch.add(getAst(fieldTypeCheck(obj, field, fieldNum, ty))) - caseStmt.add(ofBranch) - - let field = objFields[len(objFields) - 1] - let elseBranch = newNimNode(nnkElse) - elseBranch.add( - nnkStmtList.newTree(getAst(fieldTypeCheck(obj, field, fieldNum, ty))) - ) - caseStmt.add(elseBranch) - - result.add(caseStmt) +proc decodeField*[T: object]( + bytes: var seq[byte], + ty: typedesc[T], + outOffset: var int, + outBytesProcessed: var int, + numBytesToRead = none(int) +): ProtoField[T] {.inline.} macro setField(obj: typed, fieldNum: int, offset: int, bytesProcessed: int, bytesToRead: Option[int], value: untyped): untyped = let typeFields = obj.getTypeInst.getType @@ -367,7 +320,7 @@ proc decodeField*[T: object]( # read LD header # then read only amount of bytes needed increaseBytesRead() - let decodedSize = getVarint(bytes, uint, outOffset, outBytesProcessed, numBytesToRead) + let decodedSize = bytes.get(uint, outOffset, outBytesProcessed, numBytesToRead) let bytesToRead = some(decodedSize.int) let oldOffset = outOffset