Roll back page table updates if TEE fails to complete memory send.

Bug: 132429380
Change-Id: I969adb14bf34cd13aeb72d4f7abdd05c3701be10
diff --git a/src/ffa_memory.c b/src/ffa_memory.c
index 849a7cb..89c372d 100644
--- a/src/ffa_memory.c
+++ b/src/ffa_memory.c
@@ -862,7 +862,7 @@
 	struct ffa_memory_region_constituent **fragments,
 	uint32_t *fragment_constituent_counts, uint32_t fragment_count,
 	uint32_t share_func, ffa_memory_access_permissions_t permissions,
-	struct mpool *page_pool, bool clear)
+	struct mpool *page_pool, bool clear, uint32_t *orig_from_mode_ret)
 {
 	struct vm *from = from_locked.vm;
 	uint32_t i;
@@ -896,6 +896,10 @@
 		return ret;
 	}
 
+	if (orig_from_mode_ret != NULL) {
+		*orig_from_mode_ret = orig_from_mode;
+	}
+
 	/*
 	 * Create a local pool so any freed memory can't be used by another
 	 * thread. This is to ensure the original mapping can be restored if the
@@ -1270,7 +1274,8 @@
  */
 static struct ffa_value ffa_memory_send_complete(
 	struct vm_locked from_locked, struct share_states_locked share_states,
-	struct ffa_memory_share_state *share_state, struct mpool *page_pool)
+	struct ffa_memory_share_state *share_state, struct mpool *page_pool,
+	uint32_t *orig_from_mode_ret)
 {
 	struct ffa_memory_region *memory_region = share_state->memory_region;
 	struct ffa_value ret;
@@ -1284,7 +1289,8 @@
 		share_state->fragment_constituent_counts,
 		share_state->fragment_count, share_state->share_func,
 		memory_region->receivers[0].receiver_permissions.permissions,
-		page_pool, memory_region->flags & FFA_MEMORY_REGION_FLAG_CLEAR);
+		page_pool, memory_region->flags & FFA_MEMORY_REGION_FLAG_CLEAR,
+		orig_from_mode_ret);
 	if (ret.func != FFA_SUCCESS_32) {
 		/*
 		 * Free share state, it failed to send so it can't be retrieved.
@@ -1636,7 +1642,7 @@
 	if (fragment_length == memory_share_length) {
 		/* No more fragments to come, everything fit in one message. */
 		ret = ffa_memory_send_complete(from_locked, share_states,
-					       share_state, page_pool);
+					       share_state, page_pool, NULL);
 	} else {
 		ret = (struct ffa_value){
 			.func = FFA_MEM_FRAG_RX_32,
@@ -1691,13 +1697,22 @@
 			ffa_memory_region_get_composite(memory_region, 0);
 		struct ffa_memory_region_constituent *constituents =
 			composite->constituents;
+		struct mpool local_page_pool;
+		uint32_t orig_from_mode;
+
+		/*
+		 * Use a local page pool so that we can roll back if necessary.
+		 */
+		mpool_init_with_fallback(&local_page_pool, page_pool);
 
 		ret = ffa_send_check_update(
 			from_locked, &constituents,
 			&composite->constituent_count, 1, share_func,
-			permissions, page_pool,
-			memory_region->flags & FFA_MEMORY_REGION_FLAG_CLEAR);
+			permissions, &local_page_pool,
+			memory_region->flags & FFA_MEMORY_REGION_FLAG_CLEAR,
+			&orig_from_mode);
 		if (ret.func != FFA_SUCCESS_32) {
+			mpool_fini(&local_page_pool);
 			goto out;
 		}
 
@@ -1705,6 +1720,27 @@
 		ret = memory_send_tee_forward(
 			to_locked, from_locked.vm->id, share_func,
 			memory_region, memory_share_length, fragment_length);
+
+		if (ret.func != FFA_SUCCESS_32) {
+			dlog_verbose(
+				"TEE didn't successfully complete memory send "
+				"operation; returned %#x (%d). Rolling back.\n",
+				ret.func, ret.arg2);
+
+			/*
+			 * The TEE failed to complete the send operation, so
+			 * roll back the page table update for the VM. This
+			 * can't fail because it won't try to allocate more
+			 * memory than was freed into the `local_page_pool` by
+			 * `ffa_send_check_update` in the initial update.
+			 */
+			CHECK(ffa_region_group_identity_map(
+				from_locked, &constituents,
+				&composite->constituent_count, 1,
+				orig_from_mode, &local_page_pool, true));
+		}
+
+		mpool_fini(&local_page_pool);
 	} else {
 		struct share_states_locked share_states = share_states_lock();
 		ffa_memory_handle_t handle;
@@ -1824,7 +1860,7 @@
 	/* Check whether the memory send operation is now ready to complete. */
 	if (share_state_sending_complete(share_states, share_state)) {
 		ret = ffa_memory_send_complete(from_locked, share_states,
-					       share_state, page_pool);
+					       share_state, page_pool, NULL);
 	} else {
 		ret = (struct ffa_value){
 			.func = FFA_MEM_FRAG_RX_32,
@@ -1908,8 +1944,17 @@
 
 	/* Check whether the memory send operation is now ready to complete. */
 	if (share_state_sending_complete(share_states, share_state)) {
+		struct mpool local_page_pool;
+		uint32_t orig_from_mode;
+
+		/*
+		 * Use a local page pool so that we can roll back if necessary.
+		 */
+		mpool_init_with_fallback(&local_page_pool, page_pool);
+
 		ret = ffa_memory_send_complete(from_locked, share_states,
-					       share_state, page_pool);
+					       share_state, &local_page_pool,
+					       &orig_from_mode);
 
 		if (ret.func == FFA_SUCCESS_32) {
 			/*
@@ -1928,9 +1973,27 @@
 				dlog_verbose(
 					"TEE didn't successfully complete "
 					"memory send operation; returned %#x "
-					"(%d).\n",
+					"(%d). Rolling back.\n",
 					ret.func, ret.arg2);
+
+				/*
+				 * The TEE failed to complete the send
+				 * operation, so roll back the page table update
+				 * for the VM. This can't fail because it won't
+				 * try to allocate more memory than was freed
+				 * into the `local_page_pool` by
+				 * `ffa_send_check_update` in the initial
+				 * update.
+				 */
+				CHECK(ffa_region_group_identity_map(
+					from_locked, share_state->fragments,
+					share_state
+						->fragment_constituent_counts,
+					share_state->fragment_count,
+					orig_from_mode, &local_page_pool,
+					true));
 			}
+
 			/* Free share state. */
 			share_state_free(share_states, share_state, page_pool);
 		} else {
@@ -1957,6 +2020,8 @@
 			 * because ffa_memory_send_complete does that already.
 			 */
 		}
+
+		mpool_fini(&local_page_pool);
 	} else {
 		uint32_t next_fragment_offset =
 			share_state_next_fragment_offset(share_states,