♻️ Updates combase to use protobufs

This commit is contained in:
Rune Harlyk
2026-01-03 12:37:05 +01:00
committed by nikguin04
parent 28bb35d104
commit c0c13754f4
6 changed files with 347 additions and 169 deletions
+146 -121
View File
@@ -1,15 +1,85 @@
#pragma once
#include <ArduinoJson.h>
#include <functional>
#include <pb_encode.h>
#include <pb_decode.h>
#include <platform_shared/websocket_message.pb.h>
#include <list>
#include <map>
#include <type_traits>
#include <communication/proto_helpers.h>
enum message_type_t { CONNECT = 0, DISCONNECT = 1, EVENT = 2, PING = 3, PONG = 4, BINARY_EVENT = 5 };
template <typename T>
struct MessageTraits;
typedef std::function<void(JsonVariant &root, int originId)> EventCallback;
typedef std::function<void(const std::string &originId, bool sync)> SubscribeCallback;
template <>
struct MessageTraits<socket_message_IMUData> {
static constexpr pb_size_t tag = socket_message_WebsocketMessage_imu_tag;
static void assign(socket_message_WebsocketMessage& msg, const socket_message_IMUData& data) {
msg.message.imu = data;
}
};
template <>
struct MessageTraits<socket_message_ModeData> {
static constexpr pb_size_t tag = socket_message_WebsocketMessage_mode_tag;
static void assign(socket_message_WebsocketMessage& msg, const socket_message_ModeData& data) {
msg.message.mode = data;
}
};
template <>
struct MessageTraits<socket_message_AnalyticsData> {
static constexpr pb_size_t tag = socket_message_WebsocketMessage_analytics_tag;
static void assign(socket_message_WebsocketMessage& msg, const socket_message_AnalyticsData& data) {
msg.message.analytics = data;
}
};
template <>
struct MessageTraits<socket_message_AnglesData> {
static constexpr pb_size_t tag = socket_message_WebsocketMessage_angles_tag;
static void assign(socket_message_WebsocketMessage& msg, const socket_message_AnglesData& data) {
msg.message.angles = data;
}
};
template <>
struct MessageTraits<socket_message_RSSIData> {
static constexpr pb_size_t tag = socket_message_WebsocketMessage_rssi_tag;
static void assign(socket_message_WebsocketMessage& msg, const socket_message_RSSIData& data) {
msg.message.rssi = data;
}
};
template <>
struct MessageTraits<socket_message_KinematicData> {
static constexpr pb_size_t tag = socket_message_WebsocketMessage_kinematic_data_tag;
static void assign(socket_message_WebsocketMessage& msg, const socket_message_KinematicData& data) {
msg.message.kinematic_data = data;
}
};
template <>
struct MessageTraits<socket_message_IMUCalibrateData> {
static constexpr pb_size_t tag = socket_message_WebsocketMessage_imu_calibrate_tag;
static void assign(socket_message_WebsocketMessage& msg, const socket_message_IMUCalibrateData& data) {
msg.message.imu_calibrate = data;
}
};
template <>
struct MessageTraits<socket_message_I2CScanData> {
static constexpr pb_size_t tag = socket_message_WebsocketMessage_i2c_scan_tag;
static void assign(socket_message_WebsocketMessage& msg, const socket_message_I2CScanData& data) {
msg.message.i2c_scan = data;
}
};
template <>
struct MessageTraits<socket_message_PeripheralSettingsData> {
static constexpr pb_size_t tag = socket_message_WebsocketMessage_peripheral_settings_tag;
static void assign(socket_message_WebsocketMessage& msg, const socket_message_PeripheralSettingsData& data) {
msg.message.peripheral_settings = data;
}
};
class CommAdapterBase {
public:
@@ -18,142 +88,97 @@ class CommAdapterBase {
virtual void begin() {}
bool hasSubscribers(const char *event) { return !client_subscriptions[event].empty(); }
void onEvent(std::string event, EventCallback callback) { event_callbacks[event].push_back(std::move(callback)); }
void onSubscribe(std::string event, SubscribeCallback callback) {
subscribe_callbacks[event].push_back(std::move(callback));
bool hasSubscribers(int32_t tag) {
xSemaphoreTake(mutex_, portMAX_DELAY);
bool result = !client_subscriptions_[tag].empty();
xSemaphoreGive(mutex_);
return result;
}
void emit(const char *event, JsonVariant &payload, const char *originId = "", bool onlyToSameOrigin = false) {
int originSubscriptionId = originId[0] ? atoi(originId) : -1;
xSemaphoreTake(mutex_, portMAX_DELAY);
auto &subscriptions = client_subscriptions[event];
if (subscriptions.empty()) {
xSemaphoreGive(mutex_);
ProtoDecoder& decoder() { return decoder_; }
template <typename T>
void emit(const T& data, int clientId = -1) {
constexpr pb_size_t tag = MessageTraits<T>::tag;
if (clientId < 0 && !hasSubscribers(tag)) return;
msg_.which_message = tag;
MessageTraits<T>::assign(msg_, data);
pb_ostream_t stream = pb_ostream_from_buffer(buffer_, sizeof(buffer_));
if (!pb_encode(&stream, socket_message_WebsocketMessage_fields, &msg_)) {
return;
}
JsonDocument doc;
JsonArray array = doc.to<JsonArray>();
array.add(static_cast<uint8_t>(message_type_t::EVENT));
array.add(event);
array.add(payload);
#if USE_MSGPACK
std::string bin;
serializeMsgPack(doc, bin);
xSemaphoreGive(mutex_);
send(reinterpret_cast<const uint8_t *>(bin.data()), bin.size(), -1);
#else
String out;
serializeJson(doc, out);
xSemaphoreGive(mutex_);
send(out.c_str(), -1);
#endif
}
void send_wsm_by_function( void (*setmsg)(socket_message_WebsocketMessage* message), int cid ) {
setmsg(&msg);
send_wsm(&msg, cid);
if (clientId >= 0) {
send(buffer_, stream.bytes_written, clientId);
} else {
sendToSubscribers(tag, buffer_, stream.bytes_written);
}
}
protected:
socket_message_WebsocketMessage msg;
uint8_t data_buffer[512];
void send(const char *data, int cid = -1) { send(reinterpret_cast<const uint8_t *>(data), strlen(data), cid); }
virtual void send(const uint8_t *data, size_t len, int cid = -1) = 0;
void send_wsm(socket_message_WebsocketMessage* message, int cid) {
pb_ostream_t ostream = pb_ostream_from_buffer(data_buffer, sizeof(data_buffer));
// Encode the message
bool ostatus = pb_encode(&ostream, &socket_message_WebsocketMessage_msg, message);
virtual void send(const uint8_t* data, size_t len, int cid = -1) = 0;
if (!ostatus) {
// TODO: Make a re-encoder using malloc instead (which increases exponentially but only if the error is the buffer size)
printf("Encoding of socket message failed: %s\n", PB_GET_ERROR(&ostream));
return;
}
send(data_buffer, ostream.bytes_written, cid);
}
void subscribe(const char *event, int cid = 0) {
void subscribe(int32_t tag, int cid = 0) {
xSemaphoreTake(mutex_, portMAX_DELAY);
client_subscriptions[event].push_back(cid);
client_subscriptions_[tag].push_back(cid);
xSemaphoreGive(mutex_);
ESP_LOGI("ProtoComm", "Client %d subscribed to tag %d", cid, (int)tag);
}
void unsubscribe(const char *event, int cid = 0) {
void unsubscribe(int32_t tag, int cid = 0) {
xSemaphoreTake(mutex_, portMAX_DELAY);
client_subscriptions[event].remove(cid);
client_subscriptions_[tag].remove(cid);
xSemaphoreGive(mutex_);
ESP_LOGI("ProtoComm", "Client %d unsubscribed from tag %d", cid, (int)tag);
}
void removeClient(int cid) {
xSemaphoreTake(mutex_, portMAX_DELAY);
for (auto& [tag, clients] : client_subscriptions_) {
clients.remove(cid);
}
xSemaphoreGive(mutex_);
}
void handleEventCallbacks(std::string event, JsonVariant &jsonObject, int originId) {
for (auto &callback : event_callbacks[event]) {
callback(jsonObject, originId);
void handleIncoming(const uint8_t* data, size_t len, int cid) {
if (!decoder_.decode(data, len, cid)) {
ESP_LOGE("ProtoComm", "Failed to decode incoming message from client %d", cid);
}
}
virtual void handleIncoming(const uint8_t *data, size_t len, int cid = 0) {
JsonDocument doc;
#if USE_MSGPACK
DeserializationError error = deserializeMsgPack(doc, data, len);
#else
DeserializationError error = deserializeJson(doc, data, len);
#endif
if (error) {
ESP_LOGE("Comm Base", "Failed to deserialize incoming: (%s)", error.c_str());
return;
}
JsonArray obj = doc.as<JsonArray>(); // TODO: Make const
message_type_t type = static_cast<message_type_t>(obj[0].as<uint8_t>());
switch (type) {
case message_type_t::CONNECT: {
const char *event = obj[1].as<const char *>();
ESP_LOGI("Comm Base", "CONNECT topic: %s (cid=%d)", event, cid);
subscribe(event, cid);
break;
}
case message_type_t::DISCONNECT: {
const char *event = obj[1].as<const char *>();
ESP_LOGI("Comm Base", "DISCONNECT topic: %s (cid=%d)", event, cid);
unsubscribe(event, cid);
break;
}
case message_type_t::EVENT: {
const char *event = obj[1].as<const char *>();
JsonVariant payload = obj[2].as<JsonVariant>();
handleEventCallbacks(event, payload, cid);
break;
}
case message_type_t::PING: {
ESP_LOGI("Comm Base", "PING (cid=%d)", cid);
ping(cid);
break;
}
case message_type_t::PONG: ESP_LOGI("Comm Base", "PONG (cid=%d)", cid); break;
default: ESP_LOGW("Comm Base", "Unknown message type: %d", static_cast<int>(type)); break;
void sendPong(int cid) {
uint8_t pongBuffer[16];
msg_.which_message = socket_message_WebsocketMessage_pongmsg_tag;
msg_.message.pongmsg = socket_message_PongMsg_init_zero;
pb_ostream_t stream = pb_ostream_from_buffer(pongBuffer, sizeof(pongBuffer));
if (pb_encode(&stream, socket_message_WebsocketMessage_fields, &msg_)) {
send(pongBuffer, stream.bytes_written, cid);
}
}
void ping(int cid) {
#if USE_MSGPACK
static const uint8_t pong[] = {0x91, 0x04};
send(pong, sizeof(pong), cid);
#else
send("[4]", cid);
#endif
void setupDecoderHandlers() {
decoder_.onSubscribe([this](int32_t tag, int cid) { subscribe(tag, cid); });
decoder_.onUnsubscribe([this](int32_t tag, int cid) { unsubscribe(tag, cid); });
decoder_.onPing([this](int cid) { sendPong(cid); });
}
SemaphoreHandle_t mutex_;
std::map<std::string, std::list<int>> client_subscriptions;
std::map<std::string, std::list<EventCallback>> event_callbacks;
std::map<std::string, std::list<SubscribeCallback>> subscribe_callbacks;
};
std::map<int32_t, std::list<int>> client_subscriptions_;
ProtoDecoder decoder_;
socket_message_WebsocketMessage msg_ = socket_message_WebsocketMessage_init_zero;
uint8_t buffer_[PROTO_BUFFER_SIZE];
private:
void sendToSubscribers(int32_t tag, const uint8_t* data, size_t len) {
xSemaphoreTake(mutex_, portMAX_DELAY);
for (int cid : client_subscriptions_[tag]) {
send(data, len, cid);
}
xSemaphoreGive(mutex_);
}
};
+109
View File
@@ -0,0 +1,109 @@
#pragma once
#include <pb_encode.h>
#include <pb_decode.h>
#include <platform_shared/websocket_message.pb.h>
#include <functional>
#define PROTO_BUFFER_SIZE 512
class ProtoDecoder {
public:
using SubscribeHandler = std::function<void(int32_t tag, int clientId)>;
using UnsubscribeHandler = std::function<void(int32_t tag, int clientId)>;
using PingHandler = std::function<void(int clientId)>;
using ModeHandler = std::function<void(const socket_message_ModeData& data, int clientId)>;
using InputHandler = std::function<void(const socket_message_HumanInputData& data, int clientId)>;
using AnglesHandler = std::function<void(const socket_message_AnglesData& data, int clientId)>;
using KinematicHandler = std::function<void(const socket_message_KinematicData& data, int clientId)>;
using WalkGaitHandler = std::function<void(const socket_message_WalkGaitData& data, int clientId)>;
using IMUCalibrateExecHandler = std::function<void(int clientId)>;
using I2CScanRequestHandler = std::function<void(int clientId)>;
using PeripheralSettingsRequestHandler = std::function<void(int clientId)>;
void onSubscribe(SubscribeHandler handler) { subscribeHandler = handler; }
void onUnsubscribe(UnsubscribeHandler handler) { unsubscribeHandler = handler; }
void onPing(PingHandler handler) { pingHandler = handler; }
void onMode(ModeHandler handler) { modeHandler = handler; }
void onInput(InputHandler handler) { inputHandler = handler; }
void onAngles(AnglesHandler handler) { anglesHandler = handler; }
void onKinematic(KinematicHandler handler) { kinematicHandler = handler; }
void onWalkGait(WalkGaitHandler handler) { walkGaitHandler = handler; }
void onIMUCalibrateExec(IMUCalibrateExecHandler handler) { imuCalibrateExecHandler = handler; }
void onI2CScanRequest(I2CScanRequestHandler handler) { i2cScanRequestHandler = handler; }
void onPeripheralSettingsRequest(PeripheralSettingsRequestHandler handler) {
peripheralSettingsRequestHandler = handler;
}
bool decode(const uint8_t* data, size_t len, int clientId) {
pb_istream_t stream = pb_istream_from_buffer(data, len);
if (!pb_decode(&stream, socket_message_WebsocketMessage_fields, &msg_)) {
return false;
}
switch (msg_.which_message) {
case socket_message_WebsocketMessage_sub_notif_tag:
if (subscribeHandler) subscribeHandler(msg_.message.sub_notif.tag, clientId);
break;
case socket_message_WebsocketMessage_unsub_notif_tag:
if (unsubscribeHandler) unsubscribeHandler(msg_.message.unsub_notif.tag, clientId);
break;
case socket_message_WebsocketMessage_pingmsg_tag:
if (pingHandler) pingHandler(clientId);
break;
case socket_message_WebsocketMessage_mode_tag:
if (modeHandler) modeHandler(msg_.message.mode, clientId);
break;
case socket_message_WebsocketMessage_human_input_data_tag:
if (inputHandler) inputHandler(msg_.message.human_input_data, clientId);
break;
case socket_message_WebsocketMessage_angles_tag:
if (anglesHandler) anglesHandler(msg_.message.angles, clientId);
break;
case socket_message_WebsocketMessage_kinematic_data_tag:
if (kinematicHandler) kinematicHandler(msg_.message.kinematic_data, clientId);
break;
case socket_message_WebsocketMessage_walk_gait_tag:
if (walkGaitHandler) walkGaitHandler(msg_.message.walk_gait, clientId);
break;
case socket_message_WebsocketMessage_imu_calibrate_execute_tag:
if (imuCalibrateExecHandler) imuCalibrateExecHandler(clientId);
break;
case socket_message_WebsocketMessage_i2c_scan_data_request_tag:
if (i2cScanRequestHandler) i2cScanRequestHandler(clientId);
break;
case socket_message_WebsocketMessage_peripheral_settings_data_request_tag:
if (peripheralSettingsRequestHandler) peripheralSettingsRequestHandler(clientId);
break;
default: return false;
}
return true;
}
private:
socket_message_WebsocketMessage msg_ = socket_message_WebsocketMessage_init_zero;
SubscribeHandler subscribeHandler;
UnsubscribeHandler unsubscribeHandler;
PingHandler pingHandler;
ModeHandler modeHandler;
InputHandler inputHandler;
AnglesHandler anglesHandler;
KinematicHandler kinematicHandler;
WalkGaitHandler walkGaitHandler;
IMUCalibrateExecHandler imuCalibrateExecHandler;
I2CScanRequestHandler i2cScanRequestHandler;
PeripheralSettingsRequestHandler peripheralSettingsRequestHandler;
};
@@ -1,5 +1,4 @@
#ifndef Socket_h
#define Socket_h
#pragma once
#include <PsychicHttp.h>
#include <template/stateful_service.h>
@@ -16,12 +15,6 @@ class Websocket : public CommAdapterBase {
void begin() override;
void onEvent(std::string event, EventCallback callback);
void emit(const char *event, JsonVariant &payload, const char *originId = "", bool onlyToSameOrigin = false);
void emit_raw(const char *event, uint8_t* payload, size_t event_length, size_t payload_length);
private:
PsychicWebSocketHandler _socket;
PsychicHttpServer &_server;
@@ -33,6 +26,3 @@ class Websocket : public CommAdapterBase {
void send(const uint8_t *data, size_t len, int cid = -1) override;
};
#endif
+5
View File
@@ -10,6 +10,7 @@
#include <features.h>
#include <settings/peripherals_settings.h>
#include <template/stateful_endpoint.h>
#include <platform_shared/websocket_message.pb.h>
#include <list>
#include <SPI.h>
@@ -47,11 +48,15 @@ class Peripherals : public StatefulService<PeripheralsConfiguration> {
void scanI2C(uint8_t lower = 1, uint8_t higher = 127);
void getI2CResult(JsonVariant &root);
void getI2CResultProto(socket_message_I2CScanData &data);
void getIMUResult(JsonVariant &root);
void getIMUProto(socket_message_IMUData &data);
void getSonarResult(JsonVariant &root);
void getSettingsProto(socket_message_PeripheralSettingsData &data);
/* IMU FUNCTIONS */
bool readImu();