From da48ed5b54dfe796dc330d295a5220c020fc66cb Mon Sep 17 00:00:00 2001 From: Dean Lee Date: Sun, 1 Jan 2023 06:59:24 +0800 Subject: [PATCH] vipc: add function to get available streams (#400) * add function to get available streams * add test case * cleanup * use set * public type * Update visionipc_server.cc * apply review Co-authored-by: Adeeb Shihadeh --- visionipc/visionipc_client.cc | 45 ++++++++++++++++++++++------------- visionipc/visionipc_client.h | 8 ++++--- visionipc/visionipc_server.cc | 13 ++++++++++ visionipc/visionipc_tests.cc | 11 +++++++++ 4 files changed, 58 insertions(+), 19 deletions(-) diff --git a/visionipc/visionipc_client.cc b/visionipc/visionipc_client.cc index 9d6fa9f..0b86d9d 100644 --- a/visionipc/visionipc_client.cc +++ b/visionipc/visionipc_client.cc @@ -8,6 +8,17 @@ #include "cereal/visionipc/visionipc_server.h" #include "cereal/logger/logger.h" +static int connect_to_vipc_server(const std::string &name, bool blocking) { + std::string path = "/tmp/visionipc_" + name; + int socket_fd = ipc_connect(path.c_str()); + while (socket_fd < 0 && blocking) { + std::cout << "VisionIpcClient connecting" << std::endl; + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + socket_fd = ipc_connect(path.c_str()); + } + return socket_fd; +} + VisionIpcClient::VisionIpcClient(std::string name, VisionStreamType type, bool conflate, cl_device_id device_id, cl_context ctx) : name(name), type(type), device_id(device_id), ctx(ctx) { msg_ctx = Context::create(); sock = SubSocket::create(msg_ctx, get_endpoint_name(name, type), "127.0.0.1", conflate, false); @@ -29,23 +40,10 @@ bool VisionIpcClient::connect(bool blocking){ num_buffers = 0; - // Connect to server socket and ask for all FDs of type - std::string path = "/tmp/visionipc_" + name; - - int socket_fd = -1; - while (socket_fd < 0) { - socket_fd = ipc_connect(path.c_str()); - - if (socket_fd < 0) { - if (blocking){ - std::cout << "VisionIpcClient connecting" << std::endl; - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - } else { - return false; - } - } + int socket_fd = connect_to_vipc_server(name, blocking); + if (socket_fd < 0) { + return false; } - // Send stream type to server to request FDs int r = ipc_sendrecv_with_fds(true, socket_fd, &type, sizeof(type), nullptr, 0, nullptr); assert(r == sizeof(type)); @@ -114,7 +112,22 @@ VisionBuf * VisionIpcClient::recv(VisionIpcBufExtra * extra, const int timeout_m return buf; } +std::set VisionIpcClient::getAvailableStreams(const std::string &name, bool blocking) { + int socket_fd = connect_to_vipc_server(name, blocking); + if (socket_fd < 0) { + return {}; + } + // Send VISION_STREAM_MAX to server to request available streams + int request = VISION_STREAM_MAX; + int r = ipc_sendrecv_with_fds(true, socket_fd, &request, sizeof(request), nullptr, 0, nullptr); + assert(r == sizeof(request)); + VisionStreamType available_streams[VISION_STREAM_MAX] = {}; + r = ipc_sendrecv_with_fds(false, socket_fd, &available_streams, sizeof(available_streams), nullptr, 0, nullptr); + assert(r >= sizeof(VisionStreamType) && r % sizeof(VisionStreamType) == 0); + close(socket_fd); + return std::set(available_streams, available_streams + r / sizeof(VisionStreamType)); +} VisionIpcClient::~VisionIpcClient(){ for (size_t i = 0; i < num_buffers; i++){ diff --git a/visionipc/visionipc_client.h b/visionipc/visionipc_client.h index c11ee23..8dff9b2 100644 --- a/visionipc/visionipc_client.h +++ b/visionipc/visionipc_client.h @@ -1,6 +1,8 @@ #pragma once -#include + +#include #include +#include #include #include "cereal/messaging/messaging.h" @@ -14,8 +16,6 @@ private: SubSocket * sock; Poller * poller; - VisionStreamType type; - cl_device_id device_id = nullptr; cl_context ctx = nullptr; @@ -23,6 +23,7 @@ private: public: bool connected = false; + VisionStreamType type; int num_buffers = 0; VisionBuf buffers[VISIONIPC_MAX_FDS]; VisionIpcClient(std::string name, VisionStreamType type, bool conflate, cl_device_id device_id=nullptr, cl_context ctx=nullptr); @@ -30,4 +31,5 @@ public: VisionBuf * recv(VisionIpcBufExtra * extra=nullptr, const int timeout_ms=100); bool connect(bool blocking=true); bool is_connected() { return connected; } + static std::set getAvailableStreams(const std::string &name, bool blocking = true); }; diff --git a/visionipc/visionipc_server.cc b/visionipc/visionipc_server.cc index 4d34cfe..ccf70be 100644 --- a/visionipc/visionipc_server.cc +++ b/visionipc/visionipc_server.cc @@ -111,6 +111,19 @@ void VisionIpcServer::listener(){ VisionStreamType type = VisionStreamType::VISION_STREAM_MAX; int r = ipc_sendrecv_with_fds(false, fd, &type, sizeof(type), nullptr, 0, nullptr); assert(r == sizeof(type)); + + // send available stream types + if (type == VisionStreamType::VISION_STREAM_MAX) { + std::vector available_stream_types; + for (auto& [stream_type, _] : buffers) { + available_stream_types.push_back(stream_type); + } + r = ipc_sendrecv_with_fds(true, fd, available_stream_types.data(), available_stream_types.size() * sizeof(VisionStreamType), nullptr, 0, nullptr); + assert(r == available_stream_types.size() * sizeof(VisionStreamType)); + close(fd); + continue; + } + if (buffers.count(type) <= 0) { std::cout << "got request for invalid buffer type: " << type << std::endl; close(fd); diff --git a/visionipc/visionipc_tests.cc b/visionipc/visionipc_tests.cc index 55a4395..4a081df 100644 --- a/visionipc/visionipc_tests.cc +++ b/visionipc/visionipc_tests.cc @@ -22,6 +22,17 @@ TEST_CASE("Connecting"){ REQUIRE(client.connected); } +TEST_CASE("getAvailableStreams"){ + VisionIpcServer server("camerad"); + server.create_buffers(VISION_STREAM_ROAD, 1, false, 100, 100); + server.create_buffers(VISION_STREAM_WIDE_ROAD, 1, false, 100, 100); + server.start_listener(); + auto available_streams = VisionIpcClient::getAvailableStreams("camerad"); + REQUIRE(available_streams.size() == 2); + REQUIRE(available_streams.count(VISION_STREAM_ROAD) == 1); + REQUIRE(available_streams.count(VISION_STREAM_WIDE_ROAD) == 1); +} + TEST_CASE("Check buffers"){ size_t width = 100, height = 200, num_buffers = 5; VisionIpcServer server("camerad");