Refactor decodeField and get*

This commit is contained in:
Joey Yakimowich-Payne 2020-04-05 21:36:01 -06:00
commit 5b49b86b23

View file

@ -41,6 +41,8 @@ type
AnyProtoType* = SomeVarint | SomeLengthDelimited | SomeFixed | object
UnexpectedTypeError* = object of ValueError
proc newProtoBuffer*(): ProtoBuffer =
ProtoBuffer(outstream: OutputStream.init(), fieldNum: 1)
@ -111,12 +113,12 @@ proc encodeField(stream: OutputStreamVar, fieldNum: int, value: SomeFixed32) {.i
stream.put(value)
proc put(stream: OutputStreamVar, value: SomeLengthDelimited) {.inline.} =
stream.put(len(value).uint)
for b in value:
stream.append byte(b)
proc encodeField(stream: OutputStreamVar, fieldNum: int, value: SomeLengthDelimited) {.inline.} =
stream.append protoHeader(fieldNum, LengthDelimited)
stream.put(len(value).uint)
stream.put(value)
proc put(stream: OutputStreamVar, value: object) {.inline.}
@ -131,7 +133,6 @@ proc encodeField(stream: OutputStreamVar, fieldNum: int, value: object) {.inline
let objOutput = objStream.getOutput()
if objOutput.len > 0:
stream.append protoHeader(fieldNum, LengthDelimited)
stream.put(len(objOutput).uint)
stream.put(objOutput)
proc put(stream: OutputStreamVar, value: object) {.inline.} =
@ -145,14 +146,14 @@ proc put(stream: OutputStreamVar, value: object) {.inline.} =
proc encode*(protobuf: var ProtoBuffer, value: object) {.inline.} =
protobuf.outstream.put(value)
proc encodeField*(protobuf: var ProtoBuffer, value: AnyProtoType) {.inline.} =
protobuf.outstream.encodeField(protobuf.fieldNum, value)
inc protobuf.fieldNum
proc encodeField*(protobuf: var ProtoBuffer, fieldNum: int, value: AnyProtoType) {.inline.} =
protobuf.outstream.encodeField(fieldNum, value)
proc getFixed*[T: SomeFixed](
proc encodeField*(protobuf: var ProtoBuffer, value: AnyProtoType) {.inline.} =
protobuf.encodeField(protobuf.fieldNum, value)
inc protobuf.fieldNum
proc get*[T: SomeFixed](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
@ -173,7 +174,7 @@ proc getFixed*[T: SomeFixed](
result = cast[T](value)
proc getVarint[T: SomeVarint](
proc get[T: SomeVarint](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
@ -188,6 +189,7 @@ proc getVarint[T: SomeVarint](
var value: byte
else:
var value: T
var shiftAmount = 0
while true:
value += type(value)(bytes[outOffset] and 0b0111_1111) shl shiftAmount
@ -206,53 +208,30 @@ proc getVarint[T: SomeVarint](
else:
result = T(value)
proc decodeField*[T: SomeVarint](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): ProtoField[T] {.inline.} =
# Only up to 128 bits supported by the spec
assert sizeof(T) <= 16
var bytesRead = 0
let wireTy = wireType(bytes[outOffset])
proc checkType[T: SomeVarint](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} =
let wireTy = wireType(tyByte)
if wireTy != Varint:
raise newException(Exception, fmt"Not a varint at offset {outOffset}! Received a {wireTy}")
raise newException(UnexpectedTypeError, fmt"Not a varint at offset {offset}! Received a {wireTy}")
result.index = fieldNumber(bytes[outOffset])
increaseBytesRead()
proc checkType[T: SomeFixed](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} =
let wireTy = wireType(tyByte)
if wireTy notin {Fixed32, Fixed64}:
raise newException(UnexpectedTypeError, fmt"Not a fixed32 or fixed64 at offset {offset}! Received a {wireTy}")
result.value = getVarint(bytes, ty, outOffset, outBytesProcessed, numBytesToRead)
proc checkType[T: SomeLengthDelimited](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} =
let wireTy = wireType(tyByte)
if wireTy != LengthDelimited:
raise newException(UnexpectedTypeError, fmt"Not a length delimited value at offset {offset}! Received a {wireTy}")
proc decodeField*[T: SomeFixed](
proc get*[T: SomeLengthDelimited](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): ProtoField[T] {.inline.} =
var bytesRead = 0
let wireTy = wireType(bytes[outOffset])
if wireTy notin {Fixed32, Fixed64}:
raise newException(Exception, fmt"Not a fixed32 or fixed64 at offset {outOffset}! Received a {wireTy}")
result.index = fieldNumber(bytes[outOffset])
increaseBytesRead()
result.value = getFixed(bytes, ty, outOffset, outBytesProcessed, numBytesToRead)
proc getLengthDelimited*[T: SomeLengthDelimited](
bytes: var seq[byte],
ty: typedesc[T], outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): T {.inline.} =
var bytesRead = 0
let decodedSize = getVarint(bytes, uint, outOffset, outBytesProcessed, numBytesToRead)
let decodedSize = bytes.get(uint, outOffset, outBytesProcessed, numBytesToRead)
let length = decodedSize.int
when T is string:
@ -268,7 +247,7 @@ proc getLengthDelimited*[T: SomeLengthDelimited](
increaseBytesRead(length)
proc decodeField*[T: SomeLengthDelimited](
proc decodeField*[T: SomeFixed | SomeVarint | SomeLengthDelimited](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
@ -276,47 +255,21 @@ proc decodeField*[T: SomeLengthDelimited](
numBytesToRead = none(int)
): ProtoField[T] {.inline.} =
var bytesRead = 0
let wireTy = wireType(bytes[outOffset])
if wireTy != LengthDelimited:
raise newException(Exception, fmt"Not a length delimited value at offset {outOffset}! Received a {wireTy}")
checkType(bytes[outOffset], ty, outOffset)
result.index = fieldNumber(bytes[outOffset])
increaseBytesRead()
result.value = getLengthDelimited(bytes, ty, outOffset, outBytesProcessed, numBytesToRead)
result.value = bytes.get(ty, outOffset, outBytesProcessed, numBytesToRead)
macro getField(obj: typed, fieldNum: int, ty: typedesc): untyped =
template fieldTypeCheck(obj, field, fieldNum, ty) =
when type(obj.field) is type(ty):
obj.field
else:
let fnum {.inject.} = fieldNum
raise newException(Exception, fmt"Could not find field at position {fnum}.")
let typeFields = obj.getTypeInst.getType
let objFields = typeFields[2]
expectKind objFields, nnkRecList
result = newStmtList()
let caseStmt = newNimNode(nnkCaseStmt)
caseStmt.add(fieldNum)
for i in 0 ..< len(objFields) - 1:
let field = objFields[i]
let ofBranch = newNimNode(nnkOfBranch)
ofBranch.add(newLit(i+1))
ofBranch.add(getAst(fieldTypeCheck(obj, field, fieldNum, ty)))
caseStmt.add(ofBranch)
let field = objFields[len(objFields) - 1]
let elseBranch = newNimNode(nnkElse)
elseBranch.add(
nnkStmtList.newTree(getAst(fieldTypeCheck(obj, field, fieldNum, ty)))
)
caseStmt.add(elseBranch)
result.add(caseStmt)
proc decodeField*[T: object](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): ProtoField[T] {.inline.}
macro setField(obj: typed, fieldNum: int, offset: int, bytesProcessed: int, bytesToRead: Option[int], value: untyped): untyped =
let typeFields = obj.getTypeInst.getType
@ -367,7 +320,7 @@ proc decodeField*[T: object](
# read LD header
# then read only amount of bytes needed
increaseBytesRead()
let decodedSize = getVarint(bytes, uint, outOffset, outBytesProcessed, numBytesToRead)
let decodedSize = bytes.get(uint, outOffset, outBytesProcessed, numBytesToRead)
let bytesToRead = some(decodedSize.int)
let oldOffset = outOffset