WIP Object serialization/deserialization

This commit is contained in:
Joey Yakimowich-Payne 2020-04-03 17:44:50 -06:00
commit 5f2219bae7
2 changed files with 260 additions and 84 deletions

View file

@ -1,11 +1,20 @@
import macros, strformat
import macros, strformat, typetraits, options
import faststreams
template sint32*() {.pragma.}
template sint64*() {.pragma.}
template sfixed32*() {.pragma.}
template sfixed64*() {.pragma.}
template fixed32*() {.pragma.}
template fixed64*() {.pragma.}
template float*() {.pragma.}
template double*() {.pragma.}
const
MaxMessageSize* = 1'u shl 22
type
ProtoBuffer* = ref object
ProtoBuffer* = object
fieldNum: int
outstream: OutputStreamVar
@ -13,23 +22,16 @@ type
## Protobuf's field types enum
Varint, Fixed64, LengthDelimited, StartGroup, EndGroup, Fixed32
ProtoField* = object
EncodingKind* = enum
ekNormal, ekZigzag
ProtoField*[T] = object
## Protobuf's message field representation object
index: int
case kind: ProtoWireType
of Varint:
vint*: uint64
of Fixed64:
vfloat64*: float64
of LengthDelimited:
vbuffer*: OutputStreamVar
of Fixed32:
vfloat32*: float32
of StartGroup, EndGroup:
discard
index*: int
value*: T
SomeSVarint* = int | int64 | int32 | int16 | int8 | enum
SomeUVarint* = uint | uint64 | uint32 | uint16 | uint8 | byte | bool
SomeUVarint* = uint | uint64 | uint32 | uint16 | uint8 | byte | bool | char
SomeVarint* = SomeSVarint | SomeUVarint
SomeLengthDelimited* = string | seq[byte] | seq[uint8] | cstring
@ -49,22 +51,32 @@ proc output*(proto: ProtoBuffer): seq[byte] {.inline.} =
template wireType(firstByte: byte): ProtoWireType =
(firstByte and 0b111).ProtoWireType
template fieldNumber(firstByte: byte): uint =
(firstByte shr 3) and 0b1111
template fieldNumber(firstByte: byte): int =
((firstByte shr 3) and 0b1111).int
template protoHeader*(fieldNum: int, wire: ProtoWireType): byte =
## Get protobuf's field header integer for ``index`` and ``wire``.
((cast[uint](fieldNum) shl 3) or cast[uint](wire)).byte
proc putVarint(stream: OutputStreamVar, value: SomeVarint) {.inline.} =
template increaseBytesRead(amount = 1) =
mixin isSome
bytesRead += amount
outOffset += amount
outBytesProcessed += amount
if numBytesToRead.isSome():
if (bytesRead > numBytesToRead.get()).unlikely:
raise newException(Exception, "Number of bytes read exceeded")
proc put(stream: OutputStreamVar, value: SomeVarint) {.inline.} =
when value is enum:
var value = cast[type(ord(value))](value)
elif value is bool:
elif value is bool or value is char:
var value = cast[byte](value)
else:
var value = value
when type(value) is SomeSVarint:
# Encode using zigzag
if value < type(value)(0):
value = not(value shl type(value)(1))
else:
@ -77,102 +89,235 @@ proc putVarint(stream: OutputStreamVar, value: SomeVarint) {.inline.} =
proc encode(stream: OutputStreamVar, fieldNum: int, value: SomeVarint) {.inline.} =
stream.append protoHeader(fieldNum, Varint)
stream.putVarint(value)
stream.put(value)
proc encode*(protobuf: ProtoBuffer, value: SomeVarint) {.inline.} =
proc encode*(protobuf: var ProtoBuffer, value: SomeVarint) {.inline.} =
protobuf.outstream.encode(protobuf.fieldNum, value)
inc protobuf.fieldNum
proc putLengthDelimited(stream: OutputStreamVar, value: SomeLengthDelimited) {.inline.} =
proc put(stream: OutputStreamVar, value: SomeLengthDelimited) {.inline.} =
for b in value:
stream.append byte(b)
proc encode(stream: OutputStreamVar, fieldNum: int, value: SomeLengthDelimited) {.inline.} =
stream.append protoHeader(fieldNum, LengthDelimited)
stream.putVarint(len(value).uint)
stream.putLengthDelimited(value)
stream.put(len(value).uint)
stream.put(value)
proc encode*(protobuf: ProtoBuffer, value: SomeLengthDelimited) {.inline.} =
proc encode*(protobuf: var ProtoBuffer, value: SomeLengthDelimited) {.inline.} =
protobuf.outstream.encode(protobuf.fieldNum, value)
inc protobuf.fieldNum
proc getVarint[T: SomeVarint](bytes: var seq[byte], ty: typedesc[T], offset = 0): tuple[value: T, bytesProcessed: int] {.inline.} =
proc put(stream: OutputStreamVar, value: object) {.inline.}
proc encode(stream: OutputStreamVar, fieldNum: int, value: object) {.inline.} =
#TODO Encode generic objects
stream.append protoHeader(fieldNum, LengthDelimited)
let objStream = OutputStream.init()
objStream.put(value)
let objOutput = objStream.getOutput()
stream.put(len(objOutput).uint)
stream.put(objOutput)
proc encode*(protobuf: var ProtoBuffer, value: object) {.inline.} =
protobuf.outstream.encode(protobuf.fieldNum, value)
inc protobuf.fieldNum
proc put(stream: OutputStreamVar, value: object) {.inline.} =
var fieldNum = 1
for field, val in value.fieldPairs:
stream.encode(fieldNum, val)
fieldNum += 1
proc getVarint[T: SomeVarint](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): T {.inline.} =
var bytesRead = 0
# Only up to 128 bits supported by the spec
when T is enum:
var value: type(ord(result.value))
var value: type(ord(result))
else:
var value: T
var shiftAmount = 0
var i = offset
while true:
value += type(value)(bytes[i] and 0b0111_1111) shl shiftAmount
value += type(value)(bytes[outOffset] and 0b0111_1111) shl shiftAmount
shiftAmount += 7
if (bytes[i] shr 7) == 0:
if (bytes[outOffset] shr 7) == 0:
break
i += 1
increaseBytesRead()
result.bytesProcessed = i
increaseBytesRead()
when ty is SomeSVarint:
if (value and type(value)(1)) != type(value)(0):
result.value = cast[T](not(value shr type(value)(1)))
result = cast[T](not(value shr type(value)(1)))
else:
result.value = cast[T](value shr type(value)(1))
result = cast[T](value shr type(value)(1))
else:
result.value = value
result = value
proc decode*[T: SomeVarint](bytes: var seq[byte], ty: typedesc[T], offset = 0): tuple[fieldNum: uint, value: T, bytesProcessed: int] {.inline.} =
proc decode*[T: SomeVarint](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): ProtoField[T] {.inline.} =
# Only up to 128 bits supported by the spec
assert (bytes.len - 1) <= 16
let wireTy = wireType(bytes[offset])
var bytesRead = 0
let wireTy = wireType(bytes[outOffset])
if wireTy != Varint:
raise newException(Exception, fmt"Not a varint at offset {offset}! Received a {wireTy}")
raise newException(Exception, fmt"Not a varint at offset {outOffset}! Received a {wireTy}")
result.fieldNum = fieldNumber(bytes[offset])
var offset = offset + 1
result.index = fieldNumber(bytes[outOffset])
increaseBytesRead()
result.value = getVarint(bytes, ty, outOffset, outBytesProcessed, numBytesToRead)
let varGet = getVarint(bytes, ty, offset)
result.value = varGet.value
result.bytesProcessed = varGet.bytesProcessed + offset
proc getLengthDelimited*[T: SomeLengthDelimited](
bytes: var seq[byte],
ty: typedesc[T], offset = 0
): tuple[value: T, bytesProcessed: int] {.inline.} =
var offset = offset
let decodedSize = getVarint(bytes, uint, offset = offset)
offset += decodedSize.bytesProcessed
let length = decodedSize.value.int
ty: typedesc[T], outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): T {.inline.} =
var bytesRead = 0
let decodedSize = getVarint(bytes, uint, outOffset, outBytesProcessed, numBytesToRead)
let length = decodedSize.int
when T is string:
result.value = newString(length)
for i in offset ..< (offset + length):
result.value[i - offset] = bytes[i].chr
result = newString(length)
for i in outOffset ..< (outOffset + length):
result[i - outOffset] = bytes[i].chr
elif T is cstring:
result.value = cast[cstring](bytes[offset ..< (offset + length)])
result = cast[cstring](bytes[outOffset ..< (outOffset + length)])
else:
result.value = newSeq(length)
for i in offset ..< (offset + length):
result.value[i - offset] = bytes[i].chr
result.setLen(length)
for i in outOffset ..< (outOffset + length):
result[i - outOffset] = type(result[0])(bytes[i])
result.bytesProcessed += length
increaseBytesRead(length)
proc decode*[T: SomeLengthDelimited](
bytes: var seq[byte],
ty: typedesc[T], offset = 0
): tuple[fieldNum: uint, value: T, bytesProcessed: int] {.inline.} =
var offset = offset
let wireTy = wireType(bytes[offset])
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): ProtoField[T] {.inline.} =
var bytesRead = 0
let wireTy = wireType(bytes[outOffset])
if wireTy != LengthDelimited:
raise newException(Exception, fmt"Not a length delimited value at offset {offset}! Received a {wireTy}")
raise newException(Exception, fmt"Not a length delimited value at offset {outOffset}! Received a {wireTy}")
result.fieldNum = fieldNumber(bytes[offset])
result.index = fieldNumber(bytes[outOffset])
increaseBytesRead()
offset += 1
result.value = getLengthDelimited(bytes, ty, outOffset, outBytesProcessed, numBytesToRead)
let lengthDelimited = getLengthDelimited(bytes, ty, offset)
result.bytesProcessed = offset + lengthDelimited.bytesProcessed
result.value = lengthDelimited.value
type
Test1 = object
a: uint
Test3 = object
g {.sfixed32.}: int
h: int
i: Test1
macro getField(obj: typed, fieldNum: int, ty: typedesc): untyped =
template fieldTypeCheck(obj, field, fieldNum, ty) =
when type(obj.field) is type(ty):
obj.field
else:
let fnum {.inject.} = fieldNum
raise newException(Exception, fmt"Could not find field at position {fnum}.")
let typeImpl = obj.getTypeInst.getImpl
let typeFields = obj.getTypeInst.getType
let objFields = typeFields[2]
expectKind objFields, nnkRecList
result = newStmtList()
let caseStmt = newNimNode(nnkCaseStmt)
caseStmt.add(fieldNum)
for i in 0 ..< len(objFields) - 1:
let field = objFields[i]
let ofBranch = newNimNode(nnkOfBranch)
ofBranch.add(newLit(i+1))
ofBranch.add(getAst(fieldTypeCheck(obj, field, fieldNum, ty)))
caseStmt.add(ofBranch)
let field = objFields[len(objFields) - 1]
let elseBranch = newNimNode(nnkElse)
elseBranch.add(
nnkStmtList.newTree(getAst(fieldTypeCheck(obj, field, fieldNum, ty)))
)
caseStmt.add(elseBranch)
result.add(caseStmt)
macro setField(obj: typed, fieldNum: int, offset: int, bytesProcessed: int, bytesToRead: Option[int], value: untyped): untyped =
let typeImpl = obj.getTypeInst.getImpl
let typeFields = obj.getTypeInst.getType
let objFields = typeFields[2]
expectKind objFields, nnkRecList
result = newStmtList()
let caseStmt = newNimNode(nnkCaseStmt)
caseStmt.add(fieldNum)
for i in 0 ..< len(objFields) - 1:
let field = objFields[i]
let ofBranch = newNimNode(nnkOfBranch)
ofBranch.add(newLit(i+1))
ofBranch.add(
quote do:
`obj`.`field` = decode(`value`, type(`obj`.`field`), `offset`, `bytesProcessed`, `bytesToRead`).value
)
caseStmt.add(ofBranch)
let field = objFields[len(objFields) - 1]
let elseBranch = newNimNode(nnkElse)
elseBranch.add(
nnkStmtList.newTree(
quote do:
`obj`.`field` = decode(`value`, type(`obj`.`field`), `offset`, `bytesProcessed`, `bytesToRead`).value
)
)
caseStmt.add(elseBranch)
result.add(caseStmt)
proc decode*[T: object](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): ProtoField[T] {.inline.} =
var bytesRead = 0
let wireTy = wireType(bytes[outOffset])
result.index = fieldNumber(bytes[outOffset])
if wireTy == LengthDelimited:
# read LD header
# then read only amount of bytes needed
increaseBytesRead()
let decodedSize = getVarint(bytes, uint, outOffset, outBytesProcessed, numBytesToRead)
let bytesToRead = some(decodedSize.int)
setField(result.value, result.index, outOffset, outBytesProcessed, bytesToRead, bytes)
else:
setField(result.value, result.index, outOffset, outBytesProcessed, numBytesToRead, bytes)

View file

@ -5,52 +5,83 @@ import protobuf_serialization
type
MyEnum = enum
ME1, ME2, ME3
type
Test1 = object
a: uint
Test3 = object
g {.sfixed32.}: int
h: int
i: Test1
suite "Test Varint Encoding":
test "Can encode/decode enum":
let proto = newProtoBuffer()
var proto = newProtoBuffer()
var bytesProcessed: int
proto.encode(ME3)
proto.encode(ME2)
var output = proto.output
assert output == @[8.byte, 4, 16, 2]
var offset = 0
let decodedME3 = decode(output, MyEnum)
let decodedME3 = decode(output, MyEnum, offset, bytesProcessed)
assert decodedME3.value == ME3
assert decodedME3.fieldNum == 1
assert decodedME3.index == 1
let decodedME2 = decode(output, MyEnum, offset=decodedME3.bytesProcessed)
let decodedME2 = decode(output, MyEnum, offset, bytesProcessed)
assert decodedME2.value == ME2
assert decodedME2.fieldNum == 2
assert decodedME2.index == 2
test "Can encode/decode negative number":
let proto = newProtoBuffer()
var proto = newProtoBuffer()
let num = -153452
var bytesProcessed: int
proto.encode(num)
var output = proto.output
assert output == @[8.byte, 215, 221, 18]
let decoded = decode(output, int)
var offset = 0
let decoded = decode(output, int, offset, bytesProcessed)
assert decoded.value == num
assert decoded.fieldNum == 1
assert decoded.index == 1
test "Can encode/decode unsigned number":
let proto = newProtoBuffer()
var proto = newProtoBuffer()
let num = 123151.uint
var bytesProcessed: int
proto.encode(num)
var output = proto.output
assert output == @[8.byte, 143, 194, 7]
var offset = 0
let decoded = decode(output, uint)
let decoded = decode(output, uint, offset, bytesProcessed)
assert decoded.value == num
assert decoded.fieldNum == 1
assert decoded.index == 1
test "Can encode/decode string":
let proto = newProtoBuffer()
var proto = newProtoBuffer()
let str = "hey this is a string"
var bytesProcessed: int
proto.encode(str)
var output = proto.output
assert output == @[10.byte, 20, 104, 101, 121, 32, 116, 104, 105, 115, 32, 105, 115, 32, 97, 32, 115, 116, 114, 105, 110, 103]
let decoded = decode(output, string)
var offset = 0
let decoded = decode(output, string, offset, bytesProcessed)
assert decoded.value == str
assert decoded.fieldNum == 1
assert decoded.index == 1
test "Can encode/decode object":
var proto = newProtoBuffer()
let obj = Test3(g: 300, h: 200, i: Test1(a: 100))
proto.encode(obj)
var offset, bytesProcessed: int
var output = proto.output
let decoded = decode(output, Test3, offset, bytesProcessed)
echo decoded
echo output
assert false