import macros, strformat, typetraits, options import faststreams template sint32*() {.pragma.} template sint64*() {.pragma.} template sfixed32*() {.pragma.} template sfixed64*() {.pragma.} template fixed32*() {.pragma.} template fixed64*() {.pragma.} template float*() {.pragma.} template double*() {.pragma.} const MaxMessageSize* = 1'u shl 22 type ProtoBuffer* = object fieldNum: int outstream: OutputStreamVar ProtoWireType* = enum ## Protobuf's field types enum Varint, Fixed64, LengthDelimited, StartGroup, EndGroup, Fixed32 EncodingKind* = enum ekNormal, ekZigzag ProtoField*[T] = object ## Protobuf's message field representation object index*: int value*: T SomeSVarint* = int | int64 | int32 | int16 | int8 | enum SomeByte* = byte | bool | char | uint8 SomeUVarint* = uint | uint64 | uint32 | uint16 | SomeByte SomeVarint* = SomeSVarint | SomeUVarint SomeLengthDelimited* = string | seq[SomeByte] | cstring SomeFixed64* = float64 SomeFixed32* = float32 SomeFixed* = SomeFixed32 | SomeFixed64 AnyProtoType* = SomeVarint | SomeLengthDelimited | SomeFixed | object UnexpectedTypeError* = object of ValueError proc newProtoBuffer*(): ProtoBuffer = ProtoBuffer(outstream: OutputStream.init(), fieldNum: 1) proc output*(proto: ProtoBuffer): seq[byte] {.inline.} = proto.outstream.getOutput template wireType(firstByte: byte): ProtoWireType = (firstByte and 0b111).ProtoWireType template fieldNumber(firstByte: byte): int = ((firstByte shr 3) and 0b1111).int template protoHeader*(fieldNum: int, wire: ProtoWireType): byte = ## Get protobuf's field header integer for ``index`` and ``wire``. ((cast[uint](fieldNum) shl 3) or cast[uint](wire)).byte template increaseBytesRead(amount = 1) = ## Convenience template for increasing ## all of the counts mixin isSome bytesRead += amount outOffset += amount outBytesProcessed += amount if numBytesToRead.isSome(): if (bytesRead > numBytesToRead.get()).unlikely: raise newException(Exception, &"Number of bytes read ({bytesRead}) exceeded bytes requested ({numBytesToRead})") proc put(stream: OutputStreamVar, value: SomeVarint) {.inline.} = when value is enum: var value = cast[type(ord(value))](value) elif value is bool or value is char: var value = cast[byte](value) else: var value = value when type(value) is SomeSVarint: # Encode using zigzag if value < type(value)(0): value = not(value shl type(value)(1)) else: value = value shl type(value)(1) while value > type(value)(0b0111_1111): stream.append byte((value and 0b0111_1111) or 0b1000_0000) value = value shr 7 stream.append byte(value and 0b1111_1111) proc encodeField(stream: OutputStreamVar, fieldNum: int, value: SomeVarint) {.inline.} = stream.append protoHeader(fieldNum, Varint) stream.put(value) proc put(stream: OutputStreamVar, value: SomeFixed) {.inline.} = when typeof(value) is SomeFixed64: var value = cast[int64](value) else: var value = cast[int32](value) for _ in 0 ..< sizeof(value): stream.append byte(value and 0b1111_1111) value = value shr 8 proc encodeField(stream: OutputStreamVar, fieldNum: int, value: SomeFixed64) {.inline.} = stream.append protoHeader(fieldNum, Fixed64) stream.put(value) proc encodeField(stream: OutputStreamVar, fieldNum: int, value: SomeFixed32) {.inline.} = stream.append protoHeader(fieldNum, Fixed32) stream.put(value) proc put(stream: OutputStreamVar, value: SomeLengthDelimited) {.inline.} = stream.put(len(value).uint) for b in value: stream.append byte(b) proc encodeField(stream: OutputStreamVar, fieldNum: int, value: SomeLengthDelimited) {.inline.} = stream.append protoHeader(fieldNum, LengthDelimited) stream.put(value) proc put(stream: OutputStreamVar, value: object) {.inline.} proc encodeField(stream: OutputStreamVar, fieldNum: int, value: object) {.inline.} = # This is currently needed in order to get the size # of the output before adding it to the stream. # Maybe there is a better way to do this let objStream = OutputStream.init() objStream.put(value) let objOutput = objStream.getOutput() if objOutput.len > 0: stream.append protoHeader(fieldNum, LengthDelimited) stream.put(objOutput) proc put(stream: OutputStreamVar, value: object) {.inline.} = var fieldNum = 1 for _, val in value.fieldPairs: # Only store the value if default(type(val)) != val: stream.encodeField(fieldNum, val) inc fieldNum proc encode*(protobuf: var ProtoBuffer, value: object) {.inline.} = protobuf.outstream.put(value) proc encodeField*(protobuf: var ProtoBuffer, fieldNum: int, value: AnyProtoType) {.inline.} = protobuf.outstream.encodeField(fieldNum, value) proc encodeField*(protobuf: var ProtoBuffer, value: AnyProtoType) {.inline.} = protobuf.encodeField(protobuf.fieldNum, value) inc protobuf.fieldNum proc get*[T: SomeFixed]( bytes: var seq[byte], ty: typedesc[T], outOffset: var int, outBytesProcessed: var int, numBytesToRead = none(int) ): T {.inline.} = var bytesRead = 0 when T is SomeFixed64: var value: int64 else: var value: int32 var shiftAmount = 0 for _ in 0 ..< sizeof(T): value += type(value)(bytes[outOffset]) shl shiftAmount shiftAmount += 8 increaseBytesRead() result = cast[T](value) proc get[T: SomeVarint]( bytes: var seq[byte], ty: typedesc[T], outOffset: var int, outBytesProcessed: var int, numBytesToRead = none(int) ): T {.inline.} = var bytesRead = 0 # Only up to 128 bits supported by the spec when T is enum or T is char: var value: type(ord(result)) elif T is bool: var value: byte else: var value: T var shiftAmount = 0 while true: value += type(value)(bytes[outOffset] and 0b0111_1111) shl shiftAmount shiftAmount += 7 if (bytes[outOffset] shr 7) == 0: break increaseBytesRead() increaseBytesRead() when ty is SomeSVarint: if (value and type(value)(1)) != type(value)(0): result = cast[T](not(value shr type(value)(1))) else: result = cast[T](value shr type(value)(1)) else: result = T(value) proc checkType[T: SomeVarint](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} = let wireTy = wireType(tyByte) if wireTy != Varint: raise newException(UnexpectedTypeError, fmt"Not a varint at offset {offset}! Received a {wireTy}") proc checkType[T: SomeFixed](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} = let wireTy = wireType(tyByte) if wireTy notin {Fixed32, Fixed64}: raise newException(UnexpectedTypeError, fmt"Not a fixed32 or fixed64 at offset {offset}! Received a {wireTy}") proc checkType[T: SomeLengthDelimited](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} = let wireTy = wireType(tyByte) if wireTy != LengthDelimited: raise newException(UnexpectedTypeError, fmt"Not a length delimited value at offset {offset}! Received a {wireTy}") proc checkType[T: object](tyByte: byte, ty: typedesc[T], offset: int) {.inline.} = let wireTy = wireType(tyByte) if wireTy != LengthDelimited: raise newException(UnexpectedTypeError, fmt"Not an object value at offset {offset}! Received a {wireTy}") proc get*[T: SomeLengthDelimited]( bytes: var seq[byte], ty: typedesc[T], outOffset: var int, outBytesProcessed: var int, numBytesToRead = none(int) ): T {.inline.} = var bytesRead = 0 let decodedSize = bytes.get(uint, outOffset, outBytesProcessed, numBytesToRead) let length = decodedSize.int when T is string: result = newString(length) for i in outOffset ..< (outOffset + length): result[i - outOffset] = bytes[i].chr elif T is cstring: result = cast[cstring](bytes[outOffset ..< (outOffset + length)]) else: result.setLen(length) for i in outOffset ..< (outOffset + length): result[i - outOffset] = type(result[0])(bytes[i]) increaseBytesRead(length) proc decodeField*[T: SomeFixed | SomeVarint | SomeLengthDelimited]( bytes: var seq[byte], ty: typedesc[T], outOffset: var int, outBytesProcessed: var int, numBytesToRead = none(int) ): ProtoField[T] {.inline.} = var bytesRead = 0 checkType(bytes[outOffset], ty, outOffset) result.index = fieldNumber(bytes[outOffset]) increaseBytesRead() result.value = bytes.get(ty, outOffset, outBytesProcessed, numBytesToRead) proc decodeField*[T: object]( bytes: var seq[byte], ty: typedesc[T], outOffset: var int, outBytesProcessed: var int, numBytesToRead = none(int) ): ProtoField[T] {.inline.} macro setField(obj: typed, fieldNum: int, offset: int, bytesProcessed: int, bytesToRead: Option[int], value: untyped): untyped = let typeFields = obj.getTypeInst.getType let objFields = typeFields[2] expectKind objFields, nnkRecList result = newStmtList() let caseStmt = newNimNode(nnkCaseStmt) caseStmt.add(fieldNum) for i in 0 ..< len(objFields) - 1: let field = objFields[i] let ofBranch = newNimNode(nnkOfBranch) ofBranch.add(newLit(i+1)) ofBranch.add( quote do: `obj`.`field` = decodeField(`value`, type(`obj`.`field`), `offset`, `bytesProcessed`, `bytesToRead`).value ) caseStmt.add(ofBranch) let field = objFields[len(objFields) - 1] let elseBranch = newNimNode(nnkElse) elseBranch.add( nnkStmtList.newTree( quote do: `obj`.`field` = decodeField(`value`, type(`obj`.`field`), `offset`, `bytesProcessed`, `bytesToRead`).value ) ) caseStmt.add(elseBranch) result.add(caseStmt) proc decodeField*[T: object]( bytes: var seq[byte], ty: typedesc[T], outOffset: var int, outBytesProcessed: var int, numBytesToRead = none(int) ): ProtoField[T] {.inline.} = var bytesRead = 0 checkType(bytes[outOffset], ty, outOffset) result.index = fieldNumber(bytes[outOffset]) # read LD header # then read only amount of bytes needed increaseBytesRead() let decodedSize = bytes.get(uint, outOffset, outBytesProcessed, numBytesToRead) let bytesToRead = some(decodedSize.int) let oldOffset = outOffset while outOffset < oldOffset + bytesToRead.get(): let fieldNum = fieldNumber(bytes[outOffset]) setField(result.value, fieldNum, outOffset, outBytesProcessed, bytesToRead, bytes) proc decode*[T: object]( bytes: var seq[byte], ty: typedesc[T], ): T {.inline.} = var bytesRead = 0 var offset = 0 while offset < bytes.len - 1: let fieldNum = fieldNumber(bytes[offset]) setField(result, fieldNum, offset, bytesRead, none(int), bytes)