diff --git a/esp32/include/communication/comm_base.hpp b/esp32/include/communication/comm_base.hpp new file mode 100644 index 0000000..1a55429 --- /dev/null +++ b/esp32/include/communication/comm_base.hpp @@ -0,0 +1,141 @@ +#pragma once + +#include +#include + +enum message_type_t { CONNECT = 0, DISCONNECT = 1, EVENT = 2, PING = 3, PONG = 4, BINARY_EVENT = 5 }; + +typedef std::function EventCallback; +typedef std::function SubscribeCallback; + +class CommAdapterBase { + public: + CommAdapterBase() { mutex_ = xSemaphoreCreateMutex(); } + ~CommAdapterBase() { vSemaphoreDelete(mutex_); } + + virtual void begin() {} + + bool hasSubscribers(const char *event) { return !client_subscriptions[event].empty(); } + + void onEvent(std::string event, EventCallback callback) { event_callbacks[event].push_back(std::move(callback)); } + + void onSubscribe(std::string event, SubscribeCallback callback) { + subscribe_callbacks[event].push_back(std::move(callback)); + } + + void emit(const char *event, JsonVariant &payload, const char *originId = "", bool onlyToSameOrigin = false) { + int originSubscriptionId = originId[0] ? atoi(originId) : -1; + xSemaphoreTake(mutex_, portMAX_DELAY); + auto &subscriptions = client_subscriptions[event]; + if (subscriptions.empty()) { + xSemaphoreGive(mutex_); + return; + } + + JsonDocument doc; + JsonArray array = doc.to(); + array.add(static_cast(message_type_t::EVENT)); + array.add(event); + array.add(payload); + + // TODO: Only send to subscribed + +#if USE_MSGPACK + std::string bin; + serializeMsgPack(doc, bin); + send(reinterpret_cast(bin.data()), bin.size(), -1); // TODO: Make CID dynamic +#else + String out; + serializeJson(doc, out); + send(out.c_str(), cid); +#endif + } + + protected: + void send(const char *data, int cid = -1) { send(reinterpret_cast(data), strlen(data), cid); } + virtual void send(const uint8_t *data, size_t len, int cid = -1) = 0; + + void subscribe(const char *event, int cid = 0) { client_subscriptions[event].push_back(cid); } + void unsubscribe(const char *event, int cid = 0) { client_subscriptions[event].push_back(cid); } + + void handleEventCallbacks(std::string event, JsonVariant &jsonObject, int originId) { + for (auto &callback : event_callbacks[event]) { + callback(jsonObject, originId); + } + } + + virtual void handleIncoming(const uint8_t *data, size_t len, int cid = 0) { + JsonDocument doc; +#if USE_MSGPACK + DeserializationError error = deserializeMsgPack(doc, data, len); +#else + DeserializationError error = deserializeJson(doc, data, len); +#endif + if (error) { + ESP_LOGE("Comm Base", "Failed to deserialize incoming: (%s)", error.c_str()); + return; + } + + JsonArray obj = doc.as(); // TODO: Make const + message_type_t type = static_cast(obj[0].as()); + + switch (type) { + case message_type_t::CONNECT: { + const char *event = obj[1].as(); + ESP_LOGI("Comm Base", "CONNECT topic: %s (cid=%d)", event, cid); + subscribe(event, cid); + break; + } + + case message_type_t::DISCONNECT: { + const char *event = obj[1].as(); + ESP_LOGI("Comm Base", "DISCONNECT topic: %s (cid=%d)", event, cid); + unsubscribe(event, cid); + break; + } + + case message_type_t::EVENT: { + const char *event = obj[1].as(); + JsonVariant payload = obj[2].as(); + handleEventCallbacks(event, payload, cid); + break; + } + case message_type_t::PING: { + ESP_LOGI("Comm Base", "PING (cid=%d)", cid); +#if USE_MSGPACK + static const uint8_t pong[] = {0x91, 0x04}; + send(pong, sizeof(pong), cid); +#else + send("[4]", cid); +#endif + break; + } + case message_type_t::PONG: ESP_LOGI("Comm Base", "PONG (cid=%d)", cid); break; + default: ESP_LOGW("Comm Base", "Unknown message type: %d", static_cast(type)); break; + } + + if (type == PONG) { + ESP_LOGV("EventSocket", "Pong"); + return; + } else if (type == PING) { + ESP_LOGV("EventSocket", "Ping"); + ping(cid); + return; + } + } + + void ping(int cid) { +#if USE_MSGPACK + const uint8_t out[] = {0x91, 0x04}; + send(out, sizeof(out), cid); +#else + const char *out = "[4]"; + send(out, strlen(out), cid); +#endif + } + + SemaphoreHandle_t mutex_; + std::map> client_subscriptions; + std::map> event_callbacks; + std::map> subscribe_callbacks; +}; \ No newline at end of file diff --git a/esp32/include/communication/websocket_adapter.h b/esp32/include/communication/websocket_adapter.h index efb3e15..3d659a0 100644 --- a/esp32/include/communication/websocket_adapter.h +++ b/esp32/include/communication/websocket_adapter.h @@ -8,23 +8,16 @@ #include #include -enum message_type_t { CONNECT = 0, DISCONNECT = 1, EVENT = 2, PING = 3, PONG = 4, BINARY_EVENT = 5 }; +#include -typedef std::function EventCallback; -typedef std::function SubscribeCallback; - -class EventSocket { +class Websocket : CommAdapterBase { public: - EventSocket(PsychicHttpServer &server, const char *route = "/api/ws"); + Websocket(PsychicHttpServer &server, const char *route = "/api/ws"); - void begin(); - - bool hasSubscribers(const char *event); + void begin() override; void onEvent(std::string event, EventCallback callback); - void onSubscribe(std::string event, SubscribeCallback callback); - void emit(const char *event, JsonVariant &payload, const char *originId = "", bool onlyToSameOrigin = false); private: @@ -32,16 +25,11 @@ class EventSocket { PsychicHttpServer &_server; const char *_route; - std::map> client_subscriptions; - std::map> event_callbacks; - std::map> subscribe_callbacks; - void handleEventCallbacks(std::string event, JsonVariant &jsonObject, int originId); - void send(PsychicWebSocketClient *client, const char *data, size_t len); - void handleSubscribeCallbacks(std::string event, const std::string &originId); - void onWSOpen(PsychicWebSocketClient *client); void onWSClose(PsychicWebSocketClient *client); esp_err_t onFrame(PsychicWebSocketRequest *request, httpd_ws_frame *frame); + + void send(const uint8_t *data, size_t len, int cid = -1) override; }; #endif diff --git a/esp32/src/communication/websocket_adapter.cpp b/esp32/src/communication/websocket_adapter.cpp index 3f7bb18..1f48cc9 100644 --- a/esp32/src/communication/websocket_adapter.cpp +++ b/esp32/src/communication/websocket_adapter.cpp @@ -1,177 +1,73 @@ #include #include -SemaphoreHandle_t clientSubscriptionsMutex = xSemaphoreCreateMutex(); +static const char *TAG = "Websocket"; -EventSocket::EventSocket(PsychicHttpServer &server, const char *route) : _server(server), _route(route) { - _socket.onOpen((std::bind(&EventSocket::onWSOpen, this, std::placeholders::_1))); - _socket.onClose(std::bind(&EventSocket::onWSClose, this, std::placeholders::_1)); - _socket.onFrame(std::bind(&EventSocket::onFrame, this, std::placeholders::_1, std::placeholders::_2)); +Websocket::Websocket(PsychicHttpServer &server, const char *route) : _server(server), _route(route) { + _socket.onOpen((std::bind(&Websocket::onWSOpen, this, std::placeholders::_1))); + _socket.onClose(std::bind(&Websocket::onWSClose, this, std::placeholders::_1)); + _socket.onFrame(std::bind(&Websocket::onFrame, this, std::placeholders::_1, std::placeholders::_2)); } -void EventSocket::begin() { _server.on(_route, &_socket); } +void Websocket::begin() { _server.on(_route, &_socket); } -void EventSocket::onWSOpen(PsychicWebSocketClient *client) { +void Websocket::onWSOpen(PsychicWebSocketClient *client) { ESP_LOGI("EventSocket", "ws[%s][%u] connect", client->remoteIP().toString().c_str(), client->socket()); + ping(client->socket()); } -void EventSocket::onWSClose(PsychicWebSocketClient *client) { - xSemaphoreTake(clientSubscriptionsMutex, portMAX_DELAY); +void Websocket::onWSClose(PsychicWebSocketClient *client) { + xSemaphoreTake(mutex_, portMAX_DELAY); for (auto &event_subscriptions : client_subscriptions) { event_subscriptions.second.remove(client->socket()); } - xSemaphoreGive(clientSubscriptionsMutex); + xSemaphoreGive(mutex_); ESP_LOGI("EventSocket", "ws[%s][%u] disconnect", client->remoteIP().toString().c_str(), client->socket()); } -esp_err_t EventSocket::onFrame(PsychicWebSocketRequest *request, httpd_ws_frame *frame) { - ESP_LOGV("EventSocket", "ws[%s][%u] opcode[%d]", request->client()->remoteIP().toString().c_str(), +esp_err_t Websocket::onFrame(PsychicWebSocketRequest *request, httpd_ws_frame *frame) { + ESP_LOGV(TAG, "ws[%s][%u] opcode[%d]", request->client()->remoteIP().toString().c_str(), request->client()->socket(), frame->type); - JsonDocument doc; + if (frame->type != HTTPD_WS_TYPE_TEXT && frame->type != HTTPD_WS_TYPE_BINARY) { + ESP_LOGE(TAG, "Unsupported frame type: %d", frame->type); + return ESP_OK; + } #if USE_MSGPACK - if (frame->type != HTTPD_WS_TYPE_BINARY) { - ESP_LOGE("EventSocket", "Unsupported frame type: %d", frame->type); - return ESP_OK; + if (frame->type == HTTPD_WS_TYPE_BINARY) { + handleIncoming(frame->payload, frame->len, request->client()->socket()); + } else { + ESP_LOGE(TAG, "Expected binary, got text"); } - if (deserializeMsgPack(doc, frame->payload, frame->len)) { - ESP_LOGE("EventSocket", "Could not deserialize msgpack"); - return ESP_OK; - }; #else - if (frame->type != HTTPD_WS_TYPE_TEXT) { - ESP_LOGE("EventSocket", "Unsupported frame type: %d", frame->type); - return ESP_OK; + if (frame->type == HTTPD_WS_TYPE_TEXT) { + handleIncoming(frame->payload, frame->len, request->client()->socket()); + } else { + ESP_LOGE(TAG, "Expected text, got binary"); } - if (deserializeJson(doc, frame->payload, frame->len)) { - ESP_LOGE("EventSocket", "Could not deserialize json"); - return ESP_OK; - }; #endif - auto msg = doc.as(); - - message_type_t message_type = static_cast(msg[0].as()); - - if (message_type == PONG) { - ESP_LOGV("EventSocket", "Pong"); - return ESP_OK; - } else if (message_type == PING) { - ESP_LOGV("EventSocket", "Ping"); -#if USE_MSGPACK - const uint8_t out[] = {0x91, 0x04}; - send(request->client(), reinterpret_cast(out), sizeof(out)); -#else - const char *out = "[4]"; - send(request->client(), out, strlen(out)); -#endif - return ESP_OK; - } - - const char *event = msg[1].as(); - - if (!event) { - ESP_LOGE("EventSocket", "Invalid event name"); - return ESP_OK; - } - - if (message_type == CONNECT) { - ESP_LOGV("EventSocket", "Connect: %s", event); - client_subscriptions[event].push_back(request->client()->socket()); - handleSubscribeCallbacks(event, std::to_string(request->client()->socket())); - } else if (message_type == DISCONNECT) { - ESP_LOGV("EventSocket", "Disconnect: %s", event); - client_subscriptions[event].remove(request->client()->socket()); - } else if (message_type == EVENT) { - JsonVariant payload = msg[2].as(); - handleEventCallbacks(event, payload, request->client()->socket()); - return ESP_OK; - } return ESP_OK; } -bool EventSocket::hasSubscribers(const char *event) { return !client_subscriptions[event].empty(); } - -void EventSocket::emit(const char *event, JsonVariant &payload, const char *originId, bool onlyToSameOrigin) { - int originSubscriptionId = originId[0] ? atoi(originId) : -1; - xSemaphoreTake(clientSubscriptionsMutex, portMAX_DELAY); - auto &subscriptions = client_subscriptions[event]; - if (subscriptions.empty()) { - xSemaphoreGive(clientSubscriptionsMutex); - return; - } - - JsonDocument doc; - auto a = doc.to(); - a.add(static_cast(message_type_t::EVENT)); - a.add(event); - a.add(payload); - +void Websocket::send(const uint8_t *data, size_t len, int cid) { + if (cid != -1) { + auto *client = _socket.getClient(cid); + if (client) { + ESP_LOGV(TAG, "Sending to client %s: %s", client->remoteIP().toString().c_str(), data); #if USE_MSGPACK - static char out[512]; - size_t len = serializeMsgPack(doc, out, sizeof(out)); - if (len == 0 || len >= sizeof(out)) { - xSemaphoreGive(clientSubscriptionsMutex); - ESP_LOGE("EventSocket", "Message payload bigger than buffer (%d <= %d)", sizeof(out), len); - return; - } - const char *data = out; + client->sendMessage(HTTPD_WS_TYPE_BINARY, data, len); #else - static char out[1024]; - size_t len = serializeJson(doc, out, sizeof(out)); - if (len == 0 || len >= sizeof(out)) { - xSemaphoreGive(clientSubscriptionsMutex); - ESP_LOGE("EventSocket", "Message payload bigger than buffer (%d <= %d)", sizeof(out), len); - return; - } - const char *data = out; + client->sendMessage(HTTPD_WS_TYPE_TEXT, data, len); #endif - - auto sendTo = [&](int id) { - if (auto *c = _socket.getClient(id)) { - send(c, data, len); - } else { - subscriptions.remove(id); } - }; - - if (onlyToSameOrigin && originSubscriptionId > 0) { - sendTo(originSubscriptionId); } else { - for (int id : subscriptions) { - if (id != originSubscriptionId) sendTo(id); - } - } - xSemaphoreGive(clientSubscriptionsMutex); -} - -void EventSocket::send(PsychicWebSocketClient *client, const char *data, size_t len) { - if (!client) return; - + ESP_LOGV(TAG, "Sending to all clients: %s", data); #if USE_MSGPACK - client->sendMessage(HTTPD_WS_TYPE_BINARY, data, len); + _socket.sendAll(HTTPD_WS_TYPE_BINARY, data, len); #else - client->sendMessage(HTTPD_WS_TYPE_TEXT, data, len); + _socket.sendAll(HTTPD_WS_TYPE_TEXT, data, len); #endif -} - -void EventSocket::handleEventCallbacks(std::string event, JsonVariant &jsonObject, int originId) { - for (auto &callback : event_callbacks[event]) { - callback(jsonObject, originId); } } - -void EventSocket::handleSubscribeCallbacks(std::string event, const std::string &originId) { - for (auto &callback : subscribe_callbacks[event]) { - callback(originId, true); - } -} - -void EventSocket::onEvent(std::string event, EventCallback callback) { - event_callbacks[event].push_back(std::move(callback)); -} - -void EventSocket::onSubscribe(std::string event, SubscribeCallback callback) { - subscribe_callbacks[event].push_back(std::move(callback)); -} \ No newline at end of file diff --git a/esp32/src/main.cpp b/esp32/src/main.cpp index 0995974..018a35a 100644 --- a/esp32/src/main.cpp +++ b/esp32/src/main.cpp @@ -21,7 +21,7 @@ // Communication PsychicHttpServer server; -EventSocket socket {server, "/api/ws"}; +Websocket socket {server, "/api/ws"}; // Core Peripherals peripherals;