diff --git a/esp32/include/communication/comm_base.hpp b/esp32/include/communication/comm_base.hpp index aef5a23..410b8ef 100644 --- a/esp32/include/communication/comm_base.hpp +++ b/esp32/include/communication/comm_base.hpp @@ -1,15 +1,85 @@ #pragma once -#include #include -#include -#include -#include +#include +#include +#include +#include -enum message_type_t { CONNECT = 0, DISCONNECT = 1, EVENT = 2, PING = 3, PONG = 4, BINARY_EVENT = 5 }; +template +struct MessageTraits; -typedef std::function EventCallback; -typedef std::function SubscribeCallback; +template <> +struct MessageTraits { + static constexpr pb_size_t tag = socket_message_WebsocketMessage_imu_tag; + static void assign(socket_message_WebsocketMessage& msg, const socket_message_IMUData& data) { + msg.message.imu = data; + } +}; + +template <> +struct MessageTraits { + static constexpr pb_size_t tag = socket_message_WebsocketMessage_mode_tag; + static void assign(socket_message_WebsocketMessage& msg, const socket_message_ModeData& data) { + msg.message.mode = data; + } +}; + +template <> +struct MessageTraits { + static constexpr pb_size_t tag = socket_message_WebsocketMessage_analytics_tag; + static void assign(socket_message_WebsocketMessage& msg, const socket_message_AnalyticsData& data) { + msg.message.analytics = data; + } +}; + +template <> +struct MessageTraits { + static constexpr pb_size_t tag = socket_message_WebsocketMessage_angles_tag; + static void assign(socket_message_WebsocketMessage& msg, const socket_message_AnglesData& data) { + msg.message.angles = data; + } +}; + +template <> +struct MessageTraits { + static constexpr pb_size_t tag = socket_message_WebsocketMessage_rssi_tag; + static void assign(socket_message_WebsocketMessage& msg, const socket_message_RSSIData& data) { + msg.message.rssi = data; + } +}; + +template <> +struct MessageTraits { + static constexpr pb_size_t tag = socket_message_WebsocketMessage_kinematic_data_tag; + static void assign(socket_message_WebsocketMessage& msg, const socket_message_KinematicData& data) { + msg.message.kinematic_data = data; + } +}; + +template <> +struct MessageTraits { + static constexpr pb_size_t tag = socket_message_WebsocketMessage_imu_calibrate_tag; + static void assign(socket_message_WebsocketMessage& msg, const socket_message_IMUCalibrateData& data) { + msg.message.imu_calibrate = data; + } +}; + +template <> +struct MessageTraits { + static constexpr pb_size_t tag = socket_message_WebsocketMessage_i2c_scan_tag; + static void assign(socket_message_WebsocketMessage& msg, const socket_message_I2CScanData& data) { + msg.message.i2c_scan = data; + } +}; + +template <> +struct MessageTraits { + static constexpr pb_size_t tag = socket_message_WebsocketMessage_peripheral_settings_tag; + static void assign(socket_message_WebsocketMessage& msg, const socket_message_PeripheralSettingsData& data) { + msg.message.peripheral_settings = data; + } +}; class CommAdapterBase { public: @@ -18,142 +88,97 @@ class CommAdapterBase { virtual void begin() {} - bool hasSubscribers(const char *event) { return !client_subscriptions[event].empty(); } - - void onEvent(std::string event, EventCallback callback) { event_callbacks[event].push_back(std::move(callback)); } - - void onSubscribe(std::string event, SubscribeCallback callback) { - subscribe_callbacks[event].push_back(std::move(callback)); + bool hasSubscribers(int32_t tag) { + xSemaphoreTake(mutex_, portMAX_DELAY); + bool result = !client_subscriptions_[tag].empty(); + xSemaphoreGive(mutex_); + return result; } - void emit(const char *event, JsonVariant &payload, const char *originId = "", bool onlyToSameOrigin = false) { - int originSubscriptionId = originId[0] ? atoi(originId) : -1; - xSemaphoreTake(mutex_, portMAX_DELAY); - auto &subscriptions = client_subscriptions[event]; - if (subscriptions.empty()) { - xSemaphoreGive(mutex_); + ProtoDecoder& decoder() { return decoder_; } + + template + void emit(const T& data, int clientId = -1) { + constexpr pb_size_t tag = MessageTraits::tag; + + if (clientId < 0 && !hasSubscribers(tag)) return; + + msg_.which_message = tag; + MessageTraits::assign(msg_, data); + + pb_ostream_t stream = pb_ostream_from_buffer(buffer_, sizeof(buffer_)); + if (!pb_encode(&stream, socket_message_WebsocketMessage_fields, &msg_)) { return; } - JsonDocument doc; - JsonArray array = doc.to(); - array.add(static_cast(message_type_t::EVENT)); - array.add(event); - array.add(payload); - -#if USE_MSGPACK - std::string bin; - serializeMsgPack(doc, bin); - xSemaphoreGive(mutex_); - send(reinterpret_cast(bin.data()), bin.size(), -1); -#else - String out; - serializeJson(doc, out); - xSemaphoreGive(mutex_); - send(out.c_str(), -1); -#endif - } - - void send_wsm_by_function( void (*setmsg)(socket_message_WebsocketMessage* message), int cid ) { - setmsg(&msg); - send_wsm(&msg, cid); + if (clientId >= 0) { + send(buffer_, stream.bytes_written, clientId); + } else { + sendToSubscribers(tag, buffer_, stream.bytes_written); + } } protected: - socket_message_WebsocketMessage msg; - uint8_t data_buffer[512]; - void send(const char *data, int cid = -1) { send(reinterpret_cast(data), strlen(data), cid); } - virtual void send(const uint8_t *data, size_t len, int cid = -1) = 0; - void send_wsm(socket_message_WebsocketMessage* message, int cid) { - pb_ostream_t ostream = pb_ostream_from_buffer(data_buffer, sizeof(data_buffer)); - // Encode the message - bool ostatus = pb_encode(&ostream, &socket_message_WebsocketMessage_msg, message); + virtual void send(const uint8_t* data, size_t len, int cid = -1) = 0; - if (!ostatus) { - // TODO: Make a re-encoder using malloc instead (which increases exponentially but only if the error is the buffer size) - printf("Encoding of socket message failed: %s\n", PB_GET_ERROR(&ostream)); - return; - } - - send(data_buffer, ostream.bytes_written, cid); - } - - - void subscribe(const char *event, int cid = 0) { + void subscribe(int32_t tag, int cid = 0) { xSemaphoreTake(mutex_, portMAX_DELAY); - client_subscriptions[event].push_back(cid); + client_subscriptions_[tag].push_back(cid); xSemaphoreGive(mutex_); + ESP_LOGI("ProtoComm", "Client %d subscribed to tag %d", cid, (int)tag); } - void unsubscribe(const char *event, int cid = 0) { + + void unsubscribe(int32_t tag, int cid = 0) { xSemaphoreTake(mutex_, portMAX_DELAY); - client_subscriptions[event].remove(cid); + client_subscriptions_[tag].remove(cid); + xSemaphoreGive(mutex_); + ESP_LOGI("ProtoComm", "Client %d unsubscribed from tag %d", cid, (int)tag); + } + + void removeClient(int cid) { + xSemaphoreTake(mutex_, portMAX_DELAY); + for (auto& [tag, clients] : client_subscriptions_) { + clients.remove(cid); + } xSemaphoreGive(mutex_); } - void handleEventCallbacks(std::string event, JsonVariant &jsonObject, int originId) { - for (auto &callback : event_callbacks[event]) { - callback(jsonObject, originId); + void handleIncoming(const uint8_t* data, size_t len, int cid) { + if (!decoder_.decode(data, len, cid)) { + ESP_LOGE("ProtoComm", "Failed to decode incoming message from client %d", cid); } } - virtual void handleIncoming(const uint8_t *data, size_t len, int cid = 0) { - JsonDocument doc; -#if USE_MSGPACK - DeserializationError error = deserializeMsgPack(doc, data, len); -#else - DeserializationError error = deserializeJson(doc, data, len); -#endif - if (error) { - ESP_LOGE("Comm Base", "Failed to deserialize incoming: (%s)", error.c_str()); - return; - } - - JsonArray obj = doc.as(); // TODO: Make const - message_type_t type = static_cast(obj[0].as()); - - switch (type) { - case message_type_t::CONNECT: { - const char *event = obj[1].as(); - ESP_LOGI("Comm Base", "CONNECT topic: %s (cid=%d)", event, cid); - subscribe(event, cid); - break; - } - - case message_type_t::DISCONNECT: { - const char *event = obj[1].as(); - ESP_LOGI("Comm Base", "DISCONNECT topic: %s (cid=%d)", event, cid); - unsubscribe(event, cid); - break; - } - - case message_type_t::EVENT: { - const char *event = obj[1].as(); - JsonVariant payload = obj[2].as(); - handleEventCallbacks(event, payload, cid); - break; - } - case message_type_t::PING: { - ESP_LOGI("Comm Base", "PING (cid=%d)", cid); - ping(cid); - break; - } - - case message_type_t::PONG: ESP_LOGI("Comm Base", "PONG (cid=%d)", cid); break; - default: ESP_LOGW("Comm Base", "Unknown message type: %d", static_cast(type)); break; + void sendPong(int cid) { + uint8_t pongBuffer[16]; + msg_.which_message = socket_message_WebsocketMessage_pongmsg_tag; + msg_.message.pongmsg = socket_message_PongMsg_init_zero; + pb_ostream_t stream = pb_ostream_from_buffer(pongBuffer, sizeof(pongBuffer)); + if (pb_encode(&stream, socket_message_WebsocketMessage_fields, &msg_)) { + send(pongBuffer, stream.bytes_written, cid); } } - void ping(int cid) { -#if USE_MSGPACK - static const uint8_t pong[] = {0x91, 0x04}; - send(pong, sizeof(pong), cid); -#else - send("[4]", cid); -#endif + void setupDecoderHandlers() { + decoder_.onSubscribe([this](int32_t tag, int cid) { subscribe(tag, cid); }); + + decoder_.onUnsubscribe([this](int32_t tag, int cid) { unsubscribe(tag, cid); }); + + decoder_.onPing([this](int cid) { sendPong(cid); }); } SemaphoreHandle_t mutex_; - std::map> client_subscriptions; - std::map> event_callbacks; - std::map> subscribe_callbacks; -}; \ No newline at end of file + std::map> client_subscriptions_; + ProtoDecoder decoder_; + socket_message_WebsocketMessage msg_ = socket_message_WebsocketMessage_init_zero; + uint8_t buffer_[PROTO_BUFFER_SIZE]; + + private: + void sendToSubscribers(int32_t tag, const uint8_t* data, size_t len) { + xSemaphoreTake(mutex_, portMAX_DELAY); + for (int cid : client_subscriptions_[tag]) { + send(data, len, cid); + } + xSemaphoreGive(mutex_); + } +}; diff --git a/esp32/include/communication/proto_helpers.h b/esp32/include/communication/proto_helpers.h new file mode 100644 index 0000000..e87476e --- /dev/null +++ b/esp32/include/communication/proto_helpers.h @@ -0,0 +1,109 @@ +#pragma once + +#include +#include +#include +#include + +#define PROTO_BUFFER_SIZE 512 + +class ProtoDecoder { + public: + using SubscribeHandler = std::function; + using UnsubscribeHandler = std::function; + using PingHandler = std::function; + using ModeHandler = std::function; + using InputHandler = std::function; + using AnglesHandler = std::function; + using KinematicHandler = std::function; + using WalkGaitHandler = std::function; + using IMUCalibrateExecHandler = std::function; + using I2CScanRequestHandler = std::function; + using PeripheralSettingsRequestHandler = std::function; + + void onSubscribe(SubscribeHandler handler) { subscribeHandler = handler; } + void onUnsubscribe(UnsubscribeHandler handler) { unsubscribeHandler = handler; } + void onPing(PingHandler handler) { pingHandler = handler; } + void onMode(ModeHandler handler) { modeHandler = handler; } + void onInput(InputHandler handler) { inputHandler = handler; } + void onAngles(AnglesHandler handler) { anglesHandler = handler; } + void onKinematic(KinematicHandler handler) { kinematicHandler = handler; } + void onWalkGait(WalkGaitHandler handler) { walkGaitHandler = handler; } + void onIMUCalibrateExec(IMUCalibrateExecHandler handler) { imuCalibrateExecHandler = handler; } + void onI2CScanRequest(I2CScanRequestHandler handler) { i2cScanRequestHandler = handler; } + void onPeripheralSettingsRequest(PeripheralSettingsRequestHandler handler) { + peripheralSettingsRequestHandler = handler; + } + + bool decode(const uint8_t* data, size_t len, int clientId) { + pb_istream_t stream = pb_istream_from_buffer(data, len); + + if (!pb_decode(&stream, socket_message_WebsocketMessage_fields, &msg_)) { + return false; + } + + switch (msg_.which_message) { + case socket_message_WebsocketMessage_sub_notif_tag: + if (subscribeHandler) subscribeHandler(msg_.message.sub_notif.tag, clientId); + break; + + case socket_message_WebsocketMessage_unsub_notif_tag: + if (unsubscribeHandler) unsubscribeHandler(msg_.message.unsub_notif.tag, clientId); + break; + + case socket_message_WebsocketMessage_pingmsg_tag: + if (pingHandler) pingHandler(clientId); + break; + + case socket_message_WebsocketMessage_mode_tag: + if (modeHandler) modeHandler(msg_.message.mode, clientId); + break; + + case socket_message_WebsocketMessage_human_input_data_tag: + if (inputHandler) inputHandler(msg_.message.human_input_data, clientId); + break; + + case socket_message_WebsocketMessage_angles_tag: + if (anglesHandler) anglesHandler(msg_.message.angles, clientId); + break; + + case socket_message_WebsocketMessage_kinematic_data_tag: + if (kinematicHandler) kinematicHandler(msg_.message.kinematic_data, clientId); + break; + + case socket_message_WebsocketMessage_walk_gait_tag: + if (walkGaitHandler) walkGaitHandler(msg_.message.walk_gait, clientId); + break; + + case socket_message_WebsocketMessage_imu_calibrate_execute_tag: + if (imuCalibrateExecHandler) imuCalibrateExecHandler(clientId); + break; + + case socket_message_WebsocketMessage_i2c_scan_data_request_tag: + if (i2cScanRequestHandler) i2cScanRequestHandler(clientId); + break; + + case socket_message_WebsocketMessage_peripheral_settings_data_request_tag: + if (peripheralSettingsRequestHandler) peripheralSettingsRequestHandler(clientId); + break; + + default: return false; + } + + return true; + } + + private: + socket_message_WebsocketMessage msg_ = socket_message_WebsocketMessage_init_zero; + SubscribeHandler subscribeHandler; + UnsubscribeHandler unsubscribeHandler; + PingHandler pingHandler; + ModeHandler modeHandler; + InputHandler inputHandler; + AnglesHandler anglesHandler; + KinematicHandler kinematicHandler; + WalkGaitHandler walkGaitHandler; + IMUCalibrateExecHandler imuCalibrateExecHandler; + I2CScanRequestHandler i2cScanRequestHandler; + PeripheralSettingsRequestHandler peripheralSettingsRequestHandler; +}; diff --git a/esp32/include/communication/websocket_adapter.h b/esp32/include/communication/websocket_adapter.h index 50d4f0a..f43a719 100644 --- a/esp32/include/communication/websocket_adapter.h +++ b/esp32/include/communication/websocket_adapter.h @@ -1,5 +1,4 @@ -#ifndef Socket_h -#define Socket_h +#pragma once #include #include