Fix stopping all streams when just one should stop

This commit is contained in:
loki 2020-02-12 11:28:27 +01:00
parent bb95d6ab52
commit da246d6417
7 changed files with 324 additions and 129 deletions

View File

@ -134,7 +134,6 @@ set(SUNSHINE_TARGET_FILES
sunshine/stream.h
sunshine/video.cpp
sunshine/video.h
sunshine/thread_safe.h
sunshine/input.cpp
sunshine/input.h
sunshine/audio.cpp
@ -147,6 +146,8 @@ set(SUNSHINE_TARGET_FILES
sunshine/move_by_copy.h
sunshine/task_pool.h
sunshine/thread_pool.h
sunshine/thread_safe.h
sunshine/sync.h
${PLATFORM_TARGET_FILES})
include_directories(

View File

@ -95,10 +95,10 @@ std::string_view to_enum_string(net_e net) {
}
host_t host_create(ENetAddress &addr, std::size_t peers, std::uint16_t port) {
enet_address_set_host(&addr, "0.0.0.0");
enet_address_set_host(&addr, "::");
enet_address_set_port(&addr, port);
return host_t { enet_host_create(PF_INET, &addr, peers, 1, 0, 0) };
return host_t { enet_host_create(AF_INET6, &addr, peers, 1, 0, 0) };
}
void free_host(ENetHost *host) {

View File

@ -88,6 +88,7 @@ using input_t = util::safe_ptr<void, freeInput>;
std::string get_mac_address(const std::string_view &address);
std::string from_sockaddr(const sockaddr *const);
std::pair<std::uint16_t, std::string> from_sockaddr_ex(const sockaddr *const);
std::unique_ptr<mic_t> microphone(std::uint32_t sample_rate);
std::unique_ptr<display_t> display();

View File

@ -391,6 +391,24 @@ std::string from_sockaddr(const sockaddr *const ip_addr) {
return std::string { data };
}
std::pair<std::uint16_t, std::string> from_sockaddr_ex(const sockaddr *const ip_addr) {
char data[INET6_ADDRSTRLEN];
auto family = ip_addr->sa_family;
std::uint16_t port;
if(family == AF_INET6) {
inet_ntop(AF_INET6, &((sockaddr_in6*)ip_addr)->sin6_addr, data, INET6_ADDRSTRLEN);
port = ((sockaddr_in6*)ip_addr)->sin6_port;
}
if(family == AF_INET) {
inet_ntop(AF_INET, &((sockaddr_in*)ip_addr)->sin_addr, data, INET_ADDRSTRLEN);
port = ((sockaddr_in*)ip_addr)->sin_port;
}
return { port, std::string { data } };
}
std::string get_mac_address(const std::string_view &address) {
auto ifaddrs = get_ifaddrs();
for(auto pos = ifaddrs.get(); pos != nullptr; pos = pos->ifa_next) {

View File

@ -101,6 +101,24 @@ std::string from_sockaddr(const sockaddr *const socket_address) {
return std::string { data };
}
std::pair<std::uint16_t, std::string> from_sockaddr_ex(const sockaddr *const ip_addr) {
char data[INET6_ADDRSTRLEN];
auto family = ip_addr->sa_family;
std::uint16_t port;
if(family == AF_INET6) {
inet_ntop(AF_INET6, &((sockaddr_in6*)ip_addr)->sin6_addr, data, INET6_ADDRSTRLEN);
port = ((sockaddr_in6*)ip_addr)->sin6_port;
}
if(family == AF_INET) {
inet_ntop(AF_INET, &((sockaddr_in*)ip_addr)->sin_addr, data, INET_ADDRSTRLEN);
port = ((sockaddr_in*)ip_addr)->sin_port;
}
return { port, std::string { data } };
}
adapteraddrs_t get_adapteraddrs() {
adapteraddrs_t info { nullptr };
ULONG size = 0;

View File

@ -20,6 +20,7 @@ extern "C" {
#include "utility.h"
#include "stream.h"
#include "thread_safe.h"
#include "sync.h"
#include "input.h"
#include "main.h"
@ -89,14 +90,109 @@ using audio_packet_t = util::c_ptr<audio_packet_raw_t>;
using message_queue_t = std::shared_ptr<safe::queue_t<std::pair<std::uint16_t, std::string>>>;
using message_queue_queue_t = std::shared_ptr<safe::queue_t<std::tuple<socket_e, asio::ip::address, message_queue_t>>>;
using session_queue_t = std::shared_ptr<safe::queue_t<std::pair<std::string, session_t*>>>;
static inline void while_starting_do_nothing(std::atomic<session::state_e> &state) {
while(state.load(std::memory_order_acquire) == session::state_e::STARTING) {
std::this_thread::sleep_for(1ms);
}
}
class control_server_t {
public:
control_server_t(control_server_t &&) noexcept = default;
control_server_t &operator=(control_server_t &&) noexcept = default;
explicit control_server_t(std::uint16_t port) : _host { net::host_create(_addr, config::stream.channels, port) } {}
void emplace_addr_to_session(const std::string &addr, session_t &session) {
auto lg = _map_addr_session.lock();
_map_addr_session->emplace(addr, std::make_pair(0u, &session));
}
void erase_session(session_t &session) {
auto lg = _map_addr_session.lock();
auto pos = std::find_if(std::begin(_map_addr_session.raw), std::end(_map_addr_session.raw), [session_p=&session](auto &current_port_and_session) {
return session_p == current_port_and_session.second.second;
});
_map_addr_session->erase(pos);
}
// Get session associated with address.
// If none are found, try to find a session not yet claimed. (It will be marked by a port of value 0
// If none of those are found, return nullptr
session_t *get_session(const ENetAddress &address) {
TUPLE_2D(port, addr_string, platf::from_sockaddr_ex((sockaddr*)&address.address));
auto lg = _map_addr_session.lock();
TUPLE_2D(begin, end, _map_addr_session->equal_range(addr_string));
auto it = std::end(_map_addr_session.raw);
for(auto pos = begin; pos != end; ++pos) {
TUPLE_2D_REF(session_port, session_p, pos->second);
if(port == session_port) {
return session_p;
}
else if(session_port == 0) {
it = pos;
}
}
if(it != std::end(_map_addr_session.raw)) {
TUPLE_2D_REF(session_port, session_p, it->second);
session_port = port;
return session_p;
}
return nullptr;
}
// Circular dependency:
// iterate refers to session
// session refers to broadcast_ctx_t
// broadcast_ctx_t refers to control_server_t
// Therefore, iterate is implemented further down the source file
void iterate(std::chrono::milliseconds timeout);
template<class T, class X>
void iterate(std::chrono::duration<T, X> timeout) {
iterate(std::chrono::floor<std::chrono::milliseconds>(timeout));
}
void map(uint16_t type, std::function<void(session_t *, const std::string_view&)> cb) {
_map_type_cb.emplace(type, std::move(cb));
}
void send(const std::string_view &payload) {
std::for_each(_host->peers, _host->peers + _host->peerCount, [payload](auto &peer) {
auto packet = enet_packet_create(payload.data(), payload.size(), ENET_PACKET_FLAG_RELIABLE);
if(enet_peer_send(&peer, 0, packet)) {
enet_packet_destroy(packet);
}
});
enet_host_flush(_host.get());
}
// Callbacks
std::unordered_map<std::uint16_t, std::function<void(session_t *, const std::string_view&)>> _map_type_cb;
// Mapping ip:port to session
util::sync_t<std::unordered_multimap<std::string, std::pair<std::uint16_t, session_t*>>> _map_addr_session;
ENetAddress _addr;
net::host_t _host;
};
struct broadcast_ctx_t {
video::packet_queue_t video_packets;
audio::packet_queue_t audio_packets;
message_queue_queue_t message_queue_queue;
session_queue_t session_queue;
std::thread recv_thread;
std::thread video_thread;
@ -105,8 +201,9 @@ struct broadcast_ctx_t {
asio::io_service io;
udp::socket video_sock { io, udp::endpoint(udp::v4(), VIDEO_STREAM_PORT) };
udp::socket audio_sock { io, udp::endpoint(udp::v4(), AUDIO_STREAM_PORT) };
udp::socket video_sock { io, udp::endpoint(udp::v6(), VIDEO_STREAM_PORT) };
udp::socket audio_sock { io, udp::endpoint(udp::v6(), AUDIO_STREAM_PORT) };
control_server_t control_server { CONTROL_PORT };
};
struct session_t {
@ -118,6 +215,7 @@ struct session_t {
std::chrono::steady_clock::time_point pingTimeout;
safe::shared_t<broadcast_ctx_t>::ptr_t broadcast_ref;
udp::endpoint video_peer;
udp::endpoint audio_peer;
@ -138,101 +236,57 @@ std::shared_ptr<input::input_t> input;
static auto broadcast = safe::make_shared<broadcast_ctx_t>(start_broadcast, end_broadcast);
safe::signal_t broadcast_shutdown_event;
class control_server_t {
public:
control_server_t(control_server_t &&) noexcept = default;
control_server_t &operator=(control_server_t &&) noexcept = default;
void control_server_t::iterate(std::chrono::milliseconds timeout) {
ENetEvent event;
auto res = enet_host_service(_host.get(), &event, timeout.count());
explicit control_server_t(session_queue_t session_queue, std::uint16_t port) : session_queue { session_queue }, _host { net::host_create(_addr, config::stream.channels, port) } {}
if(res > 0) {
auto session = get_session(event.peer->address);
if(!session) {
BOOST_LOG(warning) << "Rejected connection from ["sv << platf::from_sockaddr((sockaddr*)&event.peer->address.address) << "]: it's not properly set up"sv;
enet_peer_disconnect_now(event.peer, 0);
void populate_addr_to_session() {
while(session_queue->peek()) {
auto session_opt = session_queue->pop();
if(!session_opt) {
break;
}
TUPLE_2D_REF(addr_string, session, *session_opt);
if(session) {
_map_addr_session.try_emplace(addr_string, session).second;
}
else {
_map_addr_session.erase(addr_string);
}
return;
}
}
template<class T, class X>
void iterate(std::chrono::duration<T, X> timeout) {
ENetEvent event;
auto res = enet_host_service(_host.get(), &event, std::chrono::floor<std::chrono::milliseconds>(timeout).count());
session->pingTimeout = std::chrono::steady_clock::now() + config::stream.ping_timeout;
populate_addr_to_session();
if(res > 0) {
auto addr_string = platf::from_sockaddr((sockaddr*)&event.peer->address.address);
switch(event.type) {
case ENET_EVENT_TYPE_RECEIVE:
{
net::packet_t packet { event.packet };
auto it = _map_addr_session.find(addr_string);
if(it == std::end(_map_addr_session)) {
BOOST_LOG(warning) << "Rejected connection from ["sv << addr_string << "]: it's not properly set up"sv;
enet_peer_disconnect_now(event.peer, 0);
return;
}
auto &session = it->second;
session->pingTimeout = std::chrono::steady_clock::now() + config::stream.ping_timeout;
switch(event.type) {
case ENET_EVENT_TYPE_RECEIVE:
{
net::packet_t packet { event.packet };
std::uint16_t *type = (std::uint16_t *)packet->data;
std::string_view payload { (char*)packet->data + sizeof(*type), packet->dataLength - sizeof(*type) };
std::uint16_t *type = (std::uint16_t *)packet->data;
std::string_view payload { (char*)packet->data + sizeof(*type), packet->dataLength - sizeof(*type) };
auto cb = _map_type_cb.find(*type);
if(cb == std::end(_map_type_cb)) {
BOOST_LOG(warning)
<< "type [Unknown] { "sv << util::hex(*type).to_string_view() << " }"sv << std::endl
<< "---data---"sv << std::endl << util::hex_vec(payload) << std::endl << "---end data---"sv;
}
else {
cb->second(session, payload);
}
auto cb = _map_type_cb.find(*type);
if(cb == std::end(_map_type_cb)) {
BOOST_LOG(warning)
<< "type [Unknown] { "sv << util::hex(*type).to_string_view() << " }"sv << std::endl
<< "---data---"sv << std::endl << util::hex_vec(payload) << std::endl << "---end data---"sv;
}
else {
cb->second(session, payload);
}
break;
case ENET_EVENT_TYPE_CONNECT:
BOOST_LOG(info) << "CLIENT CONNECTED"sv;
break;
case ENET_EVENT_TYPE_DISCONNECT:
BOOST_LOG(info) << "CLIENT DISCONNECTED"sv;
// No more clients to send video data to ^_^
if(session->state == session::state_e::RUNNING) {
session::stop(*session);
}
break;
case ENET_EVENT_TYPE_NONE:
break;
}
break;
case ENET_EVENT_TYPE_CONNECT:
BOOST_LOG(info) << "CLIENT CONNECTED"sv;
break;
case ENET_EVENT_TYPE_DISCONNECT:
BOOST_LOG(info) << "CLIENT DISCONNECTED"sv;
// No more clients to send video data to ^_^
if(session->state == session::state_e::RUNNING) {
session::stop(*session);
}
break;
case ENET_EVENT_TYPE_NONE:
break;
}
}
void map(uint16_t type, std::function<void(session_t *, const std::string_view&)> cb) {
_map_type_cb.emplace(type, std::move(cb));
}
void send(const std::string_view &payload);
std::unordered_map<std::uint16_t, std::function<void(session_t *, const std::string_view&)>> _map_type_cb;
std::unordered_map<std::string, session_t*> _map_addr_session;
session_queue_t session_queue;
ENetAddress _addr;
net::host_t _host;
};
}
namespace fec {
using rs_t = util::safe_ptr<reed_solomon, reed_solomon_release>;
@ -338,29 +392,16 @@ std::vector<uint8_t> replace(const std::string_view &original, const std::string
return replaced;
}
void control_server_t::send(const std::string_view & payload) {
std::for_each(_host->peers, _host->peers + _host->peerCount, [payload](auto &peer) {
auto packet = enet_packet_create(payload.data(), payload.size(), ENET_PACKET_FLAG_RELIABLE);
if(enet_peer_send(&peer, 0, packet)) {
enet_packet_destroy(packet);
}
});
enet_host_flush(_host.get());
}
void controlBroadcastThread(safe::signal_t *shutdown_event, session_queue_t session_queue) {
control_server_t server { session_queue, CONTROL_PORT };
server.map(packetTypes[IDX_START_A], [&](session_t *session, const std::string_view &payload) {
void controlBroadcastThread(safe::signal_t *shutdown_event, control_server_t *server) {
server->map(packetTypes[IDX_START_A], [&](session_t *session, const std::string_view &payload) {
BOOST_LOG(debug) << "type [IDX_START_A]"sv;
});
server.map(packetTypes[IDX_START_B], [&](session_t *session, const std::string_view &payload) {
server->map(packetTypes[IDX_START_B], [&](session_t *session, const std::string_view &payload) {
BOOST_LOG(debug) << "type [IDX_START_B]"sv;
});
server.map(packetTypes[IDX_LOSS_STATS], [&](session_t *session, const std::string_view &payload) {
server->map(packetTypes[IDX_LOSS_STATS], [&](session_t *session, const std::string_view &payload) {
int32_t *stats = (int32_t*)payload.data();
auto count = stats[0];
std::chrono::milliseconds t { stats[1] };
@ -376,7 +417,7 @@ void controlBroadcastThread(safe::signal_t *shutdown_event, session_queue_t sess
<< "---end stats---";
});
server.map(packetTypes[IDX_INVALIDATE_REF_FRAMES], [&](session_t *session, const std::string_view &payload) {
server->map(packetTypes[IDX_INVALIDATE_REF_FRAMES], [&](session_t *session, const std::string_view &payload) {
std::int64_t *frames = (std::int64_t *)payload.data();
auto firstFrame = frames[0];
auto lastFrame = frames[1];
@ -389,7 +430,7 @@ void controlBroadcastThread(safe::signal_t *shutdown_event, session_queue_t sess
session->idr_events->raise(std::make_pair(firstFrame, lastFrame));
});
server.map(packetTypes[IDX_INPUT_DATA], [&](session_t *session, const std::string_view &payload) {
server->map(packetTypes[IDX_INPUT_DATA], [&](session_t *session, const std::string_view &payload) {
BOOST_LOG(debug) << "type [IDX_INPUT_DATA]"sv;
int32_t tagged_cipher_length = util::endian::big(*(int32_t*)payload.data());
@ -416,11 +457,16 @@ void controlBroadcastThread(safe::signal_t *shutdown_event, session_queue_t sess
});
while(!shutdown_event->peek()) {
auto now = std::chrono::steady_clock::now();
for(auto &[addr,session] : server._map_addr_session) {
if(now > session->pingTimeout) {
BOOST_LOG(info) << addr << ": Ping Timeout"sv;
session::stop(*session);
{
auto lg = server->_map_addr_session.lock();
auto now = std::chrono::steady_clock::now();
for(auto &[addr,port_session] : server->_map_addr_session.raw) {
auto session = port_session.second;
if(now > session->pingTimeout) {
BOOST_LOG(info) << addr << ": Ping Timeout"sv;
session::stop(*session);
}
}
}
@ -433,13 +479,13 @@ void controlBroadcastThread(safe::signal_t *shutdown_event, session_queue_t sess
payload[0] = packetTypes[IDX_TERMINATION];
payload[1] = reason;
server.send(std::string_view {(char*)payload.data(), payload.size()});
server->send(std::string_view {(char*)payload.data(), payload.size()});
shutdown_event->raise(true);
continue;
}
server.iterate(500ms);
server->iterate(500ms);
}
}
@ -650,11 +696,10 @@ int start_broadcast(broadcast_ctx_t &ctx) {
ctx.video_packets = std::make_shared<video::packet_queue_t::element_type>();
ctx.audio_packets = std::make_shared<audio::packet_queue_t::element_type>();
ctx.message_queue_queue = std::make_shared<message_queue_queue_t::element_type>();
ctx.session_queue = std::make_shared<session_queue_t::element_type>();
ctx.video_thread = std::thread { videoBroadcastThread, &broadcast_shutdown_event, std::ref(ctx.video_sock), ctx.video_packets };
ctx.audio_thread = std::thread { audioBroadcastThread, &broadcast_shutdown_event, std::ref(ctx.audio_sock), ctx.audio_packets };
ctx.control_thread = std::thread { controlBroadcastThread, &broadcast_shutdown_event, ctx.session_queue };
ctx.control_thread = std::thread { controlBroadcastThread, &broadcast_shutdown_event, &ctx.control_server };
ctx.recv_thread = std::thread { recvThread, std::ref(ctx) };
@ -727,12 +772,9 @@ void videoThread(session_t *session, std::string addr_str) {
session::stop(*session);
});
while(session->state == session::state_e::STARTING) {
std::this_thread::sleep_for(1ms);
}
while_starting_do_nothing(session->state);
auto addr = asio::ip::make_address(addr_str);
auto ref = broadcast.ref();
auto port = recv_ping(ref, socket_e::video, addr, config::stream.ping_timeout);
if(port < 0) {
@ -751,9 +793,7 @@ void audioThread(session_t *session, std::string addr_str) {
session::stop(*session);
});
while(session->state == session::state_e::STARTING) {
std::this_thread::sleep_for(1ms);
}
while_starting_do_nothing(session->state);
auto addr = asio::ip::make_address(addr_str);
@ -776,11 +816,16 @@ state_e state(session_t &session) {
}
void stop(session_t &session) {
session.broadcast_ref->session_queue->raise(session.video_peer.address().to_string(), nullptr);
session.shutdown_event.raise(true);
while_starting_do_nothing(session.state);
auto expected = state_e::RUNNING;
session.state.compare_exchange_strong(expected, state_e::STOPPING);
auto already_stopping = !session.state.compare_exchange_strong(expected, state_e::STOPPING);
if(already_stopping) {
return;
}
session.broadcast_ref->control_server.erase_session(session);
session.shutdown_event.raise(true);
}
void join(session_t &session) {
@ -792,7 +837,7 @@ void join(session_t &session) {
void start(session_t &session, const std::string &addr_string) {
session.broadcast_ref = broadcast.ref();
session.broadcast_ref->session_queue->raise(addr_string, &session);
session.broadcast_ref->control_server.emplace_addr_to_session(addr_string, session);
session.pingTimeout = std::chrono::steady_clock::now() + config::stream.ping_timeout;

112
sunshine/sync.h Normal file
View File

@ -0,0 +1,112 @@
//
// Created by loki on 16-4-19.
//
#ifndef SUNSHINE_SYNC_H
#define SUNSHINE_SYNC_H
#include <utility>
#include <mutex>
#include <array>
namespace util {
template<class T, std::size_t N = 1>
class sync_t {
public:
static_assert(N > 0, "sync_t should have more than zero mutexes");
using value_type = T;
template<std::size_t I = 0>
std::lock_guard<std::mutex> lock() {
return std::lock_guard { std::get<I>(_lock) };
}
template<class ...Args>
sync_t(Args&&... args) : raw {std::forward<Args>(args)... } {}
sync_t &operator=(sync_t &&other) noexcept {
for(auto &l : _lock) {
l.lock();
}
for(auto &l : other._lock) {
l.lock();
}
raw = std::move(other.raw);
for(auto &l : _lock) {
l.unlock();
}
for(auto &l : other._lock) {
l.unlock();
}
return *this;
}
sync_t &operator=(sync_t &other) noexcept {
for(auto &l : _lock) {
l.lock();
}
for(auto &l : other._lock) {
l.lock();
}
raw = other.raw;
for(auto &l : _lock) {
l.unlock();
}
for(auto &l : other._lock) {
l.unlock();
}
return *this;
}
sync_t &operator=(const value_type &val) noexcept {
for(auto &l : _lock) {
l.lock();
}
raw = val;
for(auto &l : _lock) {
l.unlock();
}
return *this;
}
sync_t &operator=(value_type &&val) noexcept {
for(auto &l : _lock) {
l.lock();
}
raw = std::move(val);
for(auto &l : _lock) {
l.unlock();
}
return *this;
}
value_type *operator->() {
return &raw;
}
value_type raw;
private:
std::array<std::mutex, N> _lock;
};
}
#endif //T_MAN_SYNC_H