std.c: Refactor, bring semantics closer to spec.
Small refactor replacing the pattern of verifying inputs in
safe equivalents of stdlib functions with a CHECK-like macro
which adheres to the C11 semantics of filling the destination
buffer when a constraint is not satisfied at runtime.
Semantic changes:
memcpy_s: More permissive memory ranges. Used to check that source does
not overlap with the entire destination buffer, when only the first
`count` bytes matter.
memcpy_s: Used to allow `dest==src`. Safe under our implementation in
aarch64 but not allowed under C11.
strlen_s: Should return 'strsz' if NULL character not found.
Change-Id: If483a97e6ee1c64c7f2afed9a0af1d3087da7002
diff --git a/inc/hf/std.h b/inc/hf/std.h
index be3b168..1c4458c 100644
--- a/inc/hf/std.h
+++ b/inc/hf/std.h
@@ -31,7 +31,7 @@
* Only the safer versions of these functions are exposed to reduce the chance
* of misusing the versions without bounds checking or null pointer checks.
*
- * These functions don't return errno_t as per the specification and implicity
+ * These functions don't return errno_t as per the specification and implicitly
* have a constraint handler that panics.
*/
void memset_s(void *dest, rsize_t destsz, int ch, rsize_t count);
diff --git a/src/std.c b/src/std.c
index 9f829e8..a467384 100644
--- a/src/std.c
+++ b/src/std.c
@@ -23,25 +23,33 @@
void *memcpy(void *dst, const void *src, size_t n);
void *memmove(void *dst, const void *src, size_t n);
+/*
+ * As per the C11 specification, mem*_s() operations fill the destination buffer
+ * if runtime constraint validation fails, assuming that `dest` and `destsz`
+ * are both valid.
+ */
+#define CHECK_OR_FILL(cond, dest, destsz, ch) \
+ do { \
+ if (!(cond)) { \
+ if ((dest) != NULL && (destsz) <= RSIZE_MAX) { \
+ memset_s((dest), (destsz), (ch), (destsz)); \
+ } \
+ panic("%s failed: " #cond, __func__); \
+ } \
+ } while (0)
+
+#define CHECK_OR_ZERO_FILL(cond, dest, destsz) \
+ CHECK_OR_FILL(cond, dest, destsz, '\0')
+
void memset_s(void *dest, rsize_t destsz, int ch, rsize_t count)
{
- if (dest == NULL) {
- goto fail;
- }
+ CHECK_OR_FILL(dest != NULL, dest, destsz, ch);
- if (destsz > RSIZE_MAX || count > RSIZE_MAX) {
- goto fail;
- }
-
- if (count > destsz) {
- goto fail;
- }
+ /* Check count <= destsz <= RSIZE_MAX. */
+ CHECK_OR_FILL(destsz <= RSIZE_MAX, dest, destsz, ch);
+ CHECK_OR_FILL(count <= destsz, dest, destsz, ch);
memset(dest, ch, count);
- return;
-
-fail:
- panic("memset_s failure");
}
void memcpy_s(void *dest, rsize_t destsz, const void *src, rsize_t count)
@@ -49,69 +57,56 @@
uintptr_t d = (uintptr_t)dest;
uintptr_t s = (uintptr_t)src;
- if (dest == NULL || src == NULL) {
- goto fail;
- }
+ CHECK_OR_ZERO_FILL(dest != NULL, dest, destsz);
+ CHECK_OR_ZERO_FILL(src != NULL, dest, destsz);
- if (destsz > RSIZE_MAX || count > RSIZE_MAX) {
- goto fail;
- }
+ /* Check count <= destsz <= RSIZE_MAX. */
+ CHECK_OR_ZERO_FILL(destsz <= RSIZE_MAX, dest, destsz);
+ CHECK_OR_ZERO_FILL(count <= destsz, dest, destsz);
- if (count > destsz) {
- goto fail;
- }
-
- /* Destination overlaps the end of source. */
- if (d > s && d < (s + count)) {
- goto fail;
- }
-
- /* Source overlaps the end of destination. */
- if (s > d && s < (d + destsz)) {
- goto fail;
- }
-
- /* TODO: consider wrapping? */
+ /*
+ * Buffer overlap test.
+ * case a) `d < s` implies `s >= d+count`
+ * case b) `d > s` implies `d >= s+count`
+ */
+ CHECK_OR_ZERO_FILL(d != s, dest, destsz);
+ CHECK_OR_ZERO_FILL(d < s || d >= (s + count), dest, destsz);
+ CHECK_OR_ZERO_FILL(d > s || s >= (d + count), dest, destsz);
memcpy(dest, src, count);
- return;
-
-fail:
- panic("memcpy_s failure");
}
void memmove_s(void *dest, rsize_t destsz, const void *src, rsize_t count)
{
- if (dest == NULL || src == NULL) {
- goto fail;
- }
+ CHECK_OR_ZERO_FILL(dest != NULL, dest, destsz);
+ CHECK_OR_ZERO_FILL(src != NULL, dest, destsz);
- if (destsz > RSIZE_MAX || count > RSIZE_MAX) {
- goto fail;
- }
-
- if (count > destsz) {
- goto fail;
- }
+ /* Check count <= destsz <= RSIZE_MAX. */
+ CHECK_OR_ZERO_FILL(destsz <= RSIZE_MAX, dest, destsz);
+ CHECK_OR_ZERO_FILL(count <= destsz, dest, destsz);
memmove(dest, src, count);
- return;
-
-fail:
- panic("memmove_s failure");
}
+/**
+ * Returns the length of the null-terminated byte string `str`, examining at
+ * most `strsz` bytes.
+ *
+ * If `str` is a NULL pointer, it returns zero.
+ * If a NULL character is not found, it returns `strsz`.
+ */
size_t strnlen_s(const char *str, size_t strsz)
{
- const char *p = str;
-
if (str == NULL) {
return 0;
}
- while (*p && strsz--) {
- p++;
+ for (size_t i = 0; i < strsz; ++i) {
+ if (str[i] == '\0') {
+ return i;
+ }
}
- return p - str;
+ /* NULL character not found. */
+ return strsz;
}