crypto: Add support to claim one-time keys

This commit is contained in:
Damir Jelić 2021-03-04 17:14:48 +01:00
parent da35c9b6bd
commit 4c44a5e108
7 changed files with 134 additions and 66 deletions

View File

@ -79,6 +79,7 @@ import org.matrix.android.sdk.internal.crypto.model.event.SecretSendEventContent
import org.matrix.android.sdk.internal.crypto.model.rest.DeviceInfo
import org.matrix.android.sdk.internal.crypto.model.rest.DevicesListResponse
import org.matrix.android.sdk.internal.crypto.model.rest.KeysUploadResponse
import org.matrix.android.sdk.internal.crypto.model.rest.KeysClaimResponse
import org.matrix.android.sdk.internal.crypto.model.rest.KeysQueryResponse
import org.matrix.android.sdk.internal.crypto.model.rest.RoomKeyRequestBody
import org.matrix.android.sdk.internal.crypto.repository.WarnOnUnknownDeviceRepository
@ -91,6 +92,7 @@ import org.matrix.android.sdk.internal.crypto.tasks.GetDevicesTask
import org.matrix.android.sdk.internal.crypto.tasks.NewUploadKeysTask
import org.matrix.android.sdk.internal.crypto.tasks.SetDeviceNameTask
import org.matrix.android.sdk.internal.crypto.tasks.UploadKeysTask
import org.matrix.android.sdk.internal.crypto.tasks.ClaimOneTimeKeysForUsersDeviceTask
import org.matrix.android.sdk.internal.crypto.verification.DefaultVerificationService
import org.matrix.android.sdk.internal.di.DeviceId
import org.matrix.android.sdk.internal.di.MoshiProvider
@ -171,6 +173,7 @@ internal class DefaultCryptoService @Inject constructor(
private val deleteDeviceWithUserPasswordTask: DeleteDeviceWithUserPasswordTask,
// Tasks
private val getDevicesTask: GetDevicesTask,
private val oneTimeKeysForUsersDeviceTask: ClaimOneTimeKeysForUsersDeviceTask,
private val downloadKeysForUsersTask: DownloadKeysForUsersTask,
private val getDeviceInfoTask: GetDeviceInfoTask,
private val setDeviceNameTask: SetDeviceNameTask,
@ -691,6 +694,7 @@ internal class DefaultCryptoService @Inject constructor(
val t0 = System.currentTimeMillis()
Timber.v("## CRYPTO | encryptEventContent() starts")
runCatching {
preshareGroupSession(roomId, userIds)
val content = safeAlgorithm.encryptEventContent(eventContent, eventType, userIds)
Timber.v("## CRYPTO | encryptEventContent() : succeeds after ${System.currentTimeMillis() - t0} ms")
MXEncryptEventContentResult(content, EventType.ENCRYPTED)
@ -956,6 +960,26 @@ internal class DefaultCryptoService @Inject constructor(
olmMachine!!.receiveSyncChanges(toDevice, deviceChanges, keyCounts)
}
private suspend fun preshareGroupSession(roomId: String, roomMembers: List<String>) {
val request = olmMachine!!.getMissingSessions(roomMembers)
roomId == "est"
if (request != null) {
when (request) {
is Request.KeysClaim -> {
val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(request.oneTimeKeys)
val response = oneTimeKeysForUsersDeviceTask.execute(claimParams)
val adapter = MoshiProvider.providesMoshi().adapter<KeysClaimResponse>(KeysClaimResponse::class.java)
val json_response = adapter.toJson(response)!!
olmMachine!!.markRequestAsSent(request.requestId, RequestType.KEYS_CLAIM, json_response)
}
}
}
}
// private suspend fun encrypt(roomId: String, eventType: String, content: Content) {
// }
private suspend fun sendOutgoingRequests() {
// TODO these requests should be sent out in parallel
for (outgoingRequest in olmMachine!!.outgoingRequests()) {

View File

@ -50,56 +50,56 @@ internal class EnsureOlmSessionsForDevicesAction @Inject constructor(
}
}
if (devicesWithoutSession.size == 0) {
return results
}
//if (devicesWithoutSession.size == 0) {
// return results
//}
// Prepare the request for claiming one-time keys
val usersDevicesToClaim = MXUsersDevicesMap<String>()
//// Prepare the request for claiming one-time keys
//val usersDevicesToClaim = MXUsersDevicesMap<String>()
val oneTimeKeyAlgorithm = MXKey.KEY_SIGNED_CURVE_25519_TYPE
//val oneTimeKeyAlgorithm = MXKey.KEY_SIGNED_CURVE_25519_TYPE
for (device in devicesWithoutSession) {
usersDevicesToClaim.setObject(device.userId, device.deviceId, oneTimeKeyAlgorithm)
}
//for (device in devicesWithoutSession) {
// usersDevicesToClaim.setObject(device.userId, device.deviceId, oneTimeKeyAlgorithm)
//}
// TODO: this has a race condition - if we try to send another message
// while we are claiming a key, we will end up claiming two and setting up
// two sessions.
//
// That should eventually resolve itself, but it's poor form.
//// TODO: this has a race condition - if we try to send another message
//// while we are claiming a key, we will end up claiming two and setting up
//// two sessions.
////
//// That should eventually resolve itself, but it's poor form.
Timber.i("## CRYPTO | claimOneTimeKeysForUsersDevices() : $usersDevicesToClaim")
//Timber.i("## CRYPTO | claimOneTimeKeysForUsersDevices() : $usersDevicesToClaim")
val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim)
val oneTimeKeys = oneTimeKeysForUsersDeviceTask.execute(claimParams)
Timber.v("## CRYPTO | claimOneTimeKeysForUsersDevices() : keysClaimResponse.oneTimeKeys: $oneTimeKeys")
for ((userId, deviceInfos) in devicesByUser) {
for (deviceInfo in deviceInfos) {
var oneTimeKey: MXKey? = null
val deviceIds = oneTimeKeys.getUserDeviceIds(userId)
if (null != deviceIds) {
for (deviceId in deviceIds) {
val olmSessionResult = results.getObject(userId, deviceId)
if (olmSessionResult!!.sessionId != null && !force) {
// We already have a result for this device
continue
}
val key = oneTimeKeys.getObject(userId, deviceId)
if (key?.type == oneTimeKeyAlgorithm) {
oneTimeKey = key
}
if (oneTimeKey == null) {
Timber.w("## CRYPTO | ensureOlmSessionsForDevices() : No one-time keys " + oneTimeKeyAlgorithm
+ " for device " + userId + " : " + deviceId)
continue
}
// Update the result for this device in results
olmSessionResult.sessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo)
}
}
}
}
//val claimParams = ClaimOneTimeKeysForUsersDeviceTask.Params(usersDevicesToClaim)
//val oneTimeKeys = oneTimeKeysForUsersDeviceTask.execute(claimParams)
//Timber.v("## CRYPTO | claimOneTimeKeysForUsersDevices() : keysClaimResponse.oneTimeKeys: $oneTimeKeys")
//for ((userId, deviceInfos) in devicesByUser) {
// for (deviceInfo in deviceInfos) {
// var oneTimeKey: MXKey? = null
// val deviceIds = oneTimeKeys.getUserDeviceIds(userId)
// if (null != deviceIds) {
// for (deviceId in deviceIds) {
// val olmSessionResult = results.getObject(userId, deviceId)
// if (olmSessionResult!!.sessionId != null && !force) {
// // We already have a result for this device
// continue
// }
// val key = oneTimeKeys.getObject(userId, deviceId)
// if (key?.type == oneTimeKeyAlgorithm) {
// oneTimeKey = key
// }
// if (oneTimeKey == null) {
// Timber.w("## CRYPTO | ensureOlmSessionsForDevices() : No one-time keys " + oneTimeKeyAlgorithm
// + " for device " + userId + " : " + deviceId)
// continue
// }
// // Update the result for this device in results
// olmSessionResult.sessionId = verifyKeyAndStartSession(oneTimeKey, userId, deviceInfo)
// }
// }
// }
//}
return results
}

View File

@ -24,6 +24,11 @@ import com.squareup.moshi.JsonClass
*/
@JsonClass(generateAdapter = true)
internal data class KeysClaimResponse(
/// If any remote homeservers could not be reached, they are recorded here.
/// The names of the properties are the names of the unreachable servers.
@Json(name = "failures")
val failures: Map<String, Any>,
/**
* The requested keys ordered by device by user.
* TODO Type does not match spec, should be Map<String, JsonDict>

View File

@ -27,10 +27,10 @@ import org.matrix.android.sdk.internal.task.Task
import timber.log.Timber
import javax.inject.Inject
internal interface ClaimOneTimeKeysForUsersDeviceTask : Task<ClaimOneTimeKeysForUsersDeviceTask.Params, MXUsersDevicesMap<MXKey>> {
internal interface ClaimOneTimeKeysForUsersDeviceTask : Task<ClaimOneTimeKeysForUsersDeviceTask.Params, KeysClaimResponse> {
data class Params(
// a list of users, devices and key types to retrieve keys for.
val usersDevicesKeyTypesMap: MXUsersDevicesMap<String>
val usersDevicesKeyTypesMap: Map<String, Map<String, String>>
)
}
@ -39,26 +39,11 @@ internal class DefaultClaimOneTimeKeysForUsersDevice @Inject constructor(
private val globalErrorReceiver: GlobalErrorReceiver
) : ClaimOneTimeKeysForUsersDeviceTask {
override suspend fun execute(params: ClaimOneTimeKeysForUsersDeviceTask.Params): MXUsersDevicesMap<MXKey> {
val body = KeysClaimBody(oneTimeKeys = params.usersDevicesKeyTypesMap.map)
override suspend fun execute(params: ClaimOneTimeKeysForUsersDeviceTask.Params): KeysClaimResponse {
val body = KeysClaimBody(oneTimeKeys = params.usersDevicesKeyTypesMap)
val keysClaimResponse = executeRequest<KeysClaimResponse>(globalErrorReceiver) {
return executeRequest<KeysClaimResponse>(globalErrorReceiver) {
apiCall = cryptoApi.claimOneTimeKeysForUsersDevices(body)
}
val map = MXUsersDevicesMap<MXKey>()
keysClaimResponse.oneTimeKeys?.let { oneTimeKeys ->
for ((userId, mapByUserId) in oneTimeKeys) {
for ((deviceId, deviceKey) in mapByUserId) {
val mxKey = MXKey.from(deviceKey)
if (mxKey != null) {
map.setObject(userId, deviceId, mxKey)
} else {
Timber.e("## claimOneTimeKeysForUsersDevices : fail to create a MXKey")
}
}
}
}
return map
}
}

View File

@ -87,6 +87,10 @@ internal class OlmMachine(user_id: String, device_id: String, path: File) {
inner.outgoingRequests()
}
suspend fun getMissingSessions(users: List<String>): Request? = withContext(Dispatchers.IO) {
inner.getMissingSessions(users)
}
suspend fun updateTrackedUsers(users: List<String>) = withContext(Dispatchers.IO) {
inner.updateTrackedUsers(users)
}

View File

@ -10,10 +10,12 @@ use tokio::runtime::Runtime;
use matrix_sdk_common::{
api::r0::{
keys::{
claim_keys::Response as KeysClaimResponse, get_keys::Response as KeysQueryResponse,
claim_keys::{Request as KeysClaimRequest, Response as KeysClaimResponse},
get_keys::Response as KeysQueryResponse,
upload_keys::Response as KeysUploadResponse,
},
sync::sync_events::{DeviceLists as RumaDeviceLists, ToDevice},
to_device::send_event_to_device::Response as ToDeviceResponse,
},
assign,
deserialized_responses::events::{AlgorithmInfo, SyncMessageEvent},
@ -67,6 +69,7 @@ enum OwnedResponse {
KeysClaim(KeysClaimResponse),
KeysUpload(KeysUploadResponse),
KeysQuery(KeysQueryResponse),
ToDevice(ToDeviceResponse),
}
impl From<KeysClaimResponse> for OwnedResponse {
@ -87,18 +90,26 @@ impl From<KeysUploadResponse> for OwnedResponse {
}
}
impl From<ToDeviceResponse> for OwnedResponse {
fn from(response: ToDeviceResponse) -> Self {
OwnedResponse::ToDevice(response)
}
}
impl<'a> Into<IncomingResponse<'a>> for &'a OwnedResponse {
fn into(self) -> IncomingResponse<'a> {
match self {
OwnedResponse::KeysClaim(r) => IncomingResponse::KeysClaim(r),
OwnedResponse::KeysQuery(r) => IncomingResponse::KeysQuery(r),
OwnedResponse::KeysUpload(r) => IncomingResponse::KeysUpload(r),
OwnedResponse::ToDevice(r) => IncomingResponse::ToDevice(r),
}
}
}
pub enum RequestType {
KeysQuery,
KeysClaim,
KeysUpload,
ToDevice,
}
@ -130,6 +141,10 @@ pub enum Request {
request_id: String,
users: Vec<String>,
},
KeysClaim {
request_id: String,
one_time_keys: HashMap<String, HashMap<String, String>>,
},
}
impl From<OutgoingRequest> for Request {
@ -268,7 +283,8 @@ impl OlmMachine {
let response: OwnedResponse = match request_type {
RequestType::KeysUpload => KeysUploadResponse::try_from(response).map(Into::into),
RequestType::KeysQuery => KeysQueryResponse::try_from(response).map(Into::into),
RequestType::ToDevice => KeysClaimResponse::try_from(response).map(Into::into),
RequestType::ToDevice => ToDeviceResponse::try_from(response).map(Into::into),
RequestType::KeysClaim => KeysClaimResponse::try_from(response).map(Into::into),
}
.expect("Can't convert json string to response");
@ -342,6 +358,35 @@ impl OlmMachine {
.block_on(self.inner.update_tracked_users(users.iter()));
}
pub fn get_missing_sessions(
&self,
users: Vec<String>,
) -> Result<Option<Request>, CryptoStoreError> {
let users: Vec<UserId> = users
.into_iter()
.filter_map(|u| UserId::try_from(u).ok())
.collect();
Ok(self
.runtime
.block_on(self.inner.get_missing_sessions(users.iter()))?
.map(|(request_id, request)| Request::KeysClaim {
request_id: request_id.to_string(),
one_time_keys: request
.one_time_keys
.into_iter()
.map(|(u, d)| {
(
u.to_string(),
d.into_iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect(),
)
})
.collect(),
}))
}
pub fn decrypt_room_event(
&self,
event: &str,

View File

@ -55,10 +55,12 @@ interface Request {
ToDevice(string request_id, string event_type, string body);
KeysUpload(string request_id, string body);
KeysQuery(string request_id, sequence<string> users);
KeysClaim(string request_id, record<DOMString, record<DOMString, string>> one_time_keys);
};
enum RequestType {
"KeysQuery",
"KeysClaim",
"KeysUpload",
"ToDevice",
};
@ -85,6 +87,9 @@ interface OlmMachine {
sequence<Request> outgoing_requests();
void update_tracked_users(sequence<string> users);
[Throws=CryptoStoreError]
Request? get_missing_sessions(sequence<string> users);
[Throws=CryptoStoreError]
void mark_request_as_sent(
[ByRef] string request_id,