#include "MercurySession.h" #include // for memcpy #include // for shared_ptr #include // for scoped_lock #include // for runtime_error #include // for remove_extent_t, __underlying_type_impl<>:... #include // for pair #ifndef _WIN32 #include // for htons, ntohs, htonl, ntohl #endif #include "BellLogger.h" // for AbstractLogger #include "BellTask.h" // for Task #include "BellUtils.h" // for BELL_SLEEP_MS #include "Logger.h" // for CSPOT_LOG #include "NanoPBHelper.h" // for pbPutString, pbDecode, pbEncode #include "PlainConnection.h" // for PlainConnection #include "ShannonConnection.h" // for ShannonConnection #include "TimeProvider.h" // for TimeProvider #include "Utils.h" // for extract, pack, hton64 using namespace cspot; MercurySession::MercurySession(std::shared_ptr timeProvider) : bell::Task("mercury_dispatcher", 4 * 1024, 3, 1) { this->timeProvider = timeProvider; } MercurySession::~MercurySession() { std::scoped_lock lock(this->isRunningMutex); } void MercurySession::runTask() { isRunning = true; std::scoped_lock lock(this->isRunningMutex); this->executeEstabilishedCallback = true; while (isRunning) { cspot::Packet packet = {}; try { packet = shanConn->recvPacket(); CSPOT_LOG(info, "Received packet, command: %d", packet.command); if (static_cast(packet.command) == RequestType::PING) { timeProvider->syncWithPingPacket(packet.data); this->lastPingTimestamp = timeProvider->getSyncedTimestamp(); this->shanConn->sendPacket(0x49, packet.data); } else { this->packetQueue.push(packet); } } catch (const std::runtime_error& e) { CSPOT_LOG(error, "Error while receiving packet: %s", e.what()); failAllPending(); if (!isRunning) return; reconnect(); continue; } } } void MercurySession::reconnect() { isReconnecting = true; try { this->conn = nullptr; this->shanConn = nullptr; this->connectWithRandomAp(); this->authenticate(this->authBlob); CSPOT_LOG(info, "Reconnection successful"); BELL_SLEEP_MS(100); lastPingTimestamp = timeProvider->getSyncedTimestamp(); isReconnecting = false; this->executeEstabilishedCallback = true; } catch (...) { CSPOT_LOG(error, "Cannot reconnect, will retry in 5s"); BELL_SLEEP_MS(5000); if (isRunning) { return reconnect(); } } } void MercurySession::setConnectedHandler( ConnectionEstabilishedCallback callback) { this->connectionReadyCallback = callback; } bool MercurySession::triggerTimeout() { if (!isRunning) return true; auto currentTimestamp = timeProvider->getSyncedTimestamp(); if (currentTimestamp - this->lastPingTimestamp > PING_TIMEOUT_MS) { CSPOT_LOG(debug, "Reconnection required, no ping received"); return true; } return false; } void MercurySession::unregister(uint64_t sequenceId) { auto callback = this->callbacks.find(sequenceId); if (callback != this->callbacks.end()) { this->callbacks.erase(callback); } } void MercurySession::unregisterAudioKey(uint32_t sequenceId) { auto callback = this->audioKeyCallbacks.find(sequenceId); if (callback != this->audioKeyCallbacks.end()) { this->audioKeyCallbacks.erase(callback); } } void MercurySession::disconnect() { CSPOT_LOG(info, "Disconnecting mercury session"); this->isRunning = false; conn->close(); std::scoped_lock lock(this->isRunningMutex); } std::string MercurySession::getCountryCode() { return this->countryCode; } void MercurySession::handlePacket() { Packet packet = {}; this->packetQueue.wtpop(packet, 200); if (executeEstabilishedCallback && this->connectionReadyCallback != nullptr) { executeEstabilishedCallback = false; this->connectionReadyCallback(); } switch (static_cast(packet.command)) { case RequestType::COUNTRY_CODE_RESPONSE: { this->countryCode = std::string(); this->countryCode.resize(2); memcpy(this->countryCode.data(), packet.data.data(), 2); CSPOT_LOG(debug, "Received country code %s", this->countryCode.c_str()); break; } case RequestType::AUDIO_KEY_FAILURE_RESPONSE: case RequestType::AUDIO_KEY_SUCCESS_RESPONSE: { // this->lastRequestTimestamp = -1; // First four bytes mark the sequence id auto seqId = ntohl(extract(packet.data, 0)); if (this->audioKeyCallbacks.count(seqId) > 0) { auto success = static_cast(packet.command) == RequestType::AUDIO_KEY_SUCCESS_RESPONSE; this->audioKeyCallbacks[seqId](success, packet.data); } break; } case RequestType::SEND: case RequestType::SUB: case RequestType::UNSUB: { CSPOT_LOG(debug, "Received mercury packet"); auto response = this->decodeResponse(packet.data); if (this->callbacks.count(response.sequenceId) > 0) { auto seqId = response.sequenceId; this->callbacks[response.sequenceId](response); this->callbacks.erase(this->callbacks.find(seqId)); } break; } case RequestType::SUBRES: { auto response = decodeResponse(packet.data); auto uri = std::string(response.mercuryHeader.uri); if (this->subscriptions.count(uri) > 0) { this->subscriptions[uri](response); } break; } default: break; } } void MercurySession::failAllPending() { Response response = {}; response.fail = true; // Fail all callbacks for (auto& it : this->callbacks) { it.second(response); } // Fail all subscriptions for (auto& it : this->subscriptions) { it.second(response); } // Remove references this->subscriptions = {}; this->callbacks = {}; } MercurySession::Response MercurySession::decodeResponse( const std::vector& data) { Response response = {}; response.parts = {}; auto sequenceLength = ntohs(extract(data, 0)); response.sequenceId = hton64(extract(data, 2)); auto partsNumber = ntohs(extract(data, 11)); auto headerSize = ntohs(extract(data, 13)); auto headerBytes = std::vector(data.begin() + 15, data.begin() + 15 + headerSize); auto pos = 15 + headerSize; while (pos < data.size()) { auto partSize = ntohs(extract(data, pos)); response.parts.push_back(std::vector( data.begin() + pos + 2, data.begin() + pos + 2 + partSize)); pos += 2 + partSize; } pbDecode(response.mercuryHeader, Header_fields, headerBytes); response.fail = false; return response; } uint64_t MercurySession::executeSubscription(RequestType method, const std::string& uri, ResponseCallback callback, ResponseCallback subscription, DataParts& payload) { CSPOT_LOG(debug, "Executing Mercury Request, type %s", RequestTypeMap[method].c_str()); // Encode header pbPutString(uri, tempMercuryHeader.uri); pbPutString(RequestTypeMap[method], tempMercuryHeader.method); tempMercuryHeader.has_method = true; tempMercuryHeader.has_uri = true; // GET and SEND are actually the same. Therefore the override // The difference between them is only in header's method if (method == RequestType::GET) { method = RequestType::SEND; } if (method == RequestType::SUB) { this->subscriptions.insert({uri, subscription}); } auto headerBytes = pbEncode(Header_fields, &tempMercuryHeader); this->callbacks.insert({sequenceId, callback}); // Structure: [Sequence size] [SequenceId] [0x1] [Payloads number] // [Header size] [Header] [Payloads (size + data)] // Pack sequenceId auto sequenceIdBytes = pack(hton64(this->sequenceId)); auto sequenceSizeBytes = pack(htons(sequenceIdBytes.size())); sequenceIdBytes.insert(sequenceIdBytes.begin(), sequenceSizeBytes.begin(), sequenceSizeBytes.end()); sequenceIdBytes.push_back(0x01); auto payloadNum = pack(htons(payload.size() + 1)); sequenceIdBytes.insert(sequenceIdBytes.end(), payloadNum.begin(), payloadNum.end()); auto headerSizePayload = pack(htons(headerBytes.size())); sequenceIdBytes.insert(sequenceIdBytes.end(), headerSizePayload.begin(), headerSizePayload.end()); sequenceIdBytes.insert(sequenceIdBytes.end(), headerBytes.begin(), headerBytes.end()); // Encode all the payload parts for (int x = 0; x < payload.size(); x++) { headerSizePayload = pack(htons(payload[x].size())); sequenceIdBytes.insert(sequenceIdBytes.end(), headerSizePayload.begin(), headerSizePayload.end()); sequenceIdBytes.insert(sequenceIdBytes.end(), payload[x].begin(), payload[x].end()); } // Bump sequence id this->sequenceId += 1; try { this->shanConn->sendPacket( static_cast::type>(method), sequenceIdBytes); } catch (...) { // @TODO: handle disconnect } return this->sequenceId - 1; } uint32_t MercurySession::requestAudioKey(const std::vector& trackId, const std::vector& fileId, AudioKeyCallback audioCallback) { auto buffer = fileId; // Store callback this->audioKeyCallbacks.insert({this->audioKeySequence, audioCallback}); // Structure: [FILEID] [TRACKID] [4 BYTES SEQUENCE ID] [0x00, 0x00] buffer.insert(buffer.end(), trackId.begin(), trackId.end()); auto audioKeySequenceBuffer = pack(htonl(this->audioKeySequence)); buffer.insert(buffer.end(), audioKeySequenceBuffer.begin(), audioKeySequenceBuffer.end()); auto suffix = std::vector({0x00, 0x00}); buffer.insert(buffer.end(), suffix.begin(), suffix.end()); // Bump audio key sequence this->audioKeySequence += 1; // Used for broken connection detection // this->lastRequestTimestamp = timeProvider->getSyncedTimestamp(); try { this->shanConn->sendPacket( static_cast(RequestType::AUDIO_KEY_REQUEST_COMMAND), buffer); } catch (...) { // @TODO: Handle disconnect } return audioKeySequence - 1; }