Flow collector causing strange NPE in some occasions

This commit is contained in:
valere 2023-02-03 15:38:16 +01:00
parent 32aaf57ecf
commit 604c3932cd
2 changed files with 128 additions and 30 deletions

View File

@ -0,0 +1,113 @@
/*
* Copyright 2023 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.matrix.android.sdk.internal.crypto
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import org.matrix.android.sdk.api.session.crypto.crosssigning.MXCrossSigningInfo
import org.matrix.android.sdk.api.session.crypto.crosssigning.PrivateKeysInfo
import org.matrix.android.sdk.api.session.crypto.model.CryptoDeviceInfo
import org.matrix.android.sdk.api.util.Optional
internal data class UserIdentityCollector(val userId: String, val collector: SendChannel<Optional<MXCrossSigningInfo>>) :
SendChannel<Optional<MXCrossSigningInfo>> by collector
internal data class DevicesCollector(val userIds: List<String>, val collector: SendChannel<List<CryptoDeviceInfo>>) :
SendChannel<List<CryptoDeviceInfo>> by collector
private typealias PrivateKeysCollector = SendChannel<Optional<PrivateKeysInfo>>
internal class FlowCollectors {
private val userIdentityCollectors = mutableListOf<UserIdentityCollector>()
private val privateKeyCollectors = mutableListOf<PrivateKeysCollector>()
private val deviceCollectors = ArrayList<DevicesCollector>()
private val identityLock = Mutex()
private val keysLock = Mutex()
private val deviceLock = Mutex()
suspend fun addIdentityCollector(collector: UserIdentityCollector) {
identityLock.withLock {
userIdentityCollectors.add(collector)
}
}
fun removeIdentityCollector(collector: UserIdentityCollector) {
// Annoying but it's called when the channel is closed and can't call
// something suspendable there :/
runBlocking {
identityLock.withLock {
userIdentityCollectors.remove(collector)
}
}
}
suspend fun forEachIdentityCollector(block: suspend ((UserIdentityCollector) -> Unit)) {
val safeCopy = identityLock.withLock {
userIdentityCollectors.toList()
}
safeCopy.onEach { block(it) }
}
suspend fun addPrivateKeysCollector(collector: PrivateKeysCollector) {
keysLock.withLock {
privateKeyCollectors.add(collector)
}
}
fun removePrivateKeysCollector(collector: PrivateKeysCollector) {
// Annoying but it's called when the channel is closed and can't call
// something suspendable there :/
runBlocking {
keysLock.withLock {
privateKeyCollectors.remove(collector)
}
}
}
suspend fun forEachPrivateKeysCollector(block: suspend ((PrivateKeysCollector) -> Unit)) {
val safeCopy = keysLock.withLock {
privateKeyCollectors.toList()
}
safeCopy.onEach { block(it) }
}
suspend fun addDevicesCollector(collector: DevicesCollector) {
deviceLock.withLock {
deviceCollectors.add(collector)
}
}
fun removeDevicesCollector(collector: DevicesCollector) {
// Annoying but it's called when the channel is closed and can't call
// something suspendable there :/
runBlocking {
deviceLock.withLock {
deviceCollectors.remove(collector)
}
}
}
suspend fun forEachDevicesCollector(block: suspend ((DevicesCollector) -> Unit)) {
val safeCopy = deviceLock.withLock {
deviceCollectors.toList()
}
safeCopy.onEach { block(it) }
}
}

View File

@ -20,8 +20,6 @@ import androidx.lifecycle.LiveData
import androidx.lifecycle.asLiveData import androidx.lifecycle.asLiveData
import com.squareup.moshi.Moshi import com.squareup.moshi.Moshi
import com.squareup.moshi.Types import com.squareup.moshi.Types
import com.squareup.moshi.adapter
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.runBlocking import kotlinx.coroutines.runBlocking
@ -99,19 +97,6 @@ private class CryptoProgressListener(private val listener: ProgressListener?) :
} }
} }
private data class UserIdentityCollector(val userId: String, val collector: SendChannel<Optional<MXCrossSigningInfo>>) :
SendChannel<Optional<MXCrossSigningInfo>> by collector
private data class DevicesCollector(val userIds: List<String>, val collector: SendChannel<List<CryptoDeviceInfo>>) :
SendChannel<List<CryptoDeviceInfo>> by collector
private typealias PrivateKeysCollector = SendChannel<Optional<PrivateKeysInfo>>
private class FlowCollectors {
val userIdentityCollectors = ArrayList<UserIdentityCollector>()
val privateKeyCollectors = ArrayList<PrivateKeysCollector>()
val deviceCollectors = ArrayList<DevicesCollector>()
}
fun setRustLogger() { fun setRustLogger() {
setLogger(CryptoLogger() as Logger) setLogger(CryptoLogger() as Logger)
} }
@ -130,7 +115,7 @@ internal class OlmMachine @Inject constructor(
private val ensureUsersKeys: EnsureUsersKeysUseCase, private val ensureUsersKeys: EnsureUsersKeysUseCase,
private val matrixConfiguration: MatrixConfiguration, private val matrixConfiguration: MatrixConfiguration,
private val megolmSessionImportManager: MegolmSessionImportManager, private val megolmSessionImportManager: MegolmSessionImportManager,
private val rustEncryptionConfiguration: RustEncryptionConfiguration, rustEncryptionConfiguration: RustEncryptionConfiguration,
) { ) {
private val inner: InnerMachine private val inner: InnerMachine
@ -165,23 +150,23 @@ internal class OlmMachine @Inject constructor(
} }
private suspend fun updateLiveDevices() { private suspend fun updateLiveDevices() {
for (deviceCollector in flowCollectors.deviceCollectors) { flowCollectors.forEachDevicesCollector {
val devices = getCryptoDeviceInfo(deviceCollector.userIds) val devices = getCryptoDeviceInfo(it.userIds)
deviceCollector.trySend(devices) it.trySend(devices)
} }
} }
private suspend fun updateLiveUserIdentities() { private suspend fun updateLiveUserIdentities() {
for (userIdentityCollector in flowCollectors.userIdentityCollectors) { flowCollectors.forEachIdentityCollector {
val identity = getIdentity(userIdentityCollector.userId)?.toMxCrossSigningInfo() val identity = getIdentity(it.userId)?.toMxCrossSigningInfo().toOptional()
userIdentityCollector.trySend(identity.toOptional()) it.trySend(identity)
} }
} }
private suspend fun updateLivePrivateKeys() { private suspend fun updateLivePrivateKeys() {
val keys = exportCrossSigningKeys().toOptional() val keys = exportCrossSigningKeys().toOptional()
for (privateKeyCollector in flowCollectors.privateKeyCollectors) { flowCollectors.forEachPrivateKeysCollector {
privateKeyCollector.trySend(keys) it.trySend(keys)
} }
} }
@ -699,9 +684,9 @@ internal class OlmMachine @Inject constructor(
return channelFlow { return channelFlow {
val userIdentityCollector = UserIdentityCollector(userId, this) val userIdentityCollector = UserIdentityCollector(userId, this)
val onClose = safeInvokeOnClose { val onClose = safeInvokeOnClose {
flowCollectors.userIdentityCollectors.remove(userIdentityCollector) flowCollectors.removeIdentityCollector(userIdentityCollector)
} }
flowCollectors.userIdentityCollectors.add(userIdentityCollector) flowCollectors.addIdentityCollector(userIdentityCollector)
val identity = getIdentity(userId)?.toMxCrossSigningInfo().toOptional() val identity = getIdentity(userId)?.toMxCrossSigningInfo().toOptional()
send(identity) send(identity)
onClose.await() onClose.await()
@ -719,9 +704,9 @@ internal class OlmMachine @Inject constructor(
fun getPrivateCrossSigningKeysFlow(): Flow<Optional<PrivateKeysInfo>> { fun getPrivateCrossSigningKeysFlow(): Flow<Optional<PrivateKeysInfo>> {
return channelFlow { return channelFlow {
val onClose = safeInvokeOnClose { val onClose = safeInvokeOnClose {
flowCollectors.privateKeyCollectors.remove(this) flowCollectors.removePrivateKeysCollector(this)
} }
flowCollectors.privateKeyCollectors.add(this) flowCollectors.addPrivateKeysCollector(this)
val keys = this@OlmMachine.exportCrossSigningKeys().toOptional() val keys = this@OlmMachine.exportCrossSigningKeys().toOptional()
send(keys) send(keys)
onClose.await() onClose.await()
@ -746,9 +731,9 @@ internal class OlmMachine @Inject constructor(
return channelFlow { return channelFlow {
val devicesCollector = DevicesCollector(userIds, this) val devicesCollector = DevicesCollector(userIds, this)
val onClose = safeInvokeOnClose { val onClose = safeInvokeOnClose {
flowCollectors.deviceCollectors.remove(devicesCollector) flowCollectors.removeDevicesCollector(devicesCollector)
} }
flowCollectors.deviceCollectors.add(devicesCollector) flowCollectors.addDevicesCollector(devicesCollector)
val devices = getCryptoDeviceInfo(userIds) val devices = getCryptoDeviceInfo(userIds)
send(devices) send(devices)
onClose.await() onClose.await()