nim-protobuf-serialization/protobuf_serialization.nim
Joey Yakimowich-Payne 86c8143567 Add checkType for object
2020-04-05 21:39:21 -06:00

344 lines
No EOL
10 KiB
Nim

import macros, strformat, typetraits, options
import faststreams
template sint32*() {.pragma.}
template sint64*() {.pragma.}
template sfixed32*() {.pragma.}
template sfixed64*() {.pragma.}
template fixed32*() {.pragma.}
template fixed64*() {.pragma.}
template float*() {.pragma.}
template double*() {.pragma.}
const
MaxMessageSize* = 1'u shl 22
type
ProtoBuffer* = object
fieldNum: int
outstream: OutputStreamVar
ProtoWireType* = enum
## Protobuf's field types enum
Varint, Fixed64, LengthDelimited, StartGroup, EndGroup, Fixed32
EncodingKind* = enum
ekNormal, ekZigzag
ProtoField*[T] = object
## Protobuf's message field representation object
index*: int
value*: T
SomeSVarint* = int | int64 | int32 | int16 | int8 | enum
SomeByte* = byte | bool | char | uint8
SomeUVarint* = uint | uint64 | uint32 | uint16 | SomeByte
SomeVarint* = SomeSVarint | SomeUVarint
SomeLengthDelimited* = string | seq[SomeByte] | cstring
SomeFixed64* = float64
SomeFixed32* = float32
SomeFixed* = SomeFixed32 | SomeFixed64
AnyProtoType* = SomeVarint | SomeLengthDelimited | SomeFixed | object
UnexpectedTypeError* = object of ValueError
proc newProtoBuffer*(): ProtoBuffer =
ProtoBuffer(outstream: OutputStream.init(), fieldNum: 1)
proc output*(proto: ProtoBuffer): seq[byte] {.inline.} =
proto.outstream.getOutput
template wireType(firstByte: byte): ProtoWireType =
(firstByte and 0b111).ProtoWireType
template fieldNumber(firstByte: byte): int =
((firstByte shr 3) and 0b1111).int
template protoHeader*(fieldNum: int, wire: ProtoWireType): byte =
## Get protobuf's field header integer for ``index`` and ``wire``.
((cast[uint](fieldNum) shl 3) or cast[uint](wire)).byte
template increaseBytesRead(amount = 1) =
## Convenience template for increasing
## all of the counts
mixin isSome
bytesRead += amount
outOffset += amount
outBytesProcessed += amount
if numBytesToRead.isSome():
if (bytesRead > numBytesToRead.get()).unlikely:
raise newException(Exception, &"Number of bytes read ({bytesRead}) exceeded bytes requested ({numBytesToRead})")
proc put(stream: OutputStreamVar, value: SomeVarint) {.inline.} =
when value is enum:
var value = cast[type(ord(value))](value)
elif value is bool or value is char:
var value = cast[byte](value)
else:
var value = value
when type(value) is SomeSVarint:
# Encode using zigzag
if value < type(value)(0):
value = not(value shl type(value)(1))
else:
value = value shl type(value)(1)
while value > type(value)(0b0111_1111):
stream.append byte((value and 0b0111_1111) or 0b1000_0000)
value = value shr 7
stream.append byte(value and 0b1111_1111)
proc encodeField(stream: OutputStreamVar, fieldNum: int, value: SomeVarint) {.inline.} =
stream.append protoHeader(fieldNum, Varint)
stream.put(value)
proc put(stream: OutputStreamVar, value: SomeFixed) {.inline.} =
when typeof(value) is SomeFixed64:
var value = cast[int64](value)
else:
var value = cast[int32](value)
for _ in 0 ..< sizeof(value):
stream.append byte(value and 0b1111_1111)
value = value shr 8
proc encodeField(stream: OutputStreamVar, fieldNum: int, value: SomeFixed64) {.inline.} =
stream.append protoHeader(fieldNum, Fixed64)
stream.put(value)
proc encodeField(stream: OutputStreamVar, fieldNum: int, value: SomeFixed32) {.inline.} =
stream.append protoHeader(fieldNum, Fixed32)
stream.put(value)
proc put(stream: OutputStreamVar, value: SomeLengthDelimited) {.inline.} =
stream.put(len(value).uint)
for b in value:
stream.append byte(b)
proc encodeField(stream: OutputStreamVar, fieldNum: int, value: SomeLengthDelimited) {.inline.} =
stream.append protoHeader(fieldNum, LengthDelimited)
stream.put(value)
proc put(stream: OutputStreamVar, value: object) {.inline.}
proc encodeField(stream: OutputStreamVar, fieldNum: int, value: object) {.inline.} =
# This is currently needed in order to get the size
# of the output before adding it to the stream.
# Maybe there is a better way to do this
let objStream = OutputStream.init()
objStream.put(value)
let objOutput = objStream.getOutput()
if objOutput.len > 0:
stream.append protoHeader(fieldNum, LengthDelimited)
stream.put(objOutput)
proc put(stream: OutputStreamVar, value: object) {.inline.} =
var fieldNum = 1
for _, val in value.fieldPairs:
# Only store the value
if default(type(val)) != val:
stream.encodeField(fieldNum, val)
inc fieldNum
proc encode*(protobuf: var ProtoBuffer, value: object) {.inline.} =
protobuf.outstream.put(value)
proc encodeField*(protobuf: var ProtoBuffer, fieldNum: int, value: AnyProtoType) {.inline.} =
protobuf.outstream.encodeField(fieldNum, value)
proc encodeField*(protobuf: var ProtoBuffer, value: AnyProtoType) {.inline.} =
protobuf.encodeField(protobuf.fieldNum, value)
inc protobuf.fieldNum
proc get*[T: SomeFixed](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): T {.inline.} =
var bytesRead = 0
when T is SomeFixed64:
var value: int64
else:
var value: int32
var shiftAmount = 0
for _ in 0 ..< sizeof(T):
value += type(value)(bytes[outOffset]) shl shiftAmount
shiftAmount += 8
increaseBytesRead()
result = cast[T](value)
proc get[T: SomeVarint](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): T {.inline.} =
var bytesRead = 0
# Only up to 128 bits supported by the spec
when T is enum or T is char:
var value: type(ord(result))
elif T is bool:
var value: byte
else:
var value: T
var shiftAmount = 0
while true:
value += type(value)(bytes[outOffset] and 0b0111_1111) shl shiftAmount
shiftAmount += 7
if (bytes[outOffset] shr 7) == 0:
break
increaseBytesRead()
increaseBytesRead()
when ty is SomeSVarint:
if (value and type(value)(1)) != type(value)(0):
result = cast[T](not(value shr type(value)(1)))
else:
result = cast[T](value shr type(value)(1))
else:
result = T(value)
proc checkType[T: SomeVarint](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} =
let wireTy = wireType(tyByte)
if wireTy != Varint:
raise newException(UnexpectedTypeError, fmt"Not a varint at offset {offset}! Received a {wireTy}")
proc checkType[T: SomeFixed](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} =
let wireTy = wireType(tyByte)
if wireTy notin {Fixed32, Fixed64}:
raise newException(UnexpectedTypeError, fmt"Not a fixed32 or fixed64 at offset {offset}! Received a {wireTy}")
proc checkType[T: SomeLengthDelimited](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} =
let wireTy = wireType(tyByte)
if wireTy != LengthDelimited:
raise newException(UnexpectedTypeError, fmt"Not a length delimited value at offset {offset}! Received a {wireTy}")
proc checkType[T: object](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} =
let wireTy = wireType(tyByte)
if wireTy != LengthDelimited:
raise newException(UnexpectedTypeError, fmt"Not an object value at offset {offset}! Received a {wireTy}")
proc get*[T: SomeLengthDelimited](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): T {.inline.} =
var bytesRead = 0
let decodedSize = bytes.get(uint, outOffset, outBytesProcessed, numBytesToRead)
let length = decodedSize.int
when T is string:
result = newString(length)
for i in outOffset ..< (outOffset + length):
result[i - outOffset] = bytes[i].chr
elif T is cstring:
result = cast[cstring](bytes[outOffset ..< (outOffset + length)])
else:
result.setLen(length)
for i in outOffset ..< (outOffset + length):
result[i - outOffset] = type(result[0])(bytes[i])
increaseBytesRead(length)
proc decodeField*[T: SomeFixed | SomeVarint | SomeLengthDelimited](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): ProtoField[T] {.inline.} =
var bytesRead = 0
checkType(bytes[outOffset], ty, outOffset)
result.index = fieldNumber(bytes[outOffset])
increaseBytesRead()
result.value = bytes.get(ty, outOffset, outBytesProcessed, numBytesToRead)
proc decodeField*[T: object](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): ProtoField[T] {.inline.}
macro setField(obj: typed, fieldNum: int, offset: int, bytesProcessed: int, bytesToRead: Option[int], value: untyped): untyped =
let typeFields = obj.getTypeInst.getType
let objFields = typeFields[2]
expectKind objFields, nnkRecList
result = newStmtList()
let caseStmt = newNimNode(nnkCaseStmt)
caseStmt.add(fieldNum)
for i in 0 ..< len(objFields) - 1:
let field = objFields[i]
let ofBranch = newNimNode(nnkOfBranch)
ofBranch.add(newLit(i+1))
ofBranch.add(
quote do:
`obj`.`field` = decodeField(`value`, type(`obj`.`field`), `offset`, `bytesProcessed`, `bytesToRead`).value
)
caseStmt.add(ofBranch)
let field = objFields[len(objFields) - 1]
let elseBranch = newNimNode(nnkElse)
elseBranch.add(
nnkStmtList.newTree(
quote do:
`obj`.`field` = decodeField(`value`, type(`obj`.`field`), `offset`, `bytesProcessed`, `bytesToRead`).value
)
)
caseStmt.add(elseBranch)
result.add(caseStmt)
proc decodeField*[T: object](
bytes: var seq[byte],
ty: typedesc[T],
outOffset: var int,
outBytesProcessed: var int,
numBytesToRead = none(int)
): ProtoField[T] {.inline.} =
var bytesRead = 0
checkType(bytes[outOffset], ty, outOffset)
result.index = fieldNumber(bytes[outOffset])
# read LD header
# then read only amount of bytes needed
increaseBytesRead()
let decodedSize = bytes.get(uint, outOffset, outBytesProcessed, numBytesToRead)
let bytesToRead = some(decodedSize.int)
let oldOffset = outOffset
while outOffset < oldOffset + bytesToRead.get():
let fieldNum = fieldNumber(bytes[outOffset])
setField(result.value, fieldNum, outOffset, outBytesProcessed, bytesToRead, bytes)
proc decode*[T: object](
bytes: var seq[byte],
ty: typedesc[T],
): T {.inline.} =
var bytesRead = 0
var offset = 0
while offset < bytes.len - 1:
let fieldNum = fieldNumber(bytes[offset])
setField(result, fieldNum, offset, bytesRead, none(int), bytes)