diff --git a/app/src/lib/stores/socket.ts b/app/src/lib/stores/socket.ts index c608dda..5ecbe67 100644 --- a/app/src/lib/stores/socket.ts +++ b/app/src/lib/stores/socket.ts @@ -6,13 +6,14 @@ import type { BinaryWriter } from '@bufbuild/protobuf/wire' // -------- START PARSING PROTO DATA -------- -// Auto-build reverse mapping from MessageFns to event key and tag -const MESSAGE_TYPE_TO_KEY = new Map, string>() -const MESSAGE_TYPE_TO_TAG = new Map, number>() +// Auto-build reverse mapping from MessageFns to event key and tag +export const MESSAGE_TYPE_TO_KEY = new Map, string>() +export const MESSAGE_TYPE_TO_TAG = new Map, number>() +export const MESSAGE_KEY_TO_TAG = new Map() // Build the mapping using references from metadata const websocketMessageType = websocket_md.fileDescriptor.messageType?.find( - msg => msg.name === 'WebsocketMessage' + ( msg: { name: string } ) => msg.name === 'WebsocketMessage' ) if (websocketMessageType?.field) { @@ -23,6 +24,7 @@ if (websocketMessageType?.field) { if (messageFns && field.jsonName && field.number) { MESSAGE_TYPE_TO_KEY.set(messageFns, field.jsonName) MESSAGE_TYPE_TO_TAG.set(messageFns, field.number) + MESSAGE_KEY_TO_TAG.set(field.jsonName, field.number) } } } @@ -50,29 +52,34 @@ function get_tag_from_messagetype(event_type: MessageFns): number { const socketEvents = ['open', 'close', 'error', 'message', 'unresponsive'] as const type SocketEvent = (typeof socketEvents)[number] -type TaggedSocketMessage = [string, WebsocketMessage] +type TaggedSocketMessage = {"tag": number, "msg": WebsocketMessage} - -const decodeMessage = (data: ArrayBuffer): TaggedSocketMessage => { +// Only exported for socket test +export const decodeMessage = (data: ArrayBuffer): TaggedSocketMessage => { const decoded = WebsocketMessage.decode(new Uint8Array(data)); const values = Object.entries(decoded).filter(([, value]) => value !== undefined) // Filter all values which are not undefined if (values.length != 1) { throw new Error("Message included either 0 or more than 1 data point") } - const [event, value] = values[0] - return [event, decoded] + const fieldName = values[0][0] + const tag = MESSAGE_KEY_TO_TAG.get(fieldName) + if (tag === undefined) { + throw new Error(`Tag not found for field: ${fieldName}`) + } + return {"tag": tag, "msg": decoded} } -const encodeMessage = (data: WebsocketMessage): Uint8Array => { +export const encodeMessage = (data: WebsocketMessage): Uint8Array => { const encoded = WebsocketMessage.encode(data).finish(); return encoded; } function createWebSocket() { - const listeners = new Map void>>() + const message_listeners = new Map void>>() + const event_listeners = new Map void>>() const { subscribe, set } = writable(false) const reconnectTimeoutTime = 5000 let unresponsiveTimeoutId: ReturnType @@ -85,16 +92,22 @@ function createWebSocket() { connect() } - function getListeners(event: MessageFns | string): Set<(data?: unknown) => void> { - if (typeof event != "string") { // Parse messagefns to string - event = get_name_from_messagetype(event) - } + function getMsgListeners(event_type: MessageFns): Set<(data?: unknown) => void> { + const type_tag = get_tag_from_messagetype(event_type) - const event_listeners = listeners.get(event); - if (event_listeners == undefined) { + const type_listeners = message_listeners.get(type_tag); + if (type_listeners == undefined) { return new Set() } - return event_listeners; + return type_listeners; + } + function getListeners(event: string): Set<(data?: unknown) => void> { + + const event_listeners_forevent = event_listeners.get(event); + if (event_listeners_forevent == undefined) { + return new Set() + } + return event_listeners_forevent; } function disconnect(reason: SocketEvent, event?: Event) { @@ -102,7 +115,7 @@ function createWebSocket() { set(false) clearTimeout(unresponsiveTimeoutId) clearTimeout(reconnectTimeoutId) - listeners.get(reason)?.forEach(listener => listener(event)) + event_listeners.get(reason)?.forEach(listener => listener(event)) reconnectTimeoutId = setTimeout(connect, reconnectTimeoutTime) } @@ -113,7 +126,7 @@ function createWebSocket() { ping() set(true) clearTimeout(reconnectTimeoutId) - listeners.get('open')?.forEach(listener => listener(ev)) + event_listeners.get('open')?.forEach(listener => listener(ev)) // TODO: Check if this makes sense? we also call subscribe to event when a new listen calls the "on" function // for (const event of listeners.keys()) { // if (socketEvents.includes(event as SocketEvent)) continue @@ -122,26 +135,24 @@ function createWebSocket() { } ws.onmessage = frame => { resetUnresponsiveCheck() - const [event, message] = decodeMessage(frame.data) - if (event) listeners.get(event)?.forEach(listener => listener(message)) + const {tag, msg} = decodeMessage(frame.data) + if (tag) message_listeners.get(tag)?.forEach(listener => listener(msg)) } ws.onerror = ev => disconnect('error', ev) ws.onclose = ev => disconnect('close', ev) } - function unsubscribe(event_type: MessageFns, listener?: (data: unknown) => void) { - const event = get_name_from_messagetype(event_type) - const eventListeners = listeners.get(event) - if (!eventListeners) return + function unsubscribe(event_type: MessageFns, listener: (data: unknown) => void) { + const tag = get_tag_from_messagetype(event_type) + const message_listeners_totag = message_listeners.get(tag) + if (!message_listeners_totag) return - if (!eventListeners.size) { + // TODO: This looks like it deletes an individual listener, but unsubscribe unsubscribes for everyone. Not sure what it is supposed to do right now + message_listeners_totag?.delete(listener) + if (message_listeners_totag.size == 0) { // No more listeners, so we can unsubscribe unsubscribeToEvent(event_type) } - if (listener) { - eventListeners?.delete(listener) - } else { - listeners.delete(event) - } + } function resetUnresponsiveCheck() { @@ -191,18 +202,15 @@ function createWebSocket() { sendEvent, init, on: (event_type: MessageFns, listener: (data: T) => void): (() => void) => { - const event = get_name_from_messagetype(event_type); + const tag = get_tag_from_messagetype(event_type); - let eventListeners = listeners.get(event) - if (!eventListeners) { + let message_listeners_totag = message_listeners.get(tag) + if (!message_listeners_totag) { // If this is the first listener to this event, also call subscribe to the server - if (!socketEvents.includes(event as SocketEvent)) { - subscribeToEvent(event_type) - } - eventListeners = new Set() - listeners.set(event, eventListeners) + message_listeners_totag = new Set() + message_listeners.set(tag, message_listeners_totag) } - eventListeners.add(listener as (data: unknown) => void) + message_listeners_totag.add(listener as (data: unknown) => void) return () => { unsubscribe(event_type, listener as (data: unknown) => void) diff --git a/app/tests/unit/socket.spec.ts b/app/tests/unit/socket.spec.ts index 8041025..f2ab590 100644 --- a/app/tests/unit/socket.spec.ts +++ b/app/tests/unit/socket.spec.ts @@ -1,7 +1,7 @@ import { describe, it, expect, beforeEach, afterEach } from 'vitest' import { WebSocketServer } from 'ws' -import { socket } from '../../src/lib/stores/socket' -import { IMUData, RSSIData, WebsocketMessage } from '../../src/lib/platform_shared/websocket_message' +import { decodeMessage, MESSAGE_KEY_TO_TAG, socket } from '../../src/lib/stores/socket' +import { IMUData, PingMsg, PongMsg, RSSIData, WebsocketMessage, protoMetadata as websocket_md } from '../../src/lib/platform_shared/websocket_message' // Helper function to create encoded WebSocket messages function createEncodedMessage(messageType: 'imu' | 'rssi' | 'mode', data: any): Uint8Array { @@ -207,6 +207,17 @@ describe('WebsocketMessage Protobuf Encoding/Decoding', () => { expect(decoded.temp).toBe(imuData.temp) }) + it('should encode and decode two empty types correctly', () => { + + const encoded_ping = WebsocketMessage.encode(WebsocketMessage.create({ pingmsg: PingMsg.create() })).finish() + const decoded_ping = decodeMessage(encoded_ping.buffer) + expect(decoded_ping.tag).toBe(MESSAGE_KEY_TO_TAG.get("pingmsg")) + + const encoded_pong = WebsocketMessage.encode(WebsocketMessage.create({ pongmsg: PongMsg.create() })).finish() + const decoded_pong = decodeMessage(encoded_pong.buffer) + expect(decoded_pong.tag).toBe(MESSAGE_KEY_TO_TAG.get("pongmsg")) + }) + it('should encode and decode complete WebsocketMessage', () => {