diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b0276a2..4c2f0b7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,24 +13,24 @@ jobs: fail-fast: false matrix: target: - - os: linux - cpu: amd64 - - os: linux - cpu: i386 - - os: macos - cpu: amd64 + #- os: linux + # cpu: amd64 + #- os: linux + # cpu: i386 + #- os: macos + # cpu: amd64 - os: windows cpu: amd64 nim: [1.6.16, devel] include: - - target: - os: linux - builder: ubuntu-22.04 - shell: bash - - target: - os: macos - builder: macos-13 - shell: bash + #- target: + # os: linux + # builder: ubuntu-22.04 + # shell: bash + #- target: + # os: macos + # builder: macos-13 + # shell: bash - target: os: windows builder: windows-2022 @@ -65,6 +65,17 @@ jobs: - name: Install deps run: | nimble install -dy + # git clone https://github.com/sctplab/usrsctp.git + # cd usrsctp + # mkdir cmake_install + # mkdir cmake_build + # git checkout 01cc4e042e2235b29d9d489d89728a6f9ac063ed + # cmake --version + # ls -la + # cmake -S . -B cmake_build -DCMAKE_BUILD_TYPE=Debug -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DCMAKE_VERBOSE_MAKEFILE:BOOL=ON -Dsctp_debug=ON -Dsctp_invariants=ON -Dsctp_inet6=ON -Dsctp_inet=ON -Dsctp_build_programs=ON -Dsctp_build_fuzzer=OFF -DCMAKE_INSTALL_PREFIX=cmake_install -DCMAKE_INSTALL_LIBDIR=lib -DCMAKE_SYSTEM_PROCESSOR=x86_64 . + # echo "install done" + # cmake --build cmake_build --parallel 2 --config Debug --target install --clean-first --verbose + # cd .. - name: Run tests run: | diff --git a/examples/ping.nim b/examples/ping.nim new file mode 100644 index 0000000..bee1971 --- /dev/null +++ b/examples/ping.nim @@ -0,0 +1,21 @@ +import chronos, stew/byteutils +import ../webrtc/udp_transport +import ../webrtc/stun/stun_transport +import ../webrtc/dtls/dtls_transport +import ../webrtc/sctp/[sctp_transport, sctp_connection] + +proc main() {.async.} = + let laddr = initTAddress("127.0.0.1:4244") + let udp = UdpTransport.new(laddr) + let stun = Stun.new(udp) + let dtls = Dtls.new(stun) + let sctp = Sctp.new(dtls) + + let conn = await sctp.connect(initTAddress("127.0.0.1:4242"), sctpPort = 13) + while true: + await conn.write("ping".toBytes) + let msg = await conn.read() + echo "Received: ", string.fromBytes(msg.data) + await sleepAsync(1.seconds) + +waitFor(main()) diff --git a/examples/pong.nim b/examples/pong.nim new file mode 100644 index 0000000..79c018c --- /dev/null +++ b/examples/pong.nim @@ -0,0 +1,27 @@ +import chronos, stew/byteutils +import ../webrtc/udp_transport +import ../webrtc/stun/stun_transport +import ../webrtc/dtls/dtls_transport +import ../webrtc/sctp/[sctp_transport, sctp_connection] + +proc sendPong(conn: SctpConn) {.async.} = + var i = 0 + while true: + let msg = await conn.read() + echo "Received: ", string.fromBytes(msg.data) + await conn.write(("pong " & $i).toBytes) + i.inc() + +proc main() {.async.} = + let laddr = initTAddress("127.0.0.1:4242") + let udp = UdpTransport.new(laddr) + let stun = Stun.new(udp) + let dtls = Dtls.new(stun) + let sctp = Sctp.new(dtls) + + sctp.listen(13) + while true: + let conn = await sctp.accept() + asyncSpawn conn.sendPong() + +waitFor(main()) diff --git a/tests/runalltests.nim b/tests/runalltests.nim index 3a49806..87404aa 100644 --- a/tests/runalltests.nim +++ b/tests/runalltests.nim @@ -9,5 +9,8 @@ {.used.} -import teststun -import testdtls +{.passc: "-DSCTP_DEBUG".} + +#import teststun +#import testdtls +import testsctp diff --git a/tests/testsctp.nim b/tests/testsctp.nim new file mode 100644 index 0000000..b368995 --- /dev/null +++ b/tests/testsctp.nim @@ -0,0 +1,119 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +{.used.} + +import chronos +import ../webrtc/udp_transport +import ../webrtc/stun/stun_transport +import ../webrtc/dtls/dtls_transport +import ../webrtc/sctp/sctp_transport +import ../webrtc/sctp/sctp_connection +import ./asyncunit + +suite "SCTP": + teardown: + checkLeaks() + echo "TEARDOWN" + + type + SctpStackForTest = object + localAddress: TransportAddress + udp: UdpTransport + stun: Stun + dtls: Dtls + sctp: Sctp + + proc initSctpStack(la: TransportAddress, isServer: bool): SctpStackForTest = + result.udp = UdpTransport.new(la) + result.localAddress = result.udp.localAddress() + echo result.localAddress + result.stun = Stun.new(result.udp) + result.dtls = Dtls.new(result.stun) + result.sctp = Sctp.new(result.dtls) + result.sctp.listen() + + proc closeSctpStack(self: SctpStackForTest) {.async: (raises: [CancelledError]).} = + echo "=> 1" + await self.sctp.stop() + echo "=> 2" + await self.dtls.stop() + echo "=> 3" + await self.stun.stop() + echo "=> 4" + await self.udp.close() + echo "=> 5" + +# asyncTest "Two SCTP nodes connecting to each other, then sending/receiving data": +# var +# sctpServer = initSctpStack(initTAddress("127.0.0.1:0")) +# sctpClient = initSctpStack(initTAddress("127.0.0.1:0")) +# echo "Before Accept" +# let serverConnFut = sctpServer.sctp.accept() +# echo "Before Connect" +# let clientConn = await sctpClient.sctp.connect(sctpServer.localAddress) +# echo "Before await accept" +# let serverConn = await serverConnFut +# echo "Connected :tada:" +# +# await clientConn.write(@[1'u8, 2, 3, 4]) +# check (await serverConn.read()).data == @[1'u8, 2, 3, 4] +# +# await serverConn.write(@[5'u8, 6, 7, 8]) +# check (await clientConn.read()).data == @[5'u8, 6, 7, 8] +# +# await clientConn.write(@[10'u8, 11, 12, 13]) +# await serverConn.write(@[14'u8, 15, 16, 17]) +# check (await clientConn.read()).data == @[14'u8, 15, 16, 17] +# check (await serverConn.read()).data == @[10'u8, 11, 12, 13] +# +# await allFutures(clientConn.close(), serverConn.close()) +# await allFutures(sctpClient.closeSctpStack(), sctpServer.closeSctpStack()) + + asyncTest "Two DTLS nodes connecting to the same DTLS server, sending/receiving data": + echo "==========> Second test" + var + sctpServer = initSctpStack(initTAddress("127.0.0.1:0"), true) + sctpClient1 = initSctpStack(initTAddress("127.0.0.1:0"), false) + sctpClient2 = initSctpStack(initTAddress("127.0.0.1:0"), false) + let + serverConn1Fut = sctpServer.sctp.accept() + clientConn1 = await sctpClient1.sctp.connect(sctpServer.localAddress) + serverConn1 = await serverConn1Fut + serverConn2Fut = sctpServer.sctp.accept() + clientConn2 = await sctpClient2.sctp.connect(sctpServer.localAddress) + serverConn2 = await serverConn2Fut + + echo "==========> Connected" + await serverConn1.write(@[1'u8, 2, 3, 4]) + await serverConn2.write(@[5'u8, 6, 7, 8]) + await clientConn1.write(@[9'u8, 10, 11, 12]) + await clientConn2.write(@[13'u8, 14, 15, 16]) + check: + (await clientConn1.read()).data == @[1'u8, 2, 3, 4] + (await clientConn2.read()).data == @[5'u8, 6, 7, 8] + (await serverConn1.read()).data == @[9'u8, 10, 11, 12] + (await serverConn2.read()).data == @[13'u8, 14, 15, 16] + echo "==========> Read first" + await allFutures(clientConn1.close(), serverConn1.close()) + + await serverConn2.write(@[5'u8, 6, 7, 8]) + await clientConn2.write(@[13'u8, 14, 15, 16]) + check: + (await clientConn2.read()).data == @[5'u8, 6, 7, 8] + (await serverConn2.read()).data == @[13'u8, 14, 15, 16] + await allFutures(clientConn2.close(), serverConn2.close()) + + echo "==========> Read second" + await sctpClient1.closeSctpStack() + echo "1" + await sctpClient2.closeSctpStack() + echo "2" + await sctpServer.closeSctpStack() + echo "==========> END" diff --git a/webrtc.nimble b/webrtc.nimble index da9e66f..f2ee122 100644 --- a/webrtc.nimble +++ b/webrtc.nimble @@ -21,13 +21,22 @@ var cfg = " --styleCheck:usages --styleCheck:error" & (if verbose: "" else: " --verbosity:0 --hints:off") & " --skipParentCfg --skipUserCfg -f" & - " --threads:on --opt:speed" + " --threads:on --opt:speed" & + " --d:chronicles_enabled_topics=sctp:TRACE,dtls:TRACE" when defined(windows): - cfg = cfg & " --clib:ws2_32" + cfg = cfg & " --clib:ws2_32 --clib:iphlpapi" import hashes +proc buildExample(filename: string, run = false, extraFlags = "") = + var excstr = nimc & " " & lang & " " & cfg & " " & flags & " -p:. " & extraFlags + excstr.add(" examples/" & filename) + exec excstr + if run: + exec "./examples/" & filename.toExe + rmFile "examples/" & filename.toExe + proc runTest(filename: string) = var excstr = nimc & " " & lang & " -d:debug " & cfg & " " & flags excstr.add(" -d:nimOldCaseObjects") # TODO: fix this in binary-serialization @@ -36,5 +45,10 @@ proc runTest(filename: string) = exec excstr & " -r " & " tests/" & filename rmFile "tests/" & filename.toExe -task test, "Run test": +task test, "Run the test suite": runTest("runalltests") + exec "nimble build_example" + +task build_example, "Build the examples": + buildExample("ping") + buildExample("pong") diff --git a/webrtc/dtls/dtls_connection.nim b/webrtc/dtls/dtls_connection.nim index b75757f..32c3309 100644 --- a/webrtc/dtls/dtls_connection.nim +++ b/webrtc/dtls/dtls_connection.nim @@ -16,7 +16,7 @@ import import ../errors, ../stun/[stun_connection], ./dtls_utils logScope: - topics = "webrtc dtls_conn" + topics = "webrtc dtls dtls_conn" const DtlsConnTracker* = "webrtc.dtls.conn" diff --git a/webrtc/dtls/dtls_transport.nim b/webrtc/dtls/dtls_transport.nim index 9efd6d5..3207f86 100644 --- a/webrtc/dtls/dtls_transport.nim +++ b/webrtc/dtls/dtls_transport.nim @@ -19,7 +19,7 @@ import ./[dtls_utils, dtls_connection], ../errors, ../stun/[stun_connection, stun_transport] logScope: - topics = "webrtc dtls" + topics = "webrtc dtls dtls_transport" # Implementation of a DTLS client and a DTLS Server by using the Mbed-TLS library. # Multiple things here are unintuitive partly because of the callbacks diff --git a/webrtc/sctp/sctp_connection.nim b/webrtc/sctp/sctp_connection.nim new file mode 100644 index 0000000..283c0ff --- /dev/null +++ b/webrtc/sctp/sctp_connection.nim @@ -0,0 +1,264 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import nativesockets, bitops, sequtils +import usrsctp, chronos, chronicles, stew/[ptrops, endians2, byteutils] +import ./sctp_utils, ../errors, ../dtls/dtls_connection +from posix import EINPROGRESS + +logScope: + topics = "webrtc sctp sctp_connection" + +const SctpConnTracker* = "webrtc.sctp.conn" + +type + SctpConnOnClose* = proc() {.raises: [], gcsafe.} + + SctpState* = enum + SctpConnecting + SctpConnected + SctpClosed + + SctpMessageParameters* = object + protocolId*: uint32 + streamId*: uint16 + endOfRecord*: bool + unordered*: bool + + SctpMessage* = ref object + data*: seq[byte] + info*: sctp_recvv_rn + params*: SctpMessageParameters + + SctpConn* = ref object + conn: DtlsConn + state*: SctpState + onClose: seq[SctpConnOnClose] + connectEvent*: AsyncEvent + acceptEvent*: AsyncEvent + readLoop: Future[void].Raising([CancelledError, WebRtcError]) + sctpSocket*: ptr socket + dataRecv: AsyncQueue[SctpMessage] + sendQueue: seq[byte] + +proc remoteAddress*(self: SctpConn): TransportAddress = + if self.conn.isNil(): + raise newException(WebRtcError, "SCTP - Connection not set") + return self.conn.remoteAddress() + +template usrsctpAwait(self: SctpConn, body: untyped): untyped = + # usrsctpAwait is template which set `sendQueue` to @[] then calls + # an usrsctp function. If during the synchronous run of the usrsctp function + # `sendQueue` is set, it is sent at the end of the function. + proc trySend(conn: SctpConn) {.async: (raises: [CancelledError]).} = + try: + trace "Send To", address = conn.remoteAddress() + await conn.conn.write(self.sendQueue) + except CatchableError as exc: + trace "Send Failed", exceptionMsg = exc.msg + + self.sendQueue = @[] + when type(body) is void: + (body) + if self.sendQueue.len() > 0: + await self.trySend() + else: + let res = (body) + if self.sendQueue.len() > 0: + await self.trySend() + res + +# -- usrsctp send and receive callback -- + +proc recvCallback*(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = + # Callback procedure called when we receive data after a connection + # has been established. + let + conn = cast[SctpConn](data) + events = usrsctp_get_events(sock) + + trace "Receive callback", events + if bitand(events, SCTP_EVENT_READ) != 0: + var + message = SctpMessage(data: newSeq[byte](4096)) + address: Sockaddr_storage + rn: sctp_recvv_rn + addressLen = sizeof(Sockaddr_storage).SockLen + rnLen = sizeof(sctp_recvv_rn).SockLen + infotype: uint + flags: int + let n = sock.usrsctp_recvv( + cast[pointer](addr message.data[0]), + message.data.len.uint, + cast[ptr SockAddr](addr address), + cast[ptr SockLen](addr addressLen), + cast[pointer](addr message.info), + cast[ptr SockLen](addr rnLen), + cast[ptr cuint](addr infotype), + cast[ptr cint](addr flags), + ) + if n < 0: + warn "usrsctp_recvv", error = sctpStrerror() + return + elif n > 0: + # It might be necessary to check if infotype == SCTP_RECVV_RCVINFO + message.data.delete(n ..< message.data.len()) + trace "message info from handle upcall", msginfo = message.info + message.params = SctpMessageParameters( + protocolId: message.info.recvv_rcvinfo.rcv_ppid.swapBytes(), + streamId: message.info.recvv_rcvinfo.rcv_sid, + ) + if bitand(flags, MSG_NOTIFICATION) != 0: + trace "Notification received", length = n + else: + try: + conn.dataRecv.addLastNoWait(message) + except AsyncQueueFullError: + trace "Queue full, dropping packet" + elif bitand(events, SCTP_EVENT_WRITE) != 0: + trace "sctp event write in the upcall" + else: + warn "Handle Upcall unexpected event", events + +proc sendCallback*( + ctx: pointer, buffer: pointer, length: uint, tos: uint8, set_df: uint8 +): cint {.cdecl.} = + # This proc is called by usrsctp everytime usrsctp tries to send data. + let + conn = cast[SctpConn](ctx) + buf = @(buffer.makeOpenArray(byte, int(length))) + trace "sendCallback", sctpPacket = $(buf.getSctpPacket()) + proc testSend() {.async: (raises: [CancelledError]).} = + try: + trace "Send To", address = conn.remoteAddress() + await conn.conn.write(buf) + except CatchableError as exc: + trace "Send Failed", message = exc.msg + + conn.sendQueue = buf + +proc addOnClose*(self: SctpConn, onCloseProc: SctpConnOnClose) = + ## Adds a proc to be called when SctpConn is closed + ## + self.onClose.add(onCloseProc) + +proc readLoopProc(self: SctpConn) {.async: (raises: [CancelledError, WebRtcError]).} = + while true: + let msg = await self.conn.read() + if msg == @[]: + trace "Sctp read loop stopped, DTLS connection closed" + return + trace "Receive data", + remoteAddress = self.conn.remoteAddress(), sctPacket = $(msg.getSctpPacket()) + self.usrsctpAwait: + usrsctp_conninput(cast[pointer](self), unsafeAddr msg[0], uint(msg.len), 0) + +proc new*(T: typedesc[SctpConn], conn: DtlsConn): T = + result = T( + conn: conn, + state: SctpConnecting, + connectEvent: AsyncEvent(), + acceptEvent: AsyncEvent(), + dataRecv: newAsyncQueue[SctpMessage](), + ) + result.readLoop = result.readLoopProc() + usrsctp_register_address(cast[pointer](result)) + +proc connect*(self: SctpConn, sctpPort: uint16) {.async: (raises: [CancelledError, WebRtcError]).} = + var sconn: Sockaddr_conn + when compiles(sconn.sconn_len): # when macos apple or openbsd for example + echo "???" + sconn.sconn_len = sizeof(sconn).uint8 + sconn.sconn_family = AF_CONN + sconn.sconn_port = htons(sctpPort) + sconn.sconn_addr = cast[pointer](self) + + echo "======> before connect", sconn.sconn_family + let connErr = self.usrsctpAwait: self.sctpSocket.usrsctp_connect( + cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn)) + ) + if connErr != 0 and errno != posix.EINPROGRESS: + echo "======> after connect (if failed) ", connErr, " ", errno, " ", posix.EINPROGRESS + raise + newException(WebRtcError, "SCTP - Connection failed: " & sctpStrerror()) + +proc read*(self: SctpConn): Future[SctpMessage] {.async: (raises: [CancelledError, WebRtcError]).} = + # Used by DataChannel, returns SctpMessage in order to get the stream + # and protocol ids + if self.state == SctpClosed: + raise newException(WebRtcError, "Try to read on an already closed SctpConn") + return await self.dataRecv.popFirst() + +proc toFlags(params: SctpMessageParameters): uint16 = + if params.endOfRecord: + result = result or SCTP_EOR + if params.unordered: + result = result or SCTP_UNORDERED + +proc write*( + self: SctpConn, buf: seq[byte], sendParams = default(SctpMessageParameters) +) {.async: (raises: [CancelledError, WebRtcError]).} = + # Used by DataChannel, writes buf on the Dtls connection. + if self.state == SctpClosed: + raise newException(WebRtcError, "Try to write on an already closed SctpConn") + var cpy = buf + let sendvErr = + if sendParams == default(SctpMessageParameters): + # If writes is called by DataChannel, sendParams should never + # be the default value. This split is useful for testing. + self.usrsctpAwait: + self.sctpSocket.usrsctp_sendv( + cast[pointer](addr cpy[0]), + cpy.len().uint, + nil, + 0, + nil, + 0, + SCTP_SENDV_NOINFO.cuint, + 0, + ) + else: + var sendInfo = sctp_sndinfo( + snd_sid: sendParams.streamId, + snd_ppid: sendParams.protocolId.swapBytes(), + snd_flags: sendParams.toFlags(), + ) + self.usrsctpAwait: + self.sctpSocket.usrsctp_sendv( + cast[pointer](addr cpy[0]), + cpy.len().uint, + nil, + 0, + cast[pointer](addr sendInfo), + sizeof(sendInfo).SockLen, + SCTP_SENDV_SNDINFO.cuint, + 0, + ) + if sendvErr < 0: + raise newException(WebRtcError, "SCTP - " & sctpStrerror()) + +proc write*( + self: SctpConn, s: string +) {.async: (raises: [CancelledError, WebRtcError]).} = + await self.write(s.toBytes()) + +proc close*(self: SctpConn) {.async: (raises: [CancelledError, WebRtcError]).} = + if self.state == SctpClosed: + debug "Try to close SctpConn twice" + return + usrsctp_deregister_address(cast[pointer](self)) + self.usrsctpAwait: + self.sctpSocket.usrsctp_close() + await self.readLoop.cancelAndWait() + self.state = SctpClosed + untrackCounter(SctpConnTracker) + await self.conn.close() + for onCloseProc in self.onClose: + onCloseProc() + self.onClose = @[] diff --git a/webrtc/sctp/sctp_transport.nim b/webrtc/sctp/sctp_transport.nim new file mode 100644 index 0000000..fe54fd3 --- /dev/null +++ b/webrtc/sctp/sctp_transport.nim @@ -0,0 +1,263 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import tables, bitops, nativesockets, strutils, sequtils +import usrsctp, chronos, chronicles +import + ./[sctp_connection, sctp_utils], ../errors, ../dtls/dtls_transport +when defined(windows): + from winlean import AF_INET + +export chronicles + +const + SctpTransportTracker* = "webrtc.sctp.transport" + IPPROTO_SCTP = 132 + +logScope: + topics = "webrtc sctp sctp_transport" + +# Implementation of an Sctp client and server using the usrsctp library. +# Usrsctp is usable with a single thread but this is not the intended +# way to use it. As a result, there are many callbacks that calls each +# other synchronously. + +proc printf( + format: cstring +) {.cdecl, importc: "printf", varargs, header: "", gcsafe.} + +type Sctp* = ref object + dtls: Dtls # Underlying Dtls Transport + connections: Table[TransportAddress, SctpConn] # List of all the Sctp connections + isServer: bool + sockServer: ptr socket # usrsctp "server" socket to accept new connections + +# -- usrsctp accept and connect callbacks -- + +from posix import EINPROGRESS +proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = + # Callback procedure called when a connection is about to be accepted. + echo "======> handle accept", posix.EINPROGRESS + var + sconn: Sockaddr_conn + slen: Socklen = sizeof(Sockaddr_conn).uint32 + echo "======> before usrsctp_accept" + let + sctp = cast[Sctp](data) + sctpSocket = + usrsctp_accept(sctp.sockServer, cast[ptr SockAddr](addr sconn), addr slen) + conn = cast[SctpConn](sconn.sconn_addr) + echo "======> after usrsctp_accept" + + if sctpSocket.isNil(): + warn "usrsctp_accept failed", error = sctpStrerror() + conn.state = SctpState.SctpClosed + else: + trace "Scpt connection accepted", remoteAddress = conn.remoteAddress() + conn.sctpSocket = sctpSocket + conn.state = SctpState.SctpConnected + conn.acceptEvent.fire() + +proc handleConnect(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} = + # Callback procedure called during usrsctp_connect + echo "======> handle connect" + let + conn = cast[SctpConn](data) + events = usrsctp_get_events(sock) + + if conn.state == SctpState.SctpConnecting: + if bitand(events, SCTP_EVENT_ERROR) != 0: + warn "Cannot connect", remoteAddress = conn.remoteAddress() + conn.state = SctpState.SctpClosed + elif bitand(events, SCTP_EVENT_WRITE) != 0: + conn.state = SctpState.SctpConnected + if usrsctp_set_upcall(conn.sctpSocket, recvCallback, data) != 0: + warn "usrsctp_set_upcall fails while connecting", error = sctpStrerror() + trace "Sctp connection connected", remoteAddress = conn.remoteAddress() + echo "======> connect event fire" + conn.connectEvent.fire() + else: + warn "Should never happen", currentState = conn.state + +proc stopServer*(self: Sctp) = + ## Sctp Transport stop acting like a server + ## + echo "====> stopServer 1" + if not self.isServer: + trace "Try to close a client" + return + echo "====> stopServer 2" + self.isServer = false + echo "====> stopServer 3 ", cast[uint64](cast[pointer](self.sockServer)), " ", sizeof(self.sockServer) + self.sockServer.usrsctp_close() + echo "====> stopServer 4" + +proc serverSetup(self: Sctp, sctpPort: uint16): bool = + # This procedure setup usrsctp to be in "server mode" and + # creates an sctp socket on which we will listen + if usrsctp_sysctl_set_sctp_blackhole(2) != 0: + warn "usrsctp_sysctl_set_sctp_blackhole failed", error = sctpStrerror() + return false + + if usrsctp_sysctl_set_sctp_no_csum_on_loopback(0) != 0: + warn "usrsctp_sysctl_set_sctp_no_csum_on_loopback failed", error = sctpStrerror() + return false + + if usrsctp_sysctl_set_sctp_delayed_sack_time_default(0) != 0: + warn "usrsctp_sysctl_set_sctp_delayed_sack_time_default failed", error = sctpStrerror() + return false + + let sock = usrsctp_socket(AF_CONN, SOCK_STREAM.toInt(), IPPROTO_SCTP, nil, nil, 0, nil) + if usrsctp_set_non_blocking(sock, 1) != 0: + warn "usrsctp_set_non_blocking failed", error = sctpStrerror() + return false + + var sin: Sockaddr_in + when defined(windows): + sin.sin_family = type(sin.sin_family)(winlean.AF_INET) + else: + sin.sin_family = type(sin.sin_family)(nativesockets.AF_INET) + sin.sin_port = htons(sctpPort) + sin.sin_addr.s_addr = htonl(INADDR_ANY) + if usrsctp_bind(sock, cast[ptr SockAddr](addr sin), SockLen(sizeof(Sockaddr_in))) != 0: + warn "usrsctp_bind failed", error = sctpStrerror() + return false + + if usrsctp_listen(sock, 1) < 0: + warn "usrsctp_listen failed", error = sctpStrerror() + return false + + if sock.usrsctp_set_upcall(handleAccept, cast[pointer](self)) != 0: + warn "usrsctp_set_upcall failed", error = sctpStrerror() + return false + + echo "????????????????????" + self.sockServer = sock + echo "sockServer: ", cast[uint64](cast[pointer](self.sockServer)) + return true + +proc listen*(self: Sctp, sctpPort: uint16 = 5000) = + ## `listen` marks the Sctp Transport as a transport that will be used to accept + ## incoming connection requests using accept. + ## + if self.isServer: + trace "Try to start the server twice" + return + self.isServer = true + trace "Sctp listening", sctpPort + if not self.serverSetup(sctpPort): + raise newException(WebRtcError, "SCTP - Fails to listen") + +proc new*(T: type Sctp, dtls: Dtls): T = + ## Creates a new Sctp Transport + ## + var self = T() + self.dtls = dtls + + usrsctp_init_nothreads(dtls.localAddress.port.uint16, sendCallback, sctpPrintf) + if usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_ALL.uint32) != 0: + # Enabling debug is not critical, doesn't matter if it fails + trace "usrsctp_sysctl_set_sctp_debug_on failed", error = sctpStrerror() + if usrsctp_sysctl_set_sctp_ecn_enable(1) != 0: + # In the same way, enabling explicit congestion notification isn't required + trace "usrsctp_sysctl_set_sctp_ecn_enable failed", error = sctpStrerror() + trackCounter(SctpTransportTracker) + return self + +proc stop*(self: Sctp) {.async: (raises: [CancelledError]).} = + ## Stops the Sctp Transport + ## + echo "==> stop 1 isServer? ", self.isServer + if self.isServer: + self.stopServer() + untrackCounter(SctpTransportTracker) + echo "==> stop 2" + let connections = toSeq(self.connections.values()) + echo "==> stop 3 ", connections.len() + await allFutures(connections.mapIt(it.close())) + echo "==> stop 4" + if usrsctp_finish() != 0: + warn "usrsct_finish failed", error = sctpStrerror() + echo "==> stop 5" + +proc socketSetup( + conn: SctpConn, callback: proc(a1: ptr socket, a2: pointer, a3: cint) {.cdecl.} +): bool = + # This procedure setup SctpConn. It should be in `sctp_connection.nim` file but I + # prefer not to expose it. + if conn.sctpSocket.usrsctp_set_non_blocking(1) != 0: + warn "usrsctp_set_non_blocking failed", error = sctpStrerror() + return false + + if conn.sctpSocket.usrsctp_set_upcall(callback, cast[pointer](conn)) != 0: + warn "usrsctp_set_upcall failed", error = sctpStrerror() + return false + + var nodelay: uint32 = 1 + if conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY, addr nodelay, sizeof(nodelay).SockLen) != 0: + warn "usrsctp_setsockopt nodelay failed", error = sctpStrerror() + return false + + var recvinfo: uint32 = 1 + if conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO, addr recvinfo, sizeof(recvinfo).SockLen) != 0: + warn "usrsctp_setsockopt recvinfo failed", error = sctpStrerror() + return false + return true + +proc addConnToTable(self: Sctp, conn: SctpConn) = + let remoteAddress = conn.remoteAddress() + proc cleanup() = + self.connections.del(remoteAddress) + self.connections[remoteAddress] = conn + conn.addOnClose(cleanup) + +proc accept*( + self: Sctp +): Future[SctpConn] {.async: (raises: [CancelledError, WebRtcError]).} = + ## Accept an Sctp Connection + ## + if not self.isServer: + raise newException(WebRtcError, "SCTP - Not a server") + echo "=======> accept ?" + var conn: SctpConn + while true: + conn = SctpConn.new(await self.dtls.accept()) + conn.acceptEvent.clear() + echo "=======> wait accept event" + await conn.acceptEvent.wait() + if conn.state == SctpState.SctpConnected and conn.socketSetup(recvCallback): + break + await conn.close() + + self.addConnToTable(conn) + trackCounter(SctpConnTracker) + return conn + +proc connect*( + self: Sctp, raddr: TransportAddress, sctpPort: uint16 = 5000 +): Future[SctpConn] {.async: (raises: [CancelledError, WebRtcError]).} = + ## Connect to a remote address and returns an Sctp Connection + ## + let conn = SctpConn.new(await self.dtls.connect(raddr)) + conn.state = SctpState.SctpConnecting + conn.sctpSocket = + usrsctp_socket(AF_CONN, SOCK_STREAM.toInt(), IPPROTO_SCTP, nil, nil, 0, nil) + + if not conn.socketSetup(handleConnect): + raise newException(WebRtcError, "SCTP - Socket setup failed while connecting") + + await conn.connect(sctpPort) + + conn.connectEvent.clear() + await conn.connectEvent.wait() + if conn.state == SctpState.SctpClosed: + raise newException(WebRtcError, "SCTP - Connection failed") + self.addConnToTable(conn) + trackCounter(SctpConnTracker) + return conn diff --git a/webrtc/sctp/sctp_utils.c b/webrtc/sctp/sctp_utils.c new file mode 100644 index 0000000..ac4a726 --- /dev/null +++ b/webrtc/sctp/sctp_utils.c @@ -0,0 +1,9 @@ +#include +#include + +void sctpPrintf(const char *fmt, ...) { + va_list args; + va_start(args, fmt); // Initialize the argument list + vprintf(fmt, args); // Call vprintf with the argument list + va_end(args); // Clean up +} diff --git a/webrtc/sctp/sctp_utils.nim b/webrtc/sctp/sctp_utils.nim new file mode 100644 index 0000000..ff16913 --- /dev/null +++ b/webrtc/sctp/sctp_utils.nim @@ -0,0 +1,78 @@ +# Nim-WebRTC +# Copyright (c) 2024 Status Research & Development GmbH +# Licensed under either of +# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE)) +# * MIT license ([LICENSE-MIT](LICENSE-MIT)) +# at your option. +# This file may not be copied, modified, or distributed except according to +# those terms. + +import nativesockets +import binary_serialization, chronos + +{.compile: "sctp_utils.c".} + +proc sctpPrintf*( + format: cstring +) {.cdecl, importc: "sctpPrintf", varargs, gcsafe.} + +var errno* {.importc, header: "".}: cint ## error variable + +type + # These three objects are used for debugging/trace only + SctpChunk* = object + chunkType*: uint8 + flag*: uint8 + length* {.bin_value: it.data.len() + 4.}: uint16 + data* {.bin_len: it.length - 4.}: seq[byte] + + SctpPacketHeader* = object + srcPort*: uint16 + dstPort*: uint16 + verifTag*: uint32 + checksum*: uint32 + + SctpPacketStructure* = object + header*: SctpPacketHeader + chunks*: seq[SctpChunk] + +proc dataToString(data: seq[byte]): string = + # Only used for debugging/trace + if data.len() < 8: + return $data + result = "@[" + result &= $data[0] & ", " & $data[1] & ", " & $data[2] & ", " & $data[3] & " ... " + result &= $data[^4] & ", " & $data[^3] & ", " & $data[^2] & ", " & $data[^1] & "]" + +proc `$`*(packet: SctpPacketStructure): string = + # Only used for debugging/trace + result = "{header: {srcPort: " + result &= $(packet.header.srcPort) & ", dstPort: " + result &= $(packet.header.dstPort) & "}, chunks: @[" + var counter = 0 + for chunk in packet.chunks: + result &= "{type: " & $(chunk.chunkType) & ", len: " + result &= $(chunk.length) & ", data: " + result &= chunk.data.dataToString() + counter += 1 + if counter < packet.chunks.len(): + result &= ", " + result &= "]}" + +proc getSctpPacket*(buffer: seq[byte]): SctpPacketStructure = + # Only used for debugging/trace + result.header = Binary.decode(buffer, SctpPacketHeader) + var size = sizeof(SctpPacketHeader) + while size < buffer.len: + let chunk = Binary.decode(buffer[size ..^ 1], SctpChunk) + result.chunks.add(chunk) + size.inc(chunk.length.int) + while size mod 4 != 0: + # padding; could use `size.inc(-size %% 4)` instead but it lacks clarity + size.inc(1) + +proc sctpStrerror*(): string = + proc strerror( + error: int + ): cstring {.importc: "strerror", cdecl, header: "".} + return $errno & ": " & $(strerror(errno)) diff --git a/webrtc/stun/stun_connection.nim b/webrtc/stun/stun_connection.nim index 9d661a2..baf8a75 100644 --- a/webrtc/stun/stun_connection.nim +++ b/webrtc/stun/stun_connection.nim @@ -201,7 +201,7 @@ proc new*( ## var self = T( udp: udp, - laddr: udp.laddr, + laddr: udp.localAddress, raddr: raddr, closed: false, dataRecv: newAsyncQueue[seq[byte]](StunMaxQueuingMessages), diff --git a/webrtc/stun/stun_message.nim b/webrtc/stun/stun_message.nim index 6206bfe..1d5e696 100644 --- a/webrtc/stun/stun_message.nim +++ b/webrtc/stun/stun_message.nim @@ -17,7 +17,7 @@ import stun_attributes, ../errors export binary_serialization logScope: - topics = "webrtc stun" + topics = "webrtc stun stun_message" const StunMsgHeaderSize = 20 diff --git a/webrtc/stun/stun_transport.nim b/webrtc/stun/stun_transport.nim index becc593..3d52f53 100644 --- a/webrtc/stun/stun_transport.nim +++ b/webrtc/stun/stun_transport.nim @@ -108,7 +108,7 @@ proc new*( ## var self = T( udp: udp, - laddr: udp.laddr, + laddr: udp.localAddress(), usernameProvider: usernameProvider, usernameChecker: usernameChecker, passwordProvider: passwordProvider, diff --git a/webrtc/udp_transport.nim b/webrtc/udp_transport.nim index e458a5e..1f1625c 100644 --- a/webrtc/udp_transport.nim +++ b/webrtc/udp_transport.nim @@ -23,7 +23,7 @@ type raddr: TransportAddress UdpTransport* = ref object - laddr*: TransportAddress + laddr: TransportAddress udp: DatagramTransport dataRecv: AsyncQueue[UdpPacketInfo] closed: bool @@ -33,7 +33,7 @@ const UdpTransportTrackerName* = "webrtc.udp.transport" proc new*(T: type UdpTransport, laddr: TransportAddress): T = ## Initialize an Udp Transport ## - var self = T(laddr: laddr, closed: false) + var self = T(closed: false) proc onReceive( udp: DatagramTransport, @@ -49,6 +49,7 @@ proc new*(T: type UdpTransport, laddr: TransportAddress): T = self.dataRecv = newAsyncQueue[UdpPacketInfo]() self.udp = newDatagramTransport(onReceive, local = laddr) + self.laddr = self.udp.localAddress() trackCounter(UdpTransportTrackerName) return self @@ -87,3 +88,6 @@ proc read*(self: UdpTransport): Future[UdpPacketInfo] {.async: (raises: [Cancell return trace "UDP read" return await self.dataRecv.popFirst() + +proc localAddress*(self: UdpTransport): TransportAddress = + self.laddr