Quick refactor to use same mechanism as updateOneTimeKeyCount

This commit is contained in:
Valere 2021-11-15 13:56:45 +01:00
parent c603135398
commit 10671a53a4
2 changed files with 45 additions and 33 deletions

View File

@ -330,7 +330,7 @@ internal class DefaultCryptoService @Inject constructor(
uploadDeviceKeys() uploadDeviceKeys()
} }
oneTimeKeysUploader.maybeUploadOneTimeKeys(shouldGenerateFallbackKey = false) oneTimeKeysUploader.maybeUploadOneTimeKeys()
// this can throw if no backup // this can throw if no backup
tryOrNull { tryOrNull {
keysBackupService.checkAndStartKeysBackup() keysBackupService.checkAndStartKeysBackup()
@ -434,15 +434,13 @@ internal class DefaultCryptoService @Inject constructor(
deviceListManager.refreshOutdatedDeviceLists() deviceListManager.refreshOutdatedDeviceLists()
// The presence of device_unused_fallback_key_types indicates that the server supports fallback keys. // The presence of device_unused_fallback_key_types indicates that the server supports fallback keys.
// If there's no unused signed_curve25519 fallback key we need a new one. // If there's no unused signed_curve25519 fallback key we need a new one.
val shouldGenerateFallbackKey = if (syncResponse.deviceUnusedFallbackKeyTypes != null) { if (syncResponse.deviceUnusedFallbackKeyTypes != null
// Generate a fallback key only if the server does not already have an unused fallback key. // Generate a fallback key only if the server does not already have an unused fallback key.
!syncResponse.deviceUnusedFallbackKeyTypes.contains(KEY_SIGNED_CURVE_25519_TYPE) && !syncResponse.deviceUnusedFallbackKeyTypes.contains(KEY_SIGNED_CURVE_25519_TYPE)) {
} else { oneTimeKeysUploader.setNeedsNewFallback()
// Server does not support fallbackKey
false
} }
oneTimeKeysUploader.maybeUploadOneTimeKeys(shouldGenerateFallbackKey) oneTimeKeysUploader.maybeUploadOneTimeKeys()
incomingGossipingRequestManager.processReceivedGossipingRequests() incomingGossipingRequestManager.processReceivedGossipingRequests()
} }
} }

View File

@ -40,6 +40,7 @@ internal class OneTimeKeysUploader @Inject constructor(
// last OTK check timestamp // last OTK check timestamp
private var lastOneTimeKeyCheck: Long = 0 private var lastOneTimeKeyCheck: Long = 0
private var oneTimeKeyCount: Int? = null private var oneTimeKeyCount: Int? = null
private var needNewFallbackKey: Boolean = false
/** /**
* Stores the current one_time_key count which will be handled later (in a call of * Stores the current one_time_key count which will be handled later (in a call of
@ -51,10 +52,14 @@ internal class OneTimeKeysUploader @Inject constructor(
oneTimeKeyCount = currentCount oneTimeKeyCount = currentCount
} }
fun setNeedsNewFallback() {
needNewFallbackKey = true
}
/** /**
* Check if the OTK must be uploaded. * Check if the OTK must be uploaded.
*/ */
suspend fun maybeUploadOneTimeKeys(shouldGenerateFallbackKey: Boolean) { suspend fun maybeUploadOneTimeKeys() {
if (oneTimeKeyCheckInProgress) { if (oneTimeKeyCheckInProgress) {
Timber.v("maybeUploadOneTimeKeys: already in progress") Timber.v("maybeUploadOneTimeKeys: already in progress")
return return
@ -68,10 +73,6 @@ internal class OneTimeKeysUploader @Inject constructor(
lastOneTimeKeyCheck = System.currentTimeMillis() lastOneTimeKeyCheck = System.currentTimeMillis()
oneTimeKeyCheckInProgress = true oneTimeKeyCheckInProgress = true
if (shouldGenerateFallbackKey) {
olmDevice.generateFallbackKey()
}
// We then check how many keys we can store in the Account object. // We then check how many keys we can store in the Account object.
val maxOneTimeKeys = olmDevice.getMaxNumberOfOneTimeKeys() val maxOneTimeKeys = olmDevice.getMaxNumberOfOneTimeKeys()
@ -100,7 +101,7 @@ internal class OneTimeKeysUploader @Inject constructor(
// So we need some kind of engineering compromise to balance all of // So we need some kind of engineering compromise to balance all of
// these factors. // these factors.
tryOrNull("Unable to upload OTK") { tryOrNull("Unable to upload OTK") {
val uploadedKeys = uploadOTK(oneTimeKeyCountFromSync, keyLimit, shouldGenerateFallbackKey) val uploadedKeys = uploadOTK(oneTimeKeyCountFromSync, keyLimit)
Timber.v("## uploadKeys() : success, $uploadedKeys key(s) sent") Timber.v("## uploadKeys() : success, $uploadedKeys key(s) sent")
} }
} else { } else {
@ -124,22 +125,32 @@ internal class OneTimeKeysUploader @Inject constructor(
* @param keyLimit the limit * @param keyLimit the limit
* @return the number of uploaded keys * @return the number of uploaded keys
*/ */
private suspend fun uploadOTK(keyCount: Int, keyLimit: Int, shouldUploadFallbackKey: Boolean): Int { private suspend fun uploadOTK(keyCount: Int, keyLimit: Int): Int {
if (keyLimit <= keyCount && !shouldUploadFallbackKey) { if (keyLimit <= keyCount && !needNewFallbackKey) {
// If we don't need to generate any more keys then we are done. // If we don't need to generate any more keys then we are done.
return 0 return 0
} }
val keysThisLoop = min(keyLimit - keyCount, ONE_TIME_KEY_GENERATION_MAX_NUMBER) var keysThisLoop = 0
olmDevice.generateOneTimeKeys(keysThisLoop) if (keyLimit > keyCount) {
// Creating keys can be an expensive operation so we limit the
// number we generate in one go to avoid blocking the application
// for too long.
keysThisLoop = min(keyLimit - keyCount, ONE_TIME_KEY_GENERATION_MAX_NUMBER)
olmDevice.generateOneTimeKeys(keysThisLoop)
}
if (needNewFallbackKey) {
Timber.d("## CRYPTO: New fallback key needed")
olmDevice.generateFallbackKey()
}
val fallbackKey = if (shouldUploadFallbackKey) olmDevice.getFallbackKey() else null val response = uploadOneTimeKeys(olmDevice.getOneTimeKeys())
val response = uploadOneTimeKeys(olmDevice.getOneTimeKeys(), fallbackKey)
olmDevice.markKeysAsPublished() olmDevice.markKeysAsPublished()
needNewFallbackKey = false
if (response.hasOneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE)) { if (response.hasOneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE)) {
// Maybe upload other keys // Maybe upload other keys
return keysThisLoop + uploadOTK(response.oneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE), keyLimit, false) return keysThisLoop + uploadOTK(response.oneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE), keyLimit)
} else { } else {
Timber.e("## uploadOTK() : response for uploading keys does not contain one_time_key_counts.signed_curve25519") Timber.e("## uploadOTK() : response for uploading keys does not contain one_time_key_counts.signed_curve25519")
throw Exception("response for uploading keys does not contain one_time_key_counts.signed_curve25519") throw Exception("response for uploading keys does not contain one_time_key_counts.signed_curve25519")
@ -149,7 +160,7 @@ internal class OneTimeKeysUploader @Inject constructor(
/** /**
* Upload curve25519 one time keys. * Upload curve25519 one time keys.
*/ */
private suspend fun uploadOneTimeKeys(oneTimeKeys: Map<String, Map<String, String>>?, fallbackKey: Map<String, Map<String, String>>?): KeysUploadResponse { private suspend fun uploadOneTimeKeys(oneTimeKeys: Map<String, Map<String, String>>?): KeysUploadResponse {
val oneTimeJson = mutableMapOf<String, Any>() val oneTimeJson = mutableMapOf<String, Any>()
val curve25519Map = oneTimeKeys?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty() val curve25519Map = oneTimeKeys?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty()
@ -166,16 +177,19 @@ internal class OneTimeKeysUploader @Inject constructor(
oneTimeJson["signed_curve25519:$key_id"] = k oneTimeJson["signed_curve25519:$key_id"] = k
} }
val fallbackJson = mutableMapOf<String, Any>()
val fallbackCurve25519Map = fallbackKey?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty()
fallbackCurve25519Map.forEach { (key_id, key) ->
val k = mutableMapOf<String, Any>()
k["key"] = key
k["fallback"] = true
val canonicalJson = JsonCanonicalizer.getCanonicalJson(Map::class.java, k)
k["signatures"] = objectSigner.signObject(canonicalJson)
fallbackJson["signed_curve25519:$key_id"] = k val fallbackJson = mutableMapOf<String, Any>()
if (needNewFallbackKey) {
val fallbackCurve25519Map = olmDevice.getFallbackKey()?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty()
fallbackCurve25519Map.forEach { (key_id, key) ->
val k = mutableMapOf<String, Any>()
k["key"] = key
k["fallback"] = true
val canonicalJson = JsonCanonicalizer.getCanonicalJson(Map::class.java, k)
k["signatures"] = objectSigner.signObject(canonicalJson)
fallbackJson["signed_curve25519:$key_id"] = k
}
} }
// For now, we set the device id explicitly, as we may not be using the // For now, we set the device id explicitly, as we may not be using the
@ -185,7 +199,7 @@ internal class OneTimeKeysUploader @Inject constructor(
oneTimeKeys = oneTimeJson, oneTimeKeys = oneTimeJson,
fallbackKeys = fallbackJson.takeIf { fallbackJson.isNotEmpty() } fallbackKeys = fallbackJson.takeIf { fallbackJson.isNotEmpty() }
) )
return uploadKeysTask.execute(uploadParams) return uploadKeysTask.executeRetry(uploadParams, 3)
} }
companion object { companion object {