diff --git a/protobuf_serialization.nim b/protobuf_serialization.nim index 9a2596b..fa2dcd3 100644 --- a/protobuf_serialization.nim +++ b/protobuf_serialization.nim @@ -1,11 +1,20 @@ -import macros, strformat +import macros, strformat, typetraits, options import faststreams +template sint32*() {.pragma.} +template sint64*() {.pragma.} +template sfixed32*() {.pragma.} +template sfixed64*() {.pragma.} +template fixed32*() {.pragma.} +template fixed64*() {.pragma.} +template float*() {.pragma.} +template double*() {.pragma.} + const MaxMessageSize* = 1'u shl 22 type - ProtoBuffer* = ref object + ProtoBuffer* = object fieldNum: int outstream: OutputStreamVar @@ -13,23 +22,16 @@ type ## Protobuf's field types enum Varint, Fixed64, LengthDelimited, StartGroup, EndGroup, Fixed32 - ProtoField* = object + EncodingKind* = enum + ekNormal, ekZigzag + + ProtoField*[T] = object ## Protobuf's message field representation object - index: int - case kind: ProtoWireType - of Varint: - vint*: uint64 - of Fixed64: - vfloat64*: float64 - of LengthDelimited: - vbuffer*: OutputStreamVar - of Fixed32: - vfloat32*: float32 - of StartGroup, EndGroup: - discard + index*: int + value*: T SomeSVarint* = int | int64 | int32 | int16 | int8 | enum - SomeUVarint* = uint | uint64 | uint32 | uint16 | uint8 | byte | bool + SomeUVarint* = uint | uint64 | uint32 | uint16 | uint8 | byte | bool | char SomeVarint* = SomeSVarint | SomeUVarint SomeLengthDelimited* = string | seq[byte] | seq[uint8] | cstring @@ -49,22 +51,32 @@ proc output*(proto: ProtoBuffer): seq[byte] {.inline.} = template wireType(firstByte: byte): ProtoWireType = (firstByte and 0b111).ProtoWireType -template fieldNumber(firstByte: byte): uint = - (firstByte shr 3) and 0b1111 +template fieldNumber(firstByte: byte): int = + ((firstByte shr 3) and 0b1111).int template protoHeader*(fieldNum: int, wire: ProtoWireType): byte = ## Get protobuf's field header integer for ``index`` and ``wire``. ((cast[uint](fieldNum) shl 3) or cast[uint](wire)).byte -proc putVarint(stream: OutputStreamVar, value: SomeVarint) {.inline.} = +template increaseBytesRead(amount = 1) = + mixin isSome + bytesRead += amount + outOffset += amount + outBytesProcessed += amount + if numBytesToRead.isSome(): + if (bytesRead > numBytesToRead.get()).unlikely: + raise newException(Exception, "Number of bytes read exceeded") + +proc put(stream: OutputStreamVar, value: SomeVarint) {.inline.} = when value is enum: var value = cast[type(ord(value))](value) - elif value is bool: + elif value is bool or value is char: var value = cast[byte](value) else: var value = value when type(value) is SomeSVarint: + # Encode using zigzag if value < type(value)(0): value = not(value shl type(value)(1)) else: @@ -77,102 +89,235 @@ proc putVarint(stream: OutputStreamVar, value: SomeVarint) {.inline.} = proc encode(stream: OutputStreamVar, fieldNum: int, value: SomeVarint) {.inline.} = stream.append protoHeader(fieldNum, Varint) - stream.putVarint(value) + stream.put(value) -proc encode*(protobuf: ProtoBuffer, value: SomeVarint) {.inline.} = +proc encode*(protobuf: var ProtoBuffer, value: SomeVarint) {.inline.} = protobuf.outstream.encode(protobuf.fieldNum, value) inc protobuf.fieldNum -proc putLengthDelimited(stream: OutputStreamVar, value: SomeLengthDelimited) {.inline.} = +proc put(stream: OutputStreamVar, value: SomeLengthDelimited) {.inline.} = for b in value: stream.append byte(b) proc encode(stream: OutputStreamVar, fieldNum: int, value: SomeLengthDelimited) {.inline.} = stream.append protoHeader(fieldNum, LengthDelimited) - stream.putVarint(len(value).uint) - stream.putLengthDelimited(value) + stream.put(len(value).uint) + stream.put(value) -proc encode*(protobuf: ProtoBuffer, value: SomeLengthDelimited) {.inline.} = +proc encode*(protobuf: var ProtoBuffer, value: SomeLengthDelimited) {.inline.} = protobuf.outstream.encode(protobuf.fieldNum, value) inc protobuf.fieldNum -proc getVarint[T: SomeVarint](bytes: var seq[byte], ty: typedesc[T], offset = 0): tuple[value: T, bytesProcessed: int] {.inline.} = +proc put(stream: OutputStreamVar, value: object) {.inline.} + +proc encode(stream: OutputStreamVar, fieldNum: int, value: object) {.inline.} = + #TODO Encode generic objects + stream.append protoHeader(fieldNum, LengthDelimited) + let objStream = OutputStream.init() + objStream.put(value) + let objOutput = objStream.getOutput() + stream.put(len(objOutput).uint) + stream.put(objOutput) + +proc encode*(protobuf: var ProtoBuffer, value: object) {.inline.} = + protobuf.outstream.encode(protobuf.fieldNum, value) + inc protobuf.fieldNum + +proc put(stream: OutputStreamVar, value: object) {.inline.} = + var fieldNum = 1 + for field, val in value.fieldPairs: + stream.encode(fieldNum, val) + fieldNum += 1 + +proc getVarint[T: SomeVarint]( + bytes: var seq[byte], + ty: typedesc[T], + outOffset: var int, + outBytesProcessed: var int, + numBytesToRead = none(int) +): T {.inline.} = + var bytesRead = 0 # Only up to 128 bits supported by the spec when T is enum: - var value: type(ord(result.value)) + var value: type(ord(result)) else: var value: T var shiftAmount = 0 - var i = offset while true: - value += type(value)(bytes[i] and 0b0111_1111) shl shiftAmount + value += type(value)(bytes[outOffset] and 0b0111_1111) shl shiftAmount shiftAmount += 7 - if (bytes[i] shr 7) == 0: + if (bytes[outOffset] shr 7) == 0: break - i += 1 + increaseBytesRead() - result.bytesProcessed = i + increaseBytesRead() when ty is SomeSVarint: if (value and type(value)(1)) != type(value)(0): - result.value = cast[T](not(value shr type(value)(1))) + result = cast[T](not(value shr type(value)(1))) else: - result.value = cast[T](value shr type(value)(1)) + result = cast[T](value shr type(value)(1)) else: - result.value = value + result = value -proc decode*[T: SomeVarint](bytes: var seq[byte], ty: typedesc[T], offset = 0): tuple[fieldNum: uint, value: T, bytesProcessed: int] {.inline.} = +proc decode*[T: SomeVarint]( + bytes: var seq[byte], + ty: typedesc[T], + outOffset: var int, + outBytesProcessed: var int, + numBytesToRead = none(int) +): ProtoField[T] {.inline.} = # Only up to 128 bits supported by the spec assert (bytes.len - 1) <= 16 - let wireTy = wireType(bytes[offset]) + var bytesRead = 0 + + let wireTy = wireType(bytes[outOffset]) if wireTy != Varint: - raise newException(Exception, fmt"Not a varint at offset {offset}! Received a {wireTy}") + raise newException(Exception, fmt"Not a varint at offset {outOffset}! Received a {wireTy}") - result.fieldNum = fieldNumber(bytes[offset]) - var offset = offset + 1 + result.index = fieldNumber(bytes[outOffset]) + increaseBytesRead() + + result.value = getVarint(bytes, ty, outOffset, outBytesProcessed, numBytesToRead) - let varGet = getVarint(bytes, ty, offset) - result.value = varGet.value - result.bytesProcessed = varGet.bytesProcessed + offset proc getLengthDelimited*[T: SomeLengthDelimited]( bytes: var seq[byte], - ty: typedesc[T], offset = 0 -): tuple[value: T, bytesProcessed: int] {.inline.} = - - var offset = offset - let decodedSize = getVarint(bytes, uint, offset = offset) - offset += decodedSize.bytesProcessed - let length = decodedSize.value.int + ty: typedesc[T], outOffset: var int, + outBytesProcessed: var int, + numBytesToRead = none(int) +): T {.inline.} = + var bytesRead = 0 + let decodedSize = getVarint(bytes, uint, outOffset, outBytesProcessed, numBytesToRead) + let length = decodedSize.int when T is string: - result.value = newString(length) - for i in offset ..< (offset + length): - result.value[i - offset] = bytes[i].chr + result = newString(length) + for i in outOffset ..< (outOffset + length): + result[i - outOffset] = bytes[i].chr elif T is cstring: - result.value = cast[cstring](bytes[offset ..< (offset + length)]) + result = cast[cstring](bytes[outOffset ..< (outOffset + length)]) else: - result.value = newSeq(length) - for i in offset ..< (offset + length): - result.value[i - offset] = bytes[i].chr + result.setLen(length) + for i in outOffset ..< (outOffset + length): + result[i - outOffset] = type(result[0])(bytes[i]) - result.bytesProcessed += length + increaseBytesRead(length) proc decode*[T: SomeLengthDelimited]( bytes: var seq[byte], - ty: typedesc[T], offset = 0 -): tuple[fieldNum: uint, value: T, bytesProcessed: int] {.inline.} = - var offset = offset - - let wireTy = wireType(bytes[offset]) + ty: typedesc[T], + outOffset: var int, + outBytesProcessed: var int, + numBytesToRead = none(int) +): ProtoField[T] {.inline.} = + var bytesRead = 0 + let wireTy = wireType(bytes[outOffset]) if wireTy != LengthDelimited: - raise newException(Exception, fmt"Not a length delimited value at offset {offset}! Received a {wireTy}") + raise newException(Exception, fmt"Not a length delimited value at offset {outOffset}! Received a {wireTy}") - result.fieldNum = fieldNumber(bytes[offset]) + result.index = fieldNumber(bytes[outOffset]) + increaseBytesRead() - offset += 1 + result.value = getLengthDelimited(bytes, ty, outOffset, outBytesProcessed, numBytesToRead) - let lengthDelimited = getLengthDelimited(bytes, ty, offset) - result.bytesProcessed = offset + lengthDelimited.bytesProcessed - result.value = lengthDelimited.value \ No newline at end of file +type + Test1 = object + a: uint + + Test3 = object + g {.sfixed32.}: int + h: int + i: Test1 + +macro getField(obj: typed, fieldNum: int, ty: typedesc): untyped = + template fieldTypeCheck(obj, field, fieldNum, ty) = + when type(obj.field) is type(ty): + obj.field + else: + let fnum {.inject.} = fieldNum + raise newException(Exception, fmt"Could not find field at position {fnum}.") + + let typeImpl = obj.getTypeInst.getImpl + let typeFields = obj.getTypeInst.getType + + let objFields = typeFields[2] + expectKind objFields, nnkRecList + + result = newStmtList() + let caseStmt = newNimNode(nnkCaseStmt) + caseStmt.add(fieldNum) + + for i in 0 ..< len(objFields) - 1: + let field = objFields[i] + let ofBranch = newNimNode(nnkOfBranch) + ofBranch.add(newLit(i+1)) + ofBranch.add(getAst(fieldTypeCheck(obj, field, fieldNum, ty))) + caseStmt.add(ofBranch) + + let field = objFields[len(objFields) - 1] + let elseBranch = newNimNode(nnkElse) + elseBranch.add( + nnkStmtList.newTree(getAst(fieldTypeCheck(obj, field, fieldNum, ty))) + ) + caseStmt.add(elseBranch) + + result.add(caseStmt) + +macro setField(obj: typed, fieldNum: int, offset: int, bytesProcessed: int, bytesToRead: Option[int], value: untyped): untyped = + let typeImpl = obj.getTypeInst.getImpl + let typeFields = obj.getTypeInst.getType + + let objFields = typeFields[2] + expectKind objFields, nnkRecList + + result = newStmtList() + + let caseStmt = newNimNode(nnkCaseStmt) + caseStmt.add(fieldNum) + + for i in 0 ..< len(objFields) - 1: + let field = objFields[i] + let ofBranch = newNimNode(nnkOfBranch) + ofBranch.add(newLit(i+1)) + ofBranch.add( + quote do: + `obj`.`field` = decode(`value`, type(`obj`.`field`), `offset`, `bytesProcessed`, `bytesToRead`).value + ) + caseStmt.add(ofBranch) + + let field = objFields[len(objFields) - 1] + let elseBranch = newNimNode(nnkElse) + elseBranch.add( + nnkStmtList.newTree( + quote do: + `obj`.`field` = decode(`value`, type(`obj`.`field`), `offset`, `bytesProcessed`, `bytesToRead`).value + ) + ) + caseStmt.add(elseBranch) + + result.add(caseStmt) + +proc decode*[T: object]( + bytes: var seq[byte], + ty: typedesc[T], + outOffset: var int, + outBytesProcessed: var int, + numBytesToRead = none(int) +): ProtoField[T] {.inline.} = + var bytesRead = 0 + + let wireTy = wireType(bytes[outOffset]) + result.index = fieldNumber(bytes[outOffset]) + + if wireTy == LengthDelimited: + # read LD header + # then read only amount of bytes needed + increaseBytesRead() + + let decodedSize = getVarint(bytes, uint, outOffset, outBytesProcessed, numBytesToRead) + let bytesToRead = some(decodedSize.int) + setField(result.value, result.index, outOffset, outBytesProcessed, bytesToRead, bytes) + else: + setField(result.value, result.index, outOffset, outBytesProcessed, numBytesToRead, bytes) \ No newline at end of file diff --git a/tests/test_serialization.nim b/tests/test_serialization.nim index 2a11db7..37ae148 100644 --- a/tests/test_serialization.nim +++ b/tests/test_serialization.nim @@ -5,52 +5,83 @@ import protobuf_serialization type MyEnum = enum ME1, ME2, ME3 +type + Test1 = object + a: uint + + Test3 = object + g {.sfixed32.}: int + h: int + i: Test1 suite "Test Varint Encoding": test "Can encode/decode enum": - let proto = newProtoBuffer() + var proto = newProtoBuffer() + var bytesProcessed: int proto.encode(ME3) proto.encode(ME2) var output = proto.output assert output == @[8.byte, 4, 16, 2] + var offset = 0 - let decodedME3 = decode(output, MyEnum) + let decodedME3 = decode(output, MyEnum, offset, bytesProcessed) assert decodedME3.value == ME3 - assert decodedME3.fieldNum == 1 + assert decodedME3.index == 1 - let decodedME2 = decode(output, MyEnum, offset=decodedME3.bytesProcessed) + let decodedME2 = decode(output, MyEnum, offset, bytesProcessed) assert decodedME2.value == ME2 - assert decodedME2.fieldNum == 2 + assert decodedME2.index == 2 test "Can encode/decode negative number": - let proto = newProtoBuffer() + var proto = newProtoBuffer() let num = -153452 + var bytesProcessed: int proto.encode(num) var output = proto.output assert output == @[8.byte, 215, 221, 18] - let decoded = decode(output, int) + var offset = 0 + let decoded = decode(output, int, offset, bytesProcessed) assert decoded.value == num - assert decoded.fieldNum == 1 + assert decoded.index == 1 test "Can encode/decode unsigned number": - let proto = newProtoBuffer() + var proto = newProtoBuffer() let num = 123151.uint + var bytesProcessed: int proto.encode(num) var output = proto.output assert output == @[8.byte, 143, 194, 7] + var offset = 0 - let decoded = decode(output, uint) + let decoded = decode(output, uint, offset, bytesProcessed) assert decoded.value == num - assert decoded.fieldNum == 1 + assert decoded.index == 1 test "Can encode/decode string": - let proto = newProtoBuffer() + var proto = newProtoBuffer() let str = "hey this is a string" + var bytesProcessed: int proto.encode(str) var output = proto.output assert output == @[10.byte, 20, 104, 101, 121, 32, 116, 104, 105, 115, 32, 105, 115, 32, 97, 32, 115, 116, 114, 105, 110, 103] - let decoded = decode(output, string) + var offset = 0 + let decoded = decode(output, string, offset, bytesProcessed) assert decoded.value == str - assert decoded.fieldNum == 1 \ No newline at end of file + assert decoded.index == 1 + + test "Can encode/decode object": + var proto = newProtoBuffer() + + let obj = Test3(g: 300, h: 200, i: Test1(a: 100)) + + proto.encode(obj) + var offset, bytesProcessed: int + + var output = proto.output + let decoded = decode(output, Test3, offset, bytesProcessed) + echo decoded + + echo output + assert false \ No newline at end of file