MercurySession.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. #include "MercurySession.h"
  2. #include <string.h> // for memcpy
  3. #include <memory> // for shared_ptr
  4. #include <mutex> // for scoped_lock
  5. #include <stdexcept> // for runtime_error
  6. #include <type_traits> // for remove_extent_t, __underlying_type_impl<>:...
  7. #include <utility> // for pair
  8. #ifndef _WIN32
  9. #include <arpa/inet.h> // for htons, ntohs, htonl, ntohl
  10. #endif
  11. #include "BellLogger.h" // for AbstractLogger
  12. #include "BellTask.h" // for Task
  13. #include "BellUtils.h" // for BELL_SLEEP_MS
  14. #include "Logger.h" // for CSPOT_LOG
  15. #include "NanoPBHelper.h" // for pbPutString, pbDecode, pbEncode
  16. #include "PlainConnection.h" // for PlainConnection
  17. #include "ShannonConnection.h" // for ShannonConnection
  18. #include "TimeProvider.h" // for TimeProvider
  19. #include "Utils.h" // for extract, pack, hton64
  20. using namespace cspot;
  21. MercurySession::MercurySession(std::shared_ptr<TimeProvider> timeProvider)
  22. : bell::Task("mercury_dispatcher", 4 * 1024, 3, 1) {
  23. this->timeProvider = timeProvider;
  24. }
  25. MercurySession::~MercurySession() {
  26. std::scoped_lock lock(this->isRunningMutex);
  27. }
  28. void MercurySession::runTask() {
  29. isRunning = true;
  30. std::scoped_lock lock(this->isRunningMutex);
  31. this->executeEstabilishedCallback = true;
  32. while (isRunning) {
  33. cspot::Packet packet = {};
  34. try {
  35. packet = shanConn->recvPacket();
  36. CSPOT_LOG(info, "Received packet, command: %d", packet.command);
  37. if (static_cast<RequestType>(packet.command) == RequestType::PING) {
  38. timeProvider->syncWithPingPacket(packet.data);
  39. this->lastPingTimestamp = timeProvider->getSyncedTimestamp();
  40. this->shanConn->sendPacket(0x49, packet.data);
  41. } else {
  42. this->packetQueue.push(packet);
  43. }
  44. } catch (const std::runtime_error& e) {
  45. CSPOT_LOG(error, "Error while receiving packet: %s", e.what());
  46. failAllPending();
  47. if (!isRunning)
  48. return;
  49. reconnect();
  50. continue;
  51. }
  52. }
  53. }
  54. void MercurySession::reconnect() {
  55. isReconnecting = true;
  56. try {
  57. this->conn = nullptr;
  58. this->shanConn = nullptr;
  59. this->connectWithRandomAp();
  60. this->authenticate(this->authBlob);
  61. CSPOT_LOG(info, "Reconnection successful");
  62. BELL_SLEEP_MS(100);
  63. lastPingTimestamp = timeProvider->getSyncedTimestamp();
  64. isReconnecting = false;
  65. this->executeEstabilishedCallback = true;
  66. } catch (...) {
  67. CSPOT_LOG(error, "Cannot reconnect, will retry in 5s");
  68. BELL_SLEEP_MS(5000);
  69. if (isRunning) {
  70. return reconnect();
  71. }
  72. }
  73. }
  74. void MercurySession::setConnectedHandler(
  75. ConnectionEstabilishedCallback callback) {
  76. this->connectionReadyCallback = callback;
  77. }
  78. bool MercurySession::triggerTimeout() {
  79. if (!isRunning)
  80. return true;
  81. auto currentTimestamp = timeProvider->getSyncedTimestamp();
  82. if (currentTimestamp - this->lastPingTimestamp > PING_TIMEOUT_MS) {
  83. CSPOT_LOG(debug, "Reconnection required, no ping received");
  84. return true;
  85. }
  86. return false;
  87. }
  88. void MercurySession::unregister(uint64_t sequenceId) {
  89. auto callback = this->callbacks.find(sequenceId);
  90. if (callback != this->callbacks.end()) {
  91. this->callbacks.erase(callback);
  92. }
  93. }
  94. void MercurySession::unregisterAudioKey(uint32_t sequenceId) {
  95. auto callback = this->audioKeyCallbacks.find(sequenceId);
  96. if (callback != this->audioKeyCallbacks.end()) {
  97. this->audioKeyCallbacks.erase(callback);
  98. }
  99. }
  100. void MercurySession::disconnect() {
  101. CSPOT_LOG(info, "Disconnecting mercury session");
  102. this->isRunning = false;
  103. conn->close();
  104. std::scoped_lock lock(this->isRunningMutex);
  105. }
  106. std::string MercurySession::getCountryCode() {
  107. return this->countryCode;
  108. }
  109. void MercurySession::handlePacket() {
  110. Packet packet = {};
  111. this->packetQueue.wtpop(packet, 200);
  112. if (executeEstabilishedCallback && this->connectionReadyCallback != nullptr) {
  113. executeEstabilishedCallback = false;
  114. this->connectionReadyCallback();
  115. }
  116. switch (static_cast<RequestType>(packet.command)) {
  117. case RequestType::COUNTRY_CODE_RESPONSE: {
  118. this->countryCode = std::string();
  119. this->countryCode.resize(2);
  120. memcpy(this->countryCode.data(), packet.data.data(), 2);
  121. CSPOT_LOG(debug, "Received country code %s", this->countryCode.c_str());
  122. break;
  123. }
  124. case RequestType::AUDIO_KEY_FAILURE_RESPONSE:
  125. case RequestType::AUDIO_KEY_SUCCESS_RESPONSE: {
  126. // this->lastRequestTimestamp = -1;
  127. // First four bytes mark the sequence id
  128. auto seqId = ntohl(extract<uint32_t>(packet.data, 0));
  129. if (this->audioKeyCallbacks.count(seqId) > 0) {
  130. auto success = static_cast<RequestType>(packet.command) ==
  131. RequestType::AUDIO_KEY_SUCCESS_RESPONSE;
  132. this->audioKeyCallbacks[seqId](success, packet.data);
  133. }
  134. break;
  135. }
  136. case RequestType::SEND:
  137. case RequestType::SUB:
  138. case RequestType::UNSUB: {
  139. CSPOT_LOG(debug, "Received mercury packet");
  140. auto response = this->decodeResponse(packet.data);
  141. if (this->callbacks.count(response.sequenceId) > 0) {
  142. auto seqId = response.sequenceId;
  143. this->callbacks[response.sequenceId](response);
  144. this->callbacks.erase(this->callbacks.find(seqId));
  145. }
  146. break;
  147. }
  148. case RequestType::SUBRES: {
  149. auto response = decodeResponse(packet.data);
  150. auto uri = std::string(response.mercuryHeader.uri);
  151. if (this->subscriptions.count(uri) > 0) {
  152. this->subscriptions[uri](response);
  153. }
  154. break;
  155. }
  156. default:
  157. break;
  158. }
  159. }
  160. void MercurySession::failAllPending() {
  161. Response response = {};
  162. response.fail = true;
  163. // Fail all callbacks
  164. for (auto& it : this->callbacks) {
  165. it.second(response);
  166. }
  167. // Fail all subscriptions
  168. for (auto& it : this->subscriptions) {
  169. it.second(response);
  170. }
  171. // Remove references
  172. this->subscriptions = {};
  173. this->callbacks = {};
  174. }
  175. MercurySession::Response MercurySession::decodeResponse(
  176. const std::vector<uint8_t>& data) {
  177. Response response = {};
  178. response.parts = {};
  179. auto sequenceLength = ntohs(extract<uint16_t>(data, 0));
  180. response.sequenceId = hton64(extract<uint64_t>(data, 2));
  181. auto partsNumber = ntohs(extract<uint16_t>(data, 11));
  182. auto headerSize = ntohs(extract<uint16_t>(data, 13));
  183. auto headerBytes =
  184. std::vector<uint8_t>(data.begin() + 15, data.begin() + 15 + headerSize);
  185. auto pos = 15 + headerSize;
  186. while (pos < data.size()) {
  187. auto partSize = ntohs(extract<uint16_t>(data, pos));
  188. response.parts.push_back(std::vector<uint8_t>(
  189. data.begin() + pos + 2, data.begin() + pos + 2 + partSize));
  190. pos += 2 + partSize;
  191. }
  192. pbDecode(response.mercuryHeader, Header_fields, headerBytes);
  193. response.fail = false;
  194. return response;
  195. }
  196. uint64_t MercurySession::executeSubscription(RequestType method,
  197. const std::string& uri,
  198. ResponseCallback callback,
  199. ResponseCallback subscription,
  200. DataParts& payload) {
  201. CSPOT_LOG(debug, "Executing Mercury Request, type %s",
  202. RequestTypeMap[method].c_str());
  203. // Encode header
  204. pbPutString(uri, tempMercuryHeader.uri);
  205. pbPutString(RequestTypeMap[method], tempMercuryHeader.method);
  206. tempMercuryHeader.has_method = true;
  207. tempMercuryHeader.has_uri = true;
  208. // GET and SEND are actually the same. Therefore the override
  209. // The difference between them is only in header's method
  210. if (method == RequestType::GET) {
  211. method = RequestType::SEND;
  212. }
  213. if (method == RequestType::SUB) {
  214. this->subscriptions.insert({uri, subscription});
  215. }
  216. auto headerBytes = pbEncode(Header_fields, &tempMercuryHeader);
  217. this->callbacks.insert({sequenceId, callback});
  218. // Structure: [Sequence size] [SequenceId] [0x1] [Payloads number]
  219. // [Header size] [Header] [Payloads (size + data)]
  220. // Pack sequenceId
  221. auto sequenceIdBytes = pack<uint64_t>(hton64(this->sequenceId));
  222. auto sequenceSizeBytes = pack<uint16_t>(htons(sequenceIdBytes.size()));
  223. sequenceIdBytes.insert(sequenceIdBytes.begin(), sequenceSizeBytes.begin(),
  224. sequenceSizeBytes.end());
  225. sequenceIdBytes.push_back(0x01);
  226. auto payloadNum = pack<uint16_t>(htons(payload.size() + 1));
  227. sequenceIdBytes.insert(sequenceIdBytes.end(), payloadNum.begin(),
  228. payloadNum.end());
  229. auto headerSizePayload = pack<uint16_t>(htons(headerBytes.size()));
  230. sequenceIdBytes.insert(sequenceIdBytes.end(), headerSizePayload.begin(),
  231. headerSizePayload.end());
  232. sequenceIdBytes.insert(sequenceIdBytes.end(), headerBytes.begin(),
  233. headerBytes.end());
  234. // Encode all the payload parts
  235. for (int x = 0; x < payload.size(); x++) {
  236. headerSizePayload = pack<uint16_t>(htons(payload[x].size()));
  237. sequenceIdBytes.insert(sequenceIdBytes.end(), headerSizePayload.begin(),
  238. headerSizePayload.end());
  239. sequenceIdBytes.insert(sequenceIdBytes.end(), payload[x].begin(),
  240. payload[x].end());
  241. }
  242. // Bump sequence id
  243. this->sequenceId += 1;
  244. try {
  245. this->shanConn->sendPacket(
  246. static_cast<std::underlying_type<RequestType>::type>(method),
  247. sequenceIdBytes);
  248. } catch (...) {
  249. // @TODO: handle disconnect
  250. }
  251. return this->sequenceId - 1;
  252. }
  253. uint32_t MercurySession::requestAudioKey(const std::vector<uint8_t>& trackId,
  254. const std::vector<uint8_t>& fileId,
  255. AudioKeyCallback audioCallback) {
  256. auto buffer = fileId;
  257. // Store callback
  258. this->audioKeyCallbacks.insert({this->audioKeySequence, audioCallback});
  259. // Structure: [FILEID] [TRACKID] [4 BYTES SEQUENCE ID] [0x00, 0x00]
  260. buffer.insert(buffer.end(), trackId.begin(), trackId.end());
  261. auto audioKeySequenceBuffer = pack<uint32_t>(htonl(this->audioKeySequence));
  262. buffer.insert(buffer.end(), audioKeySequenceBuffer.begin(),
  263. audioKeySequenceBuffer.end());
  264. auto suffix = std::vector<uint8_t>({0x00, 0x00});
  265. buffer.insert(buffer.end(), suffix.begin(), suffix.end());
  266. // Bump audio key sequence
  267. this->audioKeySequence += 1;
  268. // Used for broken connection detection
  269. // this->lastRequestTimestamp = timeProvider->getSyncedTimestamp();
  270. try {
  271. this->shanConn->sendPacket(
  272. static_cast<uint8_t>(RequestType::AUDIO_KEY_REQUEST_COMMAND), buffer);
  273. } catch (...) {
  274. // @TODO: Handle disconnect
  275. }
  276. return audioKeySequence - 1;
  277. }