From f297117df2fe66c61e8589e73bc88cf7a99586bb Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Tue, 18 Oct 2022 08:48:28 +0100 Subject: [PATCH] Use mutex --- .../channels/ECDHRendezvousChannel.kt | 75 ++++++++++--------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/rendezvous/channels/ECDHRendezvousChannel.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/rendezvous/channels/ECDHRendezvousChannel.kt index 2f368d6520..19c46db34d 100644 --- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/rendezvous/channels/ECDHRendezvousChannel.kt +++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/rendezvous/channels/ECDHRendezvousChannel.kt @@ -19,6 +19,8 @@ package org.matrix.android.sdk.api.rendezvous.channels import android.util.Base64 import com.squareup.moshi.Json import com.squareup.moshi.JsonClass +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import okhttp3.MediaType.Companion.toMediaType import org.matrix.android.sdk.api.logger.LoggerTag import org.matrix.android.sdk.api.rendezvous.RendezvousChannel @@ -71,6 +73,7 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu @Json val iv: String? = null ) + private var olmSASMutex = Mutex() private var olmSAS: OlmSAS? private val ourPublicKey: ByteArray private val ecdhAdapter = MatrixJsonParser.getMoshi().adapter(ECDHPayload::class.java) @@ -87,45 +90,44 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu @Throws(RendezvousError::class) override suspend fun connect(): String { - olmSAS ?.let { olmSAS -> - val isInitiator = theirPublicKey == null + val sas = olmSAS ?: throw RendezvousError("Channel closed", RendezvousFailureReason.Unknown) + val isInitiator = theirPublicKey == null - if (isInitiator) { - Timber.tag(TAG).i("Waiting for other device to send their public key") - val res = this.receiveAsPayload() ?: throw RendezvousError("No reply from other device", RendezvousFailureReason.ProtocolError) + if (isInitiator) { + Timber.tag(TAG).i("Waiting for other device to send their public key") + val res = this.receiveAsPayload() ?: throw RendezvousError("No reply from other device", RendezvousFailureReason.ProtocolError) - if (res.key == null) { - throw RendezvousError( - "Unsupported algorithm: ${res.algorithm}", - RendezvousFailureReason.UnsupportedAlgorithm, - ) - } - theirPublicKey = Base64.decode(res.key, Base64.NO_WRAP) - } else { - // send our public key unencrypted - Timber.tag(TAG).i("Sending public key") - send( - ECDHPayload( - algorithm = SecureRendezvousChannelAlgorithm.ECDH_V1, - key = Base64.encodeToString(ourPublicKey, Base64.NO_WRAP) - ) + if (res.key == null) { + throw RendezvousError( + "Unsupported algorithm: ${res.algorithm}", + RendezvousFailureReason.UnsupportedAlgorithm, ) } + theirPublicKey = Base64.decode(res.key, Base64.NO_WRAP) + } else { + // send our public key unencrypted + Timber.tag(TAG).i("Sending public key") + send( + ECDHPayload( + algorithm = SecureRendezvousChannelAlgorithm.ECDH_V1, + key = Base64.encodeToString(ourPublicKey, Base64.NO_WRAP) + ) + ) + } - synchronized(olmSAS) { - olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP)) - olmSAS.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP)) + olmSASMutex.withLock { + sas.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP)) + sas.setTheirPublicKey(Base64.encodeToString(theirPublicKey, Base64.NO_WRAP)) - val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP) - val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP) - val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey" + val initiatorKey = Base64.encodeToString(if (isInitiator) ourPublicKey else theirPublicKey, Base64.NO_WRAP) + val recipientKey = Base64.encodeToString(if (isInitiator) theirPublicKey else ourPublicKey, Base64.NO_WRAP) + val aesInfo = "${SecureRendezvousChannelAlgorithm.ECDH_V1.value}|$initiatorKey|$recipientKey" - aesKey = olmSAS.generateShortCode(aesInfo, 32) + aesKey = sas.generateShortCode(aesInfo, 32) - val rawChecksum = olmSAS.generateShortCode(aesInfo, 5) - return getDecimalCodeRepresentation(rawChecksum) - } - } ?: throw RuntimeException("Channel closed") + val rawChecksum = sas.generateShortCode(aesInfo, 5) + return getDecimalCodeRepresentation(rawChecksum) + } } private suspend fun send(payload: ECDHPayload) { @@ -154,12 +156,11 @@ class ECDHRendezvousChannel(override var transport: RendezvousTransport, theirPu } override suspend fun close() { - olmSAS ?.let { - synchronized(it) { - // this does a double release check already so we don't re-check ourselves - it.releaseSas() - olmSAS = null - } + val sas = olmSAS ?: throw IllegalStateException("Channel already closed") + olmSASMutex.withLock { + // this does a double release check already so we don't re-check ourselves + sas.releaseSas() + olmSAS = null } transport.close() }