#ifndef OSSL_INTERNAL_SAFE_MATH_H
#define OSSL_INTERNAL_SAFE_MATH_H
#pragma once
#include <openssl/e_os2.h>
#ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING
#ifdef __has_builtin
#define has(func) __has_builtin(func)
#elif defined(__GNUC__)
#if __GNUC__ > 5
#define has(func) 1
#endif
#endif
#endif
#ifndef has
#define has(func) 0
#endif
#if has(__builtin_add_overflow)
#define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_add_##type_name(type a, \
type b, \
int *err) \
{ \
type r; \
\
if (!__builtin_add_overflow(a, b, &r)) \
return r; \
*err |= 1; \
return a < 0 ? min : max; \
}
#define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
static ossl_inline ossl_unused type safe_add_##type_name(type a, \
type b, \
int *err) \
{ \
type r; \
\
if (!__builtin_add_overflow(a, b, &r)) \
return r; \
*err |= 1; \
return a + b; \
}
#else
#define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_add_##type_name(type a, \
type b, \
int *err) \
{ \
if ((a < 0) ^ (b < 0) \
|| (a > 0 && b <= max - a) \
|| (a < 0 && b >= min - a) \
|| a == 0) \
return a + b; \
*err |= 1; \
return a < 0 ? min : max; \
}
#define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
static ossl_inline ossl_unused type safe_add_##type_name(type a, \
type b, \
int *err) \
{ \
if (b > max - a) \
*err |= 1; \
return a + b; \
}
#endif
#if has(__builtin_sub_overflow)
#define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_sub_##type_name(type a, \
type b, \
int *err) \
{ \
type r; \
\
if (!__builtin_sub_overflow(a, b, &r)) \
return r; \
*err |= 1; \
return a < 0 ? min : max; \
}
#else
#define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_sub_##type_name(type a, \
type b, \
int *err) \
{ \
if (!((a < 0) ^ (b < 0)) \
|| (b > 0 && a >= min + b) \
|| (b < 0 && a <= max + b) \
|| b == 0) \
return a - b; \
*err |= 1; \
return a < 0 ? min : max; \
}
#endif
#define OSSL_SAFE_MATH_SUBU(type_name, type) \
static ossl_inline ossl_unused type safe_sub_##type_name(type a, \
type b, \
int *err) \
{ \
if (b > a) \
*err |= 1; \
return a - b; \
}
#if has(__builtin_mul_overflow)
#define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_mul_##type_name(type a, \
type b, \
int *err) \
{ \
type r; \
\
if (!__builtin_mul_overflow(a, b, &r)) \
return r; \
*err |= 1; \
return (a < 0) ^ (b < 0) ? min : max; \
}
#define OSSL_SAFE_MATH_MULU(type_name, type, max) \
static ossl_inline ossl_unused type safe_mul_##type_name(type a, \
type b, \
int *err) \
{ \
type r; \
\
if (!__builtin_mul_overflow(a, b, &r)) \
return r; \
*err |= 1; \
return a * b; \
}
#else
#define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_mul_##type_name(type a, \
type b, \
int *err) \
{ \
if (a == 0 || b == 0) \
return 0; \
if (a == 1) \
return b; \
if (b == 1) \
return a; \
if (a != min && b != min) { \
const type x = a < 0 ? -a : a; \
const type y = b < 0 ? -b : b; \
\
if (x <= max / y) \
return a * b; \
} \
*err |= 1; \
return (a < 0) ^ (b < 0) ? min : max; \
}
#define OSSL_SAFE_MATH_MULU(type_name, type, max) \
static ossl_inline ossl_unused type safe_mul_##type_name(type a, \
type b, \
int *err) \
{ \
if (b != 0 && a > max / b) \
*err |= 1; \
return a * b; \
}
#endif
#define OSSL_SAFE_MATH_DIVS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_div_##type_name(type a, \
type b, \
int *err) \
{ \
if (b == 0) { \
*err |= 1; \
return a < 0 ? min : max; \
} \
if (b == -1 && a == min) { \
*err |= 1; \
return max; \
} \
return a / b; \
}
#define OSSL_SAFE_MATH_DIVU(type_name, type, max) \
static ossl_inline ossl_unused type safe_div_##type_name(type a, \
type b, \
int *err) \
{ \
if (b != 0) \
return a / b; \
*err |= 1; \
return max; \
}
#define OSSL_SAFE_MATH_MODS(type_name, type, min, max) \
static ossl_inline ossl_unused type safe_mod_##type_name(type a, \
type b, \
int *err) \
{ \
if (b == 0) { \
*err |= 1; \
return 0; \
} \
if (b == -1 && a == min) { \
*err |= 1; \
return max; \
} \
return a % b; \
}
#define OSSL_SAFE_MATH_MODU(type_name, type) \
static ossl_inline ossl_unused type safe_mod_##type_name(type a, \
type b, \
int *err) \
{ \
if (b != 0) \
return a % b; \
*err |= 1; \
return 0; \
}
#define OSSL_SAFE_MATH_NEGS(type_name, type, min) \
static ossl_inline ossl_unused type safe_neg_##type_name(type a, \
int *err) \
{ \
if (a != min) \
return -a; \
*err |= 1; \
return min; \
}
#define OSSL_SAFE_MATH_NEGU(type_name, type) \
static ossl_inline ossl_unused type safe_neg_##type_name(type a, \
int *err) \
{ \
if (a == 0) \
return a; \
*err |= 1; \
return 1 + ~a; \
}
#define OSSL_SAFE_MATH_ABSS(type_name, type, min) \
static ossl_inline ossl_unused type safe_abs_##type_name(type a, \
int *err) \
{ \
if (a != min) \
return a < 0 ? -a : a; \
*err |= 1; \
return min; \
}
#define OSSL_SAFE_MATH_ABSU(type_name, type) \
static ossl_inline ossl_unused type safe_abs_##type_name(type a, \
int *err) \
{ \
return a; \
}
#define OSSL_SAFE_MATH_MULDIVS(type_name, type, max) \
static ossl_inline ossl_unused type safe_muldiv_##type_name(type a, \
type b, \
type c, \
int *err) \
{ \
int e2 = 0; \
type q, r, x, y; \
\
if (c == 0) { \
*err |= 1; \
return a == 0 || b == 0 ? 0 : max; \
} \
x = safe_mul_##type_name(a, b, &e2); \
if (!e2) \
return safe_div_##type_name(x, c, err); \
if (b > a) { \
x = b; \
b = a; \
a = x; \
} \
q = safe_div_##type_name(a, c, err); \
r = safe_mod_##type_name(a, c, err); \
x = safe_mul_##type_name(r, b, err); \
y = safe_mul_##type_name(q, b, err); \
q = safe_div_##type_name(x, c, err); \
return safe_add_##type_name(y, q, err); \
}
#define OSSL_SAFE_MATH_MULDIVU(type_name, type, max) \
static ossl_inline ossl_unused type safe_muldiv_##type_name(type a, \
type b, \
type c, \
int *err) \
{ \
int e2 = 0; \
type x, y; \
\
if (c == 0) { \
*err |= 1; \
return a == 0 || b == 0 ? 0 : max; \
} \
x = safe_mul_##type_name(a, b, &e2); \
if (!e2) \
return x / c; \
if (b > a) { \
x = b; \
b = a; \
a = x; \
} \
x = safe_mul_##type_name(a % c, b, err); \
y = safe_mul_##type_name(a / c, b, err); \
return safe_add_##type_name(y, x / c, err); \
}
#define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max) \
static ossl_inline ossl_unused type safe_div_round_up_##type_name(type a, type b, int *errp) \
{ \
type x; \
int *err, err_local = 0; \
\
\
err = errp != NULL ? errp : &err_local; \
\
if (b > 0 && a > 0) { \
\
if (a < max - b) \
return (a + b - 1) / b; \
return a / b + (a % b != 0); \
} \
if (b == 0) { \
*err |= 1; \
return a == 0 ? 0 : max; \
} \
if (a == 0) \
return 0; \
\
x = safe_mod_##type_name(a, b, err); \
return safe_add_##type_name(safe_div_##type_name(a, b, err), \
x != 0, err); \
}
#define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1))
#define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type))
#define OSSL_SAFE_MATH_MAXU(type) (~(type)0)
#define OSSL_SAFE_MATH_SIGNED(type_name, type) \
OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \
OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type)) \
OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type)) \
OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type))
#define OSSL_SAFE_MATH_UNSIGNED(type_name, type) \
OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
OSSL_SAFE_MATH_SUBU(type_name, type) \
OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
OSSL_SAFE_MATH_MODU(type_name, type) \
OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \
OSSL_SAFE_MATH_MAXU(type)) \
OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
OSSL_SAFE_MATH_NEGU(type_name, type) \
OSSL_SAFE_MATH_ABSU(type_name, type)
#endif