filename:
src/network/Network.cpp
branch:
feature/world
back to repo
/*
*
* _____ _ _
* / ___|| | | |
* \ `--. | |_ _ __ __ _ | |_ ___ ___
* `--. \| __|| '__| / _` || __| / _ \ / __|
* /\__/ /| |_ | | | (_| || |_ | (_) |\__ \
* \____/ \__||_| \__,_| \__| \___/ |___/
*
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Copyright (C) 2025 Armen Deroian
*
*/
#include "Network.h"
#include "openssl/ssl.h"
#include "session/NetworkClient.h"
#include "spdlog/logger.h"
#include "utils/crypto/CryptoUtils.h"
#include <ranges>
#ifdef __linux__
#include <sys/epoll.h>
#endif
bool stratos::NetworkManager::start() {
if (running) throw std::runtime_error("Attempted to start NetworkManager while it is already running");
running = true;
encryptionEnabled = false; // TODO: settings from server config
if (encryptionEnabled) {
try {
encryptionKey = std::move(generateEncryptionKey());
logger->info("Generated encryption key of size {} bytes", encodeServerPublicKey(&encryptionKey).size());
} catch (const std::exception& e) {
logger->error("Failed to generate encryption key: {}", e.what());
return false;
}
}
try {
if (socketServer.isValid()) socketServer.listen(100);
} catch (const std::exception& e) {
logger->error("Socket Error: {}", e.what());
return false;
}
logger->info("Listening on {}:{}", socketServer.getAddress(), socketServer.getPort());
bossThread = std::make_unique<BossThread>(this, 10);
bossThread->start();
return true;
}
void stratos::NetworkManager::stop() {
if (!running) throw std::runtime_error("Attempted to stop NetworkManager while it is not running");
running = false;
// Wait for the socket thread to finish
bossThread->stop();
try {
socketServer.close();
} catch (const std::exception& e) {
logger->error("Socket Error: {}", e.what());
}
}
void stratos::NetworkManager::tick() {
processIncomingConnections();
// Tick all sessions
std::vector<SessionId> staleSessions;
for (const auto& [sessionId, session] : sessions) {
session->tick();
if (session->isStale()) {
staleSessions.push_back(sessionId);
}
}
// Remove stale sessions
for (const auto& sessionId : staleSessions) {
sessions.erase(sessionId);
}
}
stratos::Server* stratos::NetworkManager::getServer() const {
return server;
}
std::shared_ptr<stratos::NetworkSession> stratos::NetworkManager::getSession(const SessionId& sessionId) {
if (const auto it = sessions.find(sessionId); it != sessions.end()) return it->second;
return nullptr;
}
std::vector<std::shared_ptr<stratos::NetworkSession>> stratos::NetworkManager::getSessions() {
std::vector<std::shared_ptr<NetworkSession>> sessionList;
sessionList.reserve(sessions.size());
for (const auto& session : sessions | std::views::values) {
sessionList.push_back(session);
}
return sessionList;
}
void stratos::NetworkManager::createSession(std::shared_ptr<NetworkConnection> connection) {
sessionsQueue.enqueue(std::move(connection));
}
bool stratos::NetworkManager::removeSession(const SessionId& sessionId) {
if (const auto it = sessions.find(sessionId); it != sessions.end()) {
sessions.erase(it);
return true;
}
return false;
}
void stratos::NetworkManager::processIncomingConnections() {
std::shared_ptr<NetworkConnection> connection;
while (sessionsQueue.try_dequeue(connection)) {
if (connection) {
ClientInfo client{connection->getFd(), connection->getAddress(), connection->getPort()};
if (const auto it = sessions.find(client); it != sessions.end()) {
// TODO: This could be due to a client loosing a connection and reconnecting, reuse session?
logger->warn("Session already exists for client {}:{}", client.ip, client.port);
return; // Session already exists
}
// Create a new session
logger->info("Creating new network session for client {}:{}", client.ip, client.port);
const auto session = std::make_shared<NetworkSession>(this, client, std::move(connection));
sessions[client] = session;
}
}
}
void stratos::BossThread::start() {
if (running.exchange(true)) return;
workers = std::vector<std::shared_ptr<WorkerThread>>();
workers.reserve(workerThreads);
thread = std::thread([this] {
while (running) {
try {
// Accept new connections
if (const auto client = network->socketServer.accept(); client.socket != INVALID_SOCKET_FD) {
std::string ip = client.ip;
int port = client.port;
SocketFd socket = client.socket;
network->getLogger()->info("New connection from {}:{}", ip, port);
// Create a new session for the client
try {
std::shared_ptr<NetworkConnection> connection;
if (workers.size() < workerThreads) {
auto worker = std::make_shared<WorkerThread>(network, workers.size());
worker->start();
connection = std::make_shared<NetworkConnection>(socket, ip, port, network, network->getLogger(), worker);
worker->addConnection(connection);
workers.push_back(std::move(worker));
} else {
connection = std::make_shared<NetworkConnection>(socket, ip, port, network, network->getLogger(), workers[connectionCount % workerThreads]);
workers[connectionCount % workerThreads]->addConnection(connection);
}
// Create a new network session
//network->createSession(connection);
connectionCount++;
network->logger->info("Client '{}:{} - {}' connected", ip, port, socket);
} catch (std::exception& e) {
network->logger->error("Failed to connect client '{}:{}': {}", ip, port, e.what());
#ifdef _WIN32
closesocket(socket);
#else
::close(socket);
#endif
}
}
} catch (const std::exception& e) {
network->getLogger()->error("Socket Error: {}", e.what());
}
}
});
}
void stratos::BossThread::stop() {
if (running.exchange(false)) {
// Notify all worker threads to stop
for (const auto& worker : workers) {
worker->stop();
}
// Wait for the boss thread to finish
if (thread.joinable()) {
thread.join();
}
}
}
void stratos::WorkerThread::start() {
#ifdef __linux__
epollFd = epoll_create1(0);
if (epollFd == -1) throw std::runtime_error("epoll_create1 failed: " + std::string(strerror(errno)));
#endif
if (!running.exchange(true)) {
thread = std::thread([this] {
while (running) {
processIncomingConnections();
processSendNotifications();
#ifdef _WIN32
if (connectionPollFds.empty()) {
Sleep(2);
continue;
}
if (const int result = WSAPoll(connectionPollFds.data(), connectionPollFds.size(), 100); result == SOCKET_ERROR) {
network->getLogger()->error("WSAPoll failed: {}", WSAGetLastError());
continue;
}
for (size_t i = 0; i < connectionPollFds.size(); ++i) {
auto& [fd, events, revents] = connectionPollFds[i];
auto it = connections.find(fd);
if (it == connections.end()) continue;
auto& conn = it->second;
// Handle socket errors
if (revents & POLLERR || revents & POLLHUP || revents & POLLNVAL) {
network->getLogger()->info("Connection closed for client {}:{}", conn->getAddress(), conn->getPort());
conn->close();
connectionPollFds.erase(connectionPollFds.begin() + i);
removeConnection(fd);
--i;
continue;
}
// Handle readable events
if (revents & POLLRDNORM) {
if (!conn->handleReceive()) {
network->getLogger()->info("Connection closed for client {}:{}", conn->getAddress(), conn->getPort());
conn->close();
connectionPollFds.erase(connectionPollFds.begin() + i);
removeConnection(fd);
--i;
continue;
}
}
// Handle writable events
if (revents & POLLWRNORM) {
conn->flushSend();
if (!conn->hasSendData()) {
events &= ~POLLWRNORM;
bool expected = true;
conn->dirty.compare_exchange_strong(expected, false);
}
}
// Handle server -> client disconnects
if (conn->isDisconnected()) {
network->getLogger()->info("Client {}:{} disconnected", conn->getAddress(), conn->getPort());
conn->close();
connectionPollFds.erase(connectionPollFds.begin() + i);
removeConnection(fd);
}
}
#elifdef __linux__
constexpr int MAX_EVENTS = 64;
epoll_event events[MAX_EVENTS];
const int ready = epoll_wait(epollFd, events, MAX_EVENTS, 100); // wait up to 100 ms
if (ready == -1) {
if (errno == EINTR) continue;
throw std::runtime_error("epoll_wait failed: " + std::string(strerror(errno)));
}
for (int i = 0; i < ready; ++i) {
int fd = events[i].data.fd;
auto it = connections.find(fd);
if (it == connections.end()) continue;
auto& conn = it->second;
if (events[i].events & EPOLLIN) {
if (conn->handleReceive() == 0) {
network->getLogger()->info("Connection closed for client {}:{}", conn->getAddress(), conn->getPort());
conn->close();
removeConnection(fd);
continue;
}
}
if (!conn->isClosed() && events[i].events & EPOLLOUT) {
conn->flushSend();
// If queue is empty, remove EPOLLOUT to prevent epoll wakeups
epoll_event ev{};
ev.events = EPOLLIN | EPOLLET; // Edge-triggered, no out until we have data to send
if (conn->hasSendData()) {
ev.events |= EPOLLOUT;
} else {
bool expected = true;
conn->dirty.compare_exchange_strong(expected, false);
}
ev.data.fd = fd;
epoll_ctl(epollFd, EPOLL_CTL_MOD, fd, &ev);
}
// Handle server -> client disconnects
if (conn->isDisconnected()) {
network->getLogger()->info("Client {}:{} disconnected", conn->getAddress(), conn->getPort());
conn->close();
removeConnection(fd);
}
}
#endif
}
});
}
}
void stratos::WorkerThread::stop() {
if (running.exchange(false)) {
// Notify all connections to stop
for (const auto& conn : connections | std::views::values) {
conn->close();
}
// Wait for the worker thread to finish
if (thread.joinable()) {
thread.join();
}
}
}
void stratos::WorkerThread::addConnection(std::shared_ptr<NetworkConnection> connection) {
inConnectionQueue.enqueue(std::move(connection));
}
void stratos::WorkerThread::removeConnection(const SocketFd connection) {
#ifdef __linux__
epoll_ctl(epollFd, EPOLL_CTL_DEL, connection, nullptr);
#endif
std::lock_guard lock(connectionMutex);
connections.erase(connection);
connectionCount--;
}
void stratos::WorkerThread::notifySend(const SocketFd& socketFd) {
sendNotifyQueue.enqueue(socketFd);
}
std::shared_ptr<stratos::NetworkConnection> stratos::WorkerThread::getConnection(const SocketFd& socketFd) {
if (const auto it = connections.find(socketFd); it != connections.end())
return it->second;
return nullptr;
}
void stratos::WorkerThread::processIncomingConnections() {
std::lock_guard lock(connectionMutex);
while (true) {
if (std::shared_ptr<NetworkConnection> connection; inConnectionQueue.try_dequeue(connection)) {
#ifdef __WIN32
connectionPollFds.push_back({connection->getFd(), POLLRDNORM, 0});
#elifdef __linux__
epoll_event ev{};
ev.events = EPOLLIN | EPOLLET; // Edge-triggered, no out until we have data to send
ev.data.fd = connection->getFd();
if (epoll_ctl(epollFd, EPOLL_CTL_ADD, connection->getFd(), &ev) == -1) {
if (errno != EEXIST) throw std::runtime_error("epoll_ctl(ADD) failed: " + std::string(strerror(errno)));
epoll_ctl(epollFd, EPOLL_CTL_MOD, connection->getFd(), &ev);
}
#endif
connections[connection->getFd()] = std::move(connection);
connectionCount++;
} else {
break;
}
}
}
void stratos::WorkerThread::processSendNotifications() {
SocketFd socketFd;
#ifdef _WIN32
std::vector<SocketFd> sockets;
while (sendNotifyQueue.try_dequeue(socketFd)) {
sockets.push_back(socketFd);
}
std::lock_guard lock(connectionMutex);
for (auto& pfd : connectionPollFds) {
if (std::ranges::find(sockets, pfd.fd) != sockets.end()) {
pfd.events |= POLLWRNORM;
}
}
#elifdef __linux__
while (sendNotifyQueue.try_dequeue(socketFd)) {
epoll_event ev{};
ev.events = EPOLLIN | EPOLLOUT| EPOLLET; // Edge-triggered, no out until we have data to send
ev.data.fd = socketFd;
epoll_ctl(epollFd, EPOLL_CTL_MOD, socketFd, &ev);
}
#endif
}