Implement mm_ptable_defrag to collapse the page table hierarchy when possible.

Bug: 115321017
Change-Id: I0ad21d7a6dbb762d29e62c283d466886afad184e
diff --git a/src/arch/aarch64/inc/hf/arch/mm.h b/src/arch/aarch64/inc/hf/arch/mm.h
index 3437b71..8f7028d 100644
--- a/src/arch/aarch64/inc/hf/arch/mm.h
+++ b/src/arch/aarch64/inc/hf/arch/mm.h
@@ -203,3 +203,11 @@
 uint64_t arch_mm_mode_to_attrs(int mode);
 bool arch_mm_init(paddr_t table, bool first);
 int arch_mm_max_level(int mode);
+
+/**
+ * Given the attrs from a table at some level and the attrs from all the blocks
+ * in that table, return equivalent attrs to use for a block which will replace
+ * the entire table.
+ */
+uint64_t arch_mm_combine_table_entry_attrs(uint64_t table_attrs,
+					   uint64_t block_attrs);
diff --git a/src/arch/aarch64/mm.c b/src/arch/aarch64/mm.c
index 303814e..a416a82 100644
--- a/src/arch/aarch64/mm.c
+++ b/src/arch/aarch64/mm.c
@@ -30,11 +30,14 @@
 #define INNER_SHAREABLE UINT64_C(3)
 
 #define STAGE1_XN          (UINT64_C(1) << 54)
+#define STAGE1_PXN         (UINT64_C(1) << 53)
 #define STAGE1_CONTIGUOUS  (UINT64_C(1) << 52)
 #define STAGE1_DBM         (UINT64_C(1) << 51)
 #define STAGE1_NG          (UINT64_C(1) << 11)
 #define STAGE1_AF          (UINT64_C(1) << 10)
 #define STAGE1_SH(x)       ((x) << 8)
+#define STAGE1_AP2         (UINT64_C(1) << 7)
+#define STAGE1_AP1         (UINT64_C(1) << 6)
 #define STAGE1_AP(x)       ((x) << 6)
 #define STAGE1_NS          (UINT64_C(1) << 5)
 #define STAGE1_ATTRINDX(x) ((x) << 2)
@@ -58,6 +61,13 @@
 #define STAGE2_EXECUTE_NONE UINT64_C(2)
 #define STAGE2_EXECUTE_EL1  UINT64_C(3)
 
+/* Table attributes only apply to stage 1 translations. */
+#define TABLE_NSTABLE  (UINT64_C(1) << 63)
+#define TABLE_APTABLE1 (UINT64_C(1) << 62)
+#define TABLE_APTABLE0 (UINT64_C(1) << 61)
+#define TABLE_XNTABLE  (UINT64_C(1) << 60)
+#define TABLE_PXNTABLE (UINT64_C(1) << 59)
+
 /* The following are stage-2 memory attributes for normal memory. */
 #define STAGE2_NONCACHEABLE UINT64_C(1)
 #define STAGE2_WRITETHROUGH UINT64_C(2)
@@ -258,3 +268,28 @@
 
 	return true;
 }
+
+uint64_t arch_mm_combine_table_entry_attrs(uint64_t table_attrs,
+					   uint64_t block_attrs)
+{
+	/*
+	 * Only stage 1 table descriptors have attributes, but the bits are res0
+	 * for stage 2 table descriptors so this code is safe for both.
+	 */
+	if (table_attrs & TABLE_NSTABLE) {
+		block_attrs |= STAGE1_NS;
+	}
+	if (table_attrs & TABLE_APTABLE1) {
+		block_attrs |= STAGE1_AP2;
+	}
+	if (table_attrs & TABLE_APTABLE0) {
+		block_attrs &= ~STAGE1_AP1;
+	}
+	if (table_attrs & TABLE_XNTABLE) {
+		block_attrs |= STAGE1_XN;
+	}
+	if (table_attrs & TABLE_PXNTABLE) {
+		block_attrs |= STAGE1_PXN;
+	}
+	return block_attrs;
+}
diff --git a/src/mm.c b/src/mm.c
index fa8d52c..3538f45 100644
--- a/src/mm.c
+++ b/src/mm.c
@@ -24,6 +24,13 @@
 #include "hf/dlog.h"
 #include "hf/layout.h"
 
+/**
+ * This file has functions for managing the level 1 and 2 page tables used by
+ * Hafnium. There is a level 1 mapping used by Hafnium itself to access memory,
+ * and then a level 2 mapping per VM. The design assumes that all page tables
+ * contain only 1-1 mappings, aligned on the block boundaries.
+ */
+
 /* The type of addresses stored in the page table. */
 typedef uintvaddr_t ptable_addr_t;
 
@@ -44,9 +51,20 @@
 
 /* clang-format on */
 
+#define NUM_ENTRIES (PAGE_SIZE / sizeof(pte_t))
+
 static struct mm_ptable ptable;
 
 /**
+ * Casts a physical address to a pointer. This assumes that it is mapped (to the
+ * same address), so should only be used within the mm code.
+ */
+static inline void *ptr_from_pa(paddr_t pa)
+{
+	return ptr_from_va(va_from_pa(pa));
+}
+
+/**
  * Rounds an address down to a page boundary.
  */
 static ptable_addr_t mm_round_down_to_page(ptable_addr_t addr)
@@ -108,7 +126,7 @@
 
 	/* Just return pointer to table if it's already populated. */
 	if (arch_mm_pte_is_table(v, level)) {
-		return ptr_from_va(va_from_pa(arch_mm_table_from_pte(v)));
+		return ptr_from_pa(arch_mm_table_from_pte(v));
 	}
 
 	/* Allocate a new table. */
@@ -131,7 +149,7 @@
 	}
 
 	/* Initialise entries in the new table. */
-	for (i = 0; i < PAGE_SIZE / sizeof(paddr_t); i++) {
+	for (i = 0; i < NUM_ENTRIES; i++) {
 		ntable[i] = new_pte;
 		new_pte += inc;
 	}
@@ -159,9 +177,9 @@
 		return;
 	}
 
-	table = ptr_from_va(va_from_pa(arch_mm_table_from_pte(pte)));
+	table = ptr_from_pa(arch_mm_table_from_pte(pte));
 	/* Recursively free any subtables. */
-	for (i = 0; i < PAGE_SIZE / sizeof(pte_t); ++i) {
+	for (i = 0; i < NUM_ENTRIES; ++i) {
 		mm_free_page_pte(table[i], level - 1);
 	}
 
@@ -247,7 +265,7 @@
 	uint64_t attrs = arch_mm_mode_to_attrs(mode);
 	int flags = (mode & MM_MODE_NOSYNC) ? 0 : MAP_FLAG_SYNC;
 	int level = arch_mm_max_level(mode);
-	pte_t *table = ptr_from_va(va_from_pa(t->table));
+	pte_t *table = ptr_from_pa(t->table);
 	ptable_addr_t begin;
 	ptable_addr_t end;
 
@@ -284,7 +302,7 @@
 {
 	int flags = (mode & MM_MODE_NOSYNC) ? 0 : MAP_FLAG_SYNC;
 	int level = arch_mm_max_level(mode);
-	pte_t *table = ptr_from_va(va_from_pa(t->table));
+	pte_t *table = ptr_from_pa(t->table);
 	ptable_addr_t begin;
 	ptable_addr_t end;
 
@@ -318,7 +336,7 @@
 {
 	size_t i;
 	uint64_t attrs = arch_mm_mode_to_attrs(mode);
-	pte_t *table = ptr_from_va(va_from_pa(t->table));
+	pte_t *table = ptr_from_pa(t->table);
 	bool sync = !(mode & MM_MODE_NOSYNC);
 	ptable_addr_t addr;
 
@@ -345,7 +363,7 @@
 static void mm_dump_table_recursive(pte_t *table, int level, int max_level)
 {
 	uint64_t i;
-	for (i = 0; i < PAGE_SIZE / sizeof(pte_t); i++) {
+	for (i = 0; i < NUM_ENTRIES; i++) {
 		if (!arch_mm_pte_is_present(table[i], level)) {
 			continue;
 		}
@@ -366,20 +384,136 @@
  */
 void mm_ptable_dump(struct mm_ptable *t, int mode)
 {
-	pte_t *table = ptr_from_va(va_from_pa(t->table));
+	pte_t *table = ptr_from_pa(t->table);
 	int max_level = arch_mm_max_level(mode);
 	mm_dump_table_recursive(table, max_level, max_level);
 }
 
 /**
+ * Given that `entry` is a subtable but its entries are all absent, return the
+ * absent entry with which it can be replaced. Note that `entry` will no longer
+ * be valid after calling this function as the subtable will have been freed.
+ */
+static pte_t mm_table_pte_to_absent(pte_t entry, int level)
+{
+	pte_t *subtable = ptr_from_pa(arch_mm_table_from_pte(entry));
+	/*
+	 * Free the subtable. This is safe to do directly (rather than
+	 * using mm_free_page_pte) because we know by this point that it
+	 * doesn't have any subtables of its own.
+	 */
+	hfree(subtable);
+	/* Replace subtable with a single absent entry. */
+	return arch_mm_absent_pte(level);
+}
+
+/**
+ * Given that `entry` is a subtable and its entries are all identical, return
+ * the single block entry with which it can be replaced if possible. Note that
+ * `entry` will no longer be valid after calling this function as the subtable
+ * may have been freed.
+ */
+static pte_t mm_table_pte_to_block(pte_t entry, int level)
+{
+	pte_t *subtable;
+	uint64_t block_attrs;
+	uint64_t table_attrs;
+	uint64_t combined_attrs;
+	paddr_t block_address;
+
+	if (!arch_mm_is_block_allowed(level)) {
+		return entry;
+	}
+
+	subtable = ptr_from_pa(arch_mm_table_from_pte(entry));
+	/*
+	 * Replace subtable with a single block, with equivalent
+	 * attributes.
+	 */
+	block_attrs = arch_mm_pte_attrs(subtable[0]);
+	table_attrs = arch_mm_pte_attrs(entry);
+	combined_attrs =
+		arch_mm_combine_table_entry_attrs(table_attrs, block_attrs);
+	block_address = arch_mm_block_from_pte(subtable[0]);
+	/* Free the subtable. */
+	hfree(subtable);
+	/*
+	 * We can assume that the block is aligned properly
+	 * because all virtual addresses are aligned by
+	 * definition, and we have a 1-1 mapping from virtual to
+	 * physical addresses.
+	 */
+	return arch_mm_block_pte(level, block_address, combined_attrs);
+}
+
+/**
+ * Defragment the given ptable entry by recursively replacing any tables with
+ * block or absent entries where possible.
+ */
+static pte_t mm_ptable_defrag_entry(pte_t entry, int level)
+{
+	pte_t *table;
+	uint64_t i;
+	uint64_t attrs;
+	bool identical_blocks_so_far = true;
+	bool all_absent_so_far = true;
+
+	if (!arch_mm_pte_is_table(entry, level)) {
+		return entry;
+	}
+
+	table = ptr_from_pa(arch_mm_table_from_pte(entry));
+
+	/*
+	 * Check if all entries are blocks with the same flags or are all
+	 * absent.
+	 */
+	attrs = arch_mm_pte_attrs(table[0]);
+	for (i = 0; i < NUM_ENTRIES; ++i) {
+		/*
+		 * First try to defrag the entry, in case it is a subtable.
+		 */
+		table[i] = mm_ptable_defrag_entry(table[i], level - 1);
+
+		if (arch_mm_pte_is_present(table[i], level - 1)) {
+			all_absent_so_far = false;
+		}
+
+		/*
+		 * If the entry is a block, check that the flags are the same as
+		 * what we have so far.
+		 */
+		if (!arch_mm_pte_is_block(table[i], level - 1) ||
+		    arch_mm_pte_attrs(table[i]) != attrs) {
+			identical_blocks_so_far = false;
+		}
+	}
+	if (identical_blocks_so_far) {
+		return mm_table_pte_to_block(entry, level);
+	}
+	if (all_absent_so_far) {
+		return mm_table_pte_to_absent(entry, level);
+	}
+	return entry;
+}
+
+/**
  * Defragments the given page table by converting page table references to
  * blocks whenever possible.
  */
 void mm_ptable_defrag(struct mm_ptable *t, int mode)
 {
-	/* TODO: Implement. */
-	(void)t;
-	(void)mode;
+	pte_t *table = ptr_from_pa(t->table);
+	int level = arch_mm_max_level(mode);
+	uint64_t i;
+
+	/*
+	 * Loop through each entry in the table. If it points to another table,
+	 * check if that table can be replaced by a block or an absent entry.
+	 */
+	for (i = 0; i < NUM_ENTRIES; ++i) {
+		table[i] = mm_ptable_defrag_entry(table[i], level);
+	}
 }
 
 /**
@@ -418,8 +552,8 @@
 
 	if (arch_mm_pte_is_table(pte, level)) {
 		return mm_is_mapped_recursive(
-			ptr_from_va(va_from_pa(arch_mm_table_from_pte(pte))),
-			addr, level - 1);
+			ptr_from_pa(arch_mm_table_from_pte(pte)), addr,
+			level - 1);
 	}
 
 	/* The entry is not present. */
@@ -432,7 +566,7 @@
 static bool mm_ptable_is_mapped(struct mm_ptable *t, ptable_addr_t addr,
 				int mode)
 {
-	pte_t *table = ptr_from_va(va_from_pa(t->table));
+	pte_t *table = ptr_from_pa(t->table);
 	int level = arch_mm_max_level(mode);
 
 	addr = mm_round_down_to_page(addr);
@@ -458,7 +592,7 @@
 		return false;
 	}
 
-	for (i = 0; i < PAGE_SIZE / sizeof(pte_t); i++) {
+	for (i = 0; i < NUM_ENTRIES; i++) {
 		table[i] = arch_mm_absent_pte(arch_mm_max_level(mode));
 	}
 
@@ -548,7 +682,7 @@
 {
 	if (mm_ptable_identity_map(&ptable, begin, end,
 				   mode | MM_MODE_STAGE1)) {
-		return ptr_from_va(va_from_pa(begin));
+		return ptr_from_pa(begin);
 	}
 
 	return NULL;