Quick refactor to use same mechanism as updateOneTimeKeyCount

This commit is contained in:
Valere 2021-11-15 13:56:45 +01:00
parent c603135398
commit 10671a53a4
2 changed files with 45 additions and 33 deletions

View File

@ -330,7 +330,7 @@ internal class DefaultCryptoService @Inject constructor(
uploadDeviceKeys()
}
oneTimeKeysUploader.maybeUploadOneTimeKeys(shouldGenerateFallbackKey = false)
oneTimeKeysUploader.maybeUploadOneTimeKeys()
// this can throw if no backup
tryOrNull {
keysBackupService.checkAndStartKeysBackup()
@ -434,15 +434,13 @@ internal class DefaultCryptoService @Inject constructor(
deviceListManager.refreshOutdatedDeviceLists()
// The presence of device_unused_fallback_key_types indicates that the server supports fallback keys.
// If there's no unused signed_curve25519 fallback key we need a new one.
val shouldGenerateFallbackKey = if (syncResponse.deviceUnusedFallbackKeyTypes != null) {
// Generate a fallback key only if the server does not already have an unused fallback key.
!syncResponse.deviceUnusedFallbackKeyTypes.contains(KEY_SIGNED_CURVE_25519_TYPE)
} else {
// Server does not support fallbackKey
false
if (syncResponse.deviceUnusedFallbackKeyTypes != null
// Generate a fallback key only if the server does not already have an unused fallback key.
&& !syncResponse.deviceUnusedFallbackKeyTypes.contains(KEY_SIGNED_CURVE_25519_TYPE)) {
oneTimeKeysUploader.setNeedsNewFallback()
}
oneTimeKeysUploader.maybeUploadOneTimeKeys(shouldGenerateFallbackKey)
oneTimeKeysUploader.maybeUploadOneTimeKeys()
incomingGossipingRequestManager.processReceivedGossipingRequests()
}
}

View File

@ -40,6 +40,7 @@ internal class OneTimeKeysUploader @Inject constructor(
// last OTK check timestamp
private var lastOneTimeKeyCheck: Long = 0
private var oneTimeKeyCount: Int? = null
private var needNewFallbackKey: Boolean = false
/**
* Stores the current one_time_key count which will be handled later (in a call of
@ -51,10 +52,14 @@ internal class OneTimeKeysUploader @Inject constructor(
oneTimeKeyCount = currentCount
}
fun setNeedsNewFallback() {
needNewFallbackKey = true
}
/**
* Check if the OTK must be uploaded.
*/
suspend fun maybeUploadOneTimeKeys(shouldGenerateFallbackKey: Boolean) {
suspend fun maybeUploadOneTimeKeys() {
if (oneTimeKeyCheckInProgress) {
Timber.v("maybeUploadOneTimeKeys: already in progress")
return
@ -68,10 +73,6 @@ internal class OneTimeKeysUploader @Inject constructor(
lastOneTimeKeyCheck = System.currentTimeMillis()
oneTimeKeyCheckInProgress = true
if (shouldGenerateFallbackKey) {
olmDevice.generateFallbackKey()
}
// We then check how many keys we can store in the Account object.
val maxOneTimeKeys = olmDevice.getMaxNumberOfOneTimeKeys()
@ -100,7 +101,7 @@ internal class OneTimeKeysUploader @Inject constructor(
// So we need some kind of engineering compromise to balance all of
// these factors.
tryOrNull("Unable to upload OTK") {
val uploadedKeys = uploadOTK(oneTimeKeyCountFromSync, keyLimit, shouldGenerateFallbackKey)
val uploadedKeys = uploadOTK(oneTimeKeyCountFromSync, keyLimit)
Timber.v("## uploadKeys() : success, $uploadedKeys key(s) sent")
}
} else {
@ -124,22 +125,32 @@ internal class OneTimeKeysUploader @Inject constructor(
* @param keyLimit the limit
* @return the number of uploaded keys
*/
private suspend fun uploadOTK(keyCount: Int, keyLimit: Int, shouldUploadFallbackKey: Boolean): Int {
if (keyLimit <= keyCount && !shouldUploadFallbackKey) {
private suspend fun uploadOTK(keyCount: Int, keyLimit: Int): Int {
if (keyLimit <= keyCount && !needNewFallbackKey) {
// If we don't need to generate any more keys then we are done.
return 0
}
val keysThisLoop = min(keyLimit - keyCount, ONE_TIME_KEY_GENERATION_MAX_NUMBER)
olmDevice.generateOneTimeKeys(keysThisLoop)
var keysThisLoop = 0
if (keyLimit > keyCount) {
// Creating keys can be an expensive operation so we limit the
// number we generate in one go to avoid blocking the application
// for too long.
keysThisLoop = min(keyLimit - keyCount, ONE_TIME_KEY_GENERATION_MAX_NUMBER)
olmDevice.generateOneTimeKeys(keysThisLoop)
}
if (needNewFallbackKey) {
Timber.d("## CRYPTO: New fallback key needed")
olmDevice.generateFallbackKey()
}
val fallbackKey = if (shouldUploadFallbackKey) olmDevice.getFallbackKey() else null
val response = uploadOneTimeKeys(olmDevice.getOneTimeKeys(), fallbackKey)
val response = uploadOneTimeKeys(olmDevice.getOneTimeKeys())
olmDevice.markKeysAsPublished()
needNewFallbackKey = false
if (response.hasOneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE)) {
// Maybe upload other keys
return keysThisLoop + uploadOTK(response.oneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE), keyLimit, false)
return keysThisLoop + uploadOTK(response.oneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE), keyLimit)
} else {
Timber.e("## uploadOTK() : response for uploading keys does not contain one_time_key_counts.signed_curve25519")
throw Exception("response for uploading keys does not contain one_time_key_counts.signed_curve25519")
@ -149,7 +160,7 @@ internal class OneTimeKeysUploader @Inject constructor(
/**
* Upload curve25519 one time keys.
*/
private suspend fun uploadOneTimeKeys(oneTimeKeys: Map<String, Map<String, String>>?, fallbackKey: Map<String, Map<String, String>>?): KeysUploadResponse {
private suspend fun uploadOneTimeKeys(oneTimeKeys: Map<String, Map<String, String>>?): KeysUploadResponse {
val oneTimeJson = mutableMapOf<String, Any>()
val curve25519Map = oneTimeKeys?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty()
@ -166,16 +177,19 @@ internal class OneTimeKeysUploader @Inject constructor(
oneTimeJson["signed_curve25519:$key_id"] = k
}
val fallbackJson = mutableMapOf<String, Any>()
val fallbackCurve25519Map = fallbackKey?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty()
fallbackCurve25519Map.forEach { (key_id, key) ->
val k = mutableMapOf<String, Any>()
k["key"] = key
k["fallback"] = true
val canonicalJson = JsonCanonicalizer.getCanonicalJson(Map::class.java, k)
k["signatures"] = objectSigner.signObject(canonicalJson)
fallbackJson["signed_curve25519:$key_id"] = k
val fallbackJson = mutableMapOf<String, Any>()
if (needNewFallbackKey) {
val fallbackCurve25519Map = olmDevice.getFallbackKey()?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty()
fallbackCurve25519Map.forEach { (key_id, key) ->
val k = mutableMapOf<String, Any>()
k["key"] = key
k["fallback"] = true
val canonicalJson = JsonCanonicalizer.getCanonicalJson(Map::class.java, k)
k["signatures"] = objectSigner.signObject(canonicalJson)
fallbackJson["signed_curve25519:$key_id"] = k
}
}
// For now, we set the device id explicitly, as we may not be using the
@ -185,7 +199,7 @@ internal class OneTimeKeysUploader @Inject constructor(
oneTimeKeys = oneTimeJson,
fallbackKeys = fallbackJson.takeIf { fallbackJson.isNotEmpty() }
)
return uploadKeysTask.execute(uploadParams)
return uploadKeysTask.executeRetry(uploadParams, 3)
}
companion object {