diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/rendezvous/channels/ECDHRendezvousChannel.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/rendezvous/channels/ECDHRendezvousChannel.kt index 489d20e588..ca7083b297 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/rendezvous/channels/ECDHRendezvousChannel.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/rendezvous/channels/ECDHRendezvousChannel.kt @@ -90,43 +90,45 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu } override suspend fun connect(): String { - if (olmSAS == null) { - throw RuntimeException("Channel closed") - } - val isInitiator = theirPublicKey == null + olmSAS ?.let { olmSAS -> + val isInitiator = theirPublicKey == null - if (isInitiator) { - Timber.tag(TAG).i("Waiting for other device to send their public key") - val res = this.receiveAsPayload() ?: throw RuntimeException("No reply from other device") + if (isInitiator) { + Timber.tag(TAG).i("Waiting for other device to send their public key") + val res = this.receiveAsPayload() ?: throw RuntimeException("No reply from other device") - if (res.key == null) { - throw RendezvousError( - "Unsupported algorithm: ${res.algorithm}", - RendezvousFailureReason.UnsupportedAlgorithm, + if (res.key == null) { + throw RendezvousError( + "Unsupported algorithm: ${res.algorithm}", + RendezvousFailureReason.UnsupportedAlgorithm, + ) + } + theirPublicKey = Base64.decode(res.key, Base64.NO_WRAP) + } else { + // send our public key unencrypted + Timber.tag(TAG).i("Sending public key") + send( + ECDHPayload( + algorithm = SecureRendezvousChannelAlgorithm.ECDH_V1, + key = Base64.encodeToString(ourPublicKey, Base64.NO_WRAP) + ) ) } - theirPublicKey = Base64.decode(res.key, Base64.NO_WRAP) - } else { - // send our public key unencrypted - Timber.tag(TAG).i("Sending public key") - send( - ECDHPayload( - algorithm = SecureRendezvousChannelAlgorithm.ECDH_V1, - key = Base64.encodeToString(ourPublicKey, Base64.NO_WRAP) - ) - ) - } - olmSAS!!.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP)) + synchronized(olmSAS) { + olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP)) + olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP)) - val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP) - val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP) - val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey" + val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP) + val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP) + val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey" - aesKey = olmSAS!!.generateShortCode(aesInfo, 32) + aesKey = olmSAS.generateShortCode(aesInfo, 32) - val rawChecksum = olmSAS!!.generateShortCode(aesInfo, 5) - return getDecimalCodeRepresentation(rawChecksum) + val rawChecksum = olmSAS.generateShortCode(aesInfo, 5) + return getDecimalCodeRepresentation(rawChecksum) + } + } ?: throw RuntimeException("Channel closed") } private suspend fun send(payload: ECDHPayload) { @@ -174,8 +176,13 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu } override suspend fun close() { - olmSAS?.releaseSas() - olmSAS = null + olmSAS ?.let { + synchronized(it) { + // this does a double release check already so we don't re-check ourselves + it.releaseSas() + olmSAS = null + } + } } private fun encrypt(plainText: ByteArray): ECDHPayload {