From 604c3932cdba38dc86d0d5c28a249d3c915a4c2c Mon Sep 17 00:00:00 2001 From: valere Date: Fri, 3 Feb 2023 15:38:16 +0100 Subject: [PATCH] Flow collector causing strange NPE in some occasions --- .../sdk/internal/crypto/FlowCollectors.kt | 113 ++++++++++++++++++ .../android/sdk/internal/crypto/OlmMachine.kt | 45 +++---- 2 files changed, 128 insertions(+), 30 deletions(-) create mode 100644 matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/FlowCollectors.kt diff --git a/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/FlowCollectors.kt b/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/FlowCollectors.kt new file mode 100644 index 0000000000..391c0a2ae7 --- /dev/null +++ b/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/FlowCollectors.kt @@ -0,0 +1,113 @@ +/* + * Copyright 2023 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.matrix.android.sdk.internal.crypto + +import kotlinx.coroutines.channels.SendChannel +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import org.matrix.android.sdk.api.session.crypto.crosssigning.MXCrossSigningInfo +import org.matrix.android.sdk.api.session.crypto.crosssigning.PrivateKeysInfo +import org.matrix.android.sdk.api.session.crypto.model.CryptoDeviceInfo +import org.matrix.android.sdk.api.util.Optional + +internal data class UserIdentityCollector(val userId: String, val collector: SendChannel>) : + SendChannel> by collector + +internal data class DevicesCollector(val userIds: List, val collector: SendChannel>) : + SendChannel> by collector + +private typealias PrivateKeysCollector = SendChannel> + +internal class FlowCollectors { + private val userIdentityCollectors = mutableListOf() + private val privateKeyCollectors = mutableListOf() + private val deviceCollectors = ArrayList() + + private val identityLock = Mutex() + private val keysLock = Mutex() + private val deviceLock = Mutex() + + suspend fun addIdentityCollector(collector: UserIdentityCollector) { + identityLock.withLock { + userIdentityCollectors.add(collector) + } + } + + fun removeIdentityCollector(collector: UserIdentityCollector) { + // Annoying but it's called when the channel is closed and can't call + // something suspendable there :/ + runBlocking { + identityLock.withLock { + userIdentityCollectors.remove(collector) + } + } + } + + suspend fun forEachIdentityCollector(block: suspend ((UserIdentityCollector) -> Unit)) { + val safeCopy = identityLock.withLock { + userIdentityCollectors.toList() + } + safeCopy.onEach { block(it) } + } + + suspend fun addPrivateKeysCollector(collector: PrivateKeysCollector) { + keysLock.withLock { + privateKeyCollectors.add(collector) + } + } + + fun removePrivateKeysCollector(collector: PrivateKeysCollector) { + // Annoying but it's called when the channel is closed and can't call + // something suspendable there :/ + runBlocking { + keysLock.withLock { + privateKeyCollectors.remove(collector) + } + } + } + + suspend fun forEachPrivateKeysCollector(block: suspend ((PrivateKeysCollector) -> Unit)) { + val safeCopy = keysLock.withLock { + privateKeyCollectors.toList() + } + safeCopy.onEach { block(it) } + } + + suspend fun addDevicesCollector(collector: DevicesCollector) { + deviceLock.withLock { + deviceCollectors.add(collector) + } + } + + fun removeDevicesCollector(collector: DevicesCollector) { + // Annoying but it's called when the channel is closed and can't call + // something suspendable there :/ + runBlocking { + deviceLock.withLock { + deviceCollectors.remove(collector) + } + } + } + + suspend fun forEachDevicesCollector(block: suspend ((DevicesCollector) -> Unit)) { + val safeCopy = deviceLock.withLock { + deviceCollectors.toList() + } + safeCopy.onEach { block(it) } + } +} diff --git a/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt b/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt index f5b9ec17a1..9c52ef9da5 100644 --- a/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt +++ b/matrix-sdk-android/src/rustCrypto/java/org/matrix/android/sdk/internal/crypto/OlmMachine.kt @@ -20,8 +20,6 @@ import androidx.lifecycle.LiveData import androidx.lifecycle.asLiveData import com.squareup.moshi.Moshi import com.squareup.moshi.Types -import com.squareup.moshi.adapter -import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.channelFlow import kotlinx.coroutines.runBlocking @@ -99,19 +97,6 @@ private class CryptoProgressListener(private val listener: ProgressListener?) : } } -private data class UserIdentityCollector(val userId: String, val collector: SendChannel>) : - SendChannel> by collector - -private data class DevicesCollector(val userIds: List, val collector: SendChannel>) : - SendChannel> by collector -private typealias PrivateKeysCollector = SendChannel> - -private class FlowCollectors { - val userIdentityCollectors = ArrayList() - val privateKeyCollectors = ArrayList() - val deviceCollectors = ArrayList() -} - fun setRustLogger() { setLogger(CryptoLogger() as Logger) } @@ -130,7 +115,7 @@ internal class OlmMachine @Inject constructor( private val ensureUsersKeys: EnsureUsersKeysUseCase, private val matrixConfiguration: MatrixConfiguration, private val megolmSessionImportManager: MegolmSessionImportManager, - private val rustEncryptionConfiguration: RustEncryptionConfiguration, + rustEncryptionConfiguration: RustEncryptionConfiguration, ) { private val inner: InnerMachine @@ -165,23 +150,23 @@ internal class OlmMachine @Inject constructor( } private suspend fun updateLiveDevices() { - for (deviceCollector in flowCollectors.deviceCollectors) { - val devices = getCryptoDeviceInfo(deviceCollector.userIds) - deviceCollector.trySend(devices) + flowCollectors.forEachDevicesCollector { + val devices = getCryptoDeviceInfo(it.userIds) + it.trySend(devices) } } private suspend fun updateLiveUserIdentities() { - for (userIdentityCollector in flowCollectors.userIdentityCollectors) { - val identity = getIdentity(userIdentityCollector.userId)?.toMxCrossSigningInfo() - userIdentityCollector.trySend(identity.toOptional()) + flowCollectors.forEachIdentityCollector { + val identity = getIdentity(it.userId)?.toMxCrossSigningInfo().toOptional() + it.trySend(identity) } } private suspend fun updateLivePrivateKeys() { val keys = exportCrossSigningKeys().toOptional() - for (privateKeyCollector in flowCollectors.privateKeyCollectors) { - privateKeyCollector.trySend(keys) + flowCollectors.forEachPrivateKeysCollector { + it.trySend(keys) } } @@ -699,9 +684,9 @@ internal class OlmMachine @Inject constructor( return channelFlow { val userIdentityCollector = UserIdentityCollector(userId, this) val onClose = safeInvokeOnClose { - flowCollectors.userIdentityCollectors.remove(userIdentityCollector) + flowCollectors.removeIdentityCollector(userIdentityCollector) } - flowCollectors.userIdentityCollectors.add(userIdentityCollector) + flowCollectors.addIdentityCollector(userIdentityCollector) val identity = getIdentity(userId)?.toMxCrossSigningInfo().toOptional() send(identity) onClose.await() @@ -719,9 +704,9 @@ internal class OlmMachine @Inject constructor( fun getPrivateCrossSigningKeysFlow(): Flow> { return channelFlow { val onClose = safeInvokeOnClose { - flowCollectors.privateKeyCollectors.remove(this) + flowCollectors.removePrivateKeysCollector(this) } - flowCollectors.privateKeyCollectors.add(this) + flowCollectors.addPrivateKeysCollector(this) val keys = this@OlmMachine.exportCrossSigningKeys().toOptional() send(keys) onClose.await() @@ -746,9 +731,9 @@ internal class OlmMachine @Inject constructor( return channelFlow { val devicesCollector = DevicesCollector(userIds, this) val onClose = safeInvokeOnClose { - flowCollectors.deviceCollectors.remove(devicesCollector) + flowCollectors.removeDevicesCollector(devicesCollector) } - flowCollectors.deviceCollectors.add(devicesCollector) + flowCollectors.addDevicesCollector(devicesCollector) val devices = getCryptoDeviceInfo(userIds) send(devices) onClose.await()