#include <linux/kernel.h>
#include <linux/tnum.h>
#include <linux/swab.h>
#define TNUM(_v, _m) (struct tnum){.value = _v, .mask = _m}
const struct tnum tnum_unknown = { .value = 0, .mask = -1 };
struct tnum tnum_const(u64 value)
{
return TNUM(value, 0);
}
struct tnum tnum_range(u64 min, u64 max)
{
u64 chi = min ^ max, delta;
u8 bits = fls64(chi);
if (bits > 63)
return tnum_unknown;
delta = (1ULL << bits) - 1;
return TNUM(min & ~delta, delta);
}
struct tnum tnum_lshift(struct tnum a, u8 shift)
{
return TNUM(a.value << shift, a.mask << shift);
}
struct tnum tnum_rshift(struct tnum a, u8 shift)
{
return TNUM(a.value >> shift, a.mask >> shift);
}
struct tnum tnum_arshift(struct tnum a, u8 min_shift, u8 insn_bitness)
{
if (insn_bitness == 32)
return TNUM((u32)(((s32)a.value) >> min_shift),
(u32)(((s32)a.mask) >> min_shift));
else
return TNUM((s64)a.value >> min_shift,
(s64)a.mask >> min_shift);
}
struct tnum tnum_add(struct tnum a, struct tnum b)
{
u64 sm, sv, sigma, chi, mu;
sm = a.mask + b.mask;
sv = a.value + b.value;
sigma = sm + sv;
chi = sigma ^ sv;
mu = chi | a.mask | b.mask;
return TNUM(sv & ~mu, mu);
}
struct tnum tnum_sub(struct tnum a, struct tnum b)
{
u64 dv, alpha, beta, chi, mu;
dv = a.value - b.value;
alpha = dv + a.mask;
beta = dv - b.mask;
chi = alpha ^ beta;
mu = chi | a.mask | b.mask;
return TNUM(dv & ~mu, mu);
}
struct tnum tnum_neg(struct tnum a)
{
return tnum_sub(TNUM(0, 0), a);
}
struct tnum tnum_and(struct tnum a, struct tnum b)
{
u64 alpha, beta, v;
alpha = a.value | a.mask;
beta = b.value | b.mask;
v = a.value & b.value;
return TNUM(v, alpha & beta & ~v);
}
struct tnum tnum_or(struct tnum a, struct tnum b)
{
u64 v, mu;
v = a.value | b.value;
mu = a.mask | b.mask;
return TNUM(v, mu & ~v);
}
struct tnum tnum_xor(struct tnum a, struct tnum b)
{
u64 v, mu;
v = a.value ^ b.value;
mu = a.mask | b.mask;
return TNUM(v & ~mu, mu);
}
struct tnum tnum_mul(struct tnum a, struct tnum b)
{
struct tnum acc = TNUM(0, 0);
while (a.value || a.mask) {
if (a.value & 1)
acc = tnum_add(acc, b);
else if (a.mask & 1) {
acc = tnum_union(acc, tnum_add(acc, b));
}
a = tnum_rshift(a, 1);
b = tnum_lshift(b, 1);
}
return acc;
}
bool tnum_overlap(struct tnum a, struct tnum b)
{
u64 mu;
mu = ~a.mask & ~b.mask;
return (a.value & mu) == (b.value & mu);
}
struct tnum tnum_intersect(struct tnum a, struct tnum b)
{
u64 v, mu;
v = a.value | b.value;
mu = a.mask & b.mask;
return TNUM(v & ~mu, mu);
}
struct tnum tnum_union(struct tnum a, struct tnum b)
{
u64 v = a.value & b.value;
u64 mu = (a.value ^ b.value) | a.mask | b.mask;
return TNUM(v & ~mu, mu);
}
struct tnum tnum_cast(struct tnum a, u8 size)
{
a.value &= (1ULL << (size * 8)) - 1;
a.mask &= (1ULL << (size * 8)) - 1;
return a;
}
bool tnum_is_aligned(struct tnum a, u64 size)
{
if (!size)
return true;
return !((a.value | a.mask) & (size - 1));
}
bool tnum_in(struct tnum a, struct tnum b)
{
if (b.mask & ~a.mask)
return false;
b.value &= ~a.mask;
return a.value == b.value;
}
int tnum_sbin(char *str, size_t size, struct tnum a)
{
size_t n;
for (n = 64; n; n--) {
if (n < size) {
if (a.mask & 1)
str[n - 1] = 'x';
else if (a.value & 1)
str[n - 1] = '1';
else
str[n - 1] = '0';
}
a.mask >>= 1;
a.value >>= 1;
}
str[min(size - 1, (size_t)64)] = 0;
return 64;
}
struct tnum tnum_subreg(struct tnum a)
{
return tnum_cast(a, 4);
}
struct tnum tnum_clear_subreg(struct tnum a)
{
return tnum_lshift(tnum_rshift(a, 32), 32);
}
struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg)
{
return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg));
}
struct tnum tnum_const_subreg(struct tnum a, u32 value)
{
return tnum_with_subreg(a, tnum_const(value));
}
struct tnum tnum_bswap16(struct tnum a)
{
return TNUM(swab16(a.value & 0xFFFF), swab16(a.mask & 0xFFFF));
}
struct tnum tnum_bswap32(struct tnum a)
{
return TNUM(swab32(a.value & 0xFFFFFFFF), swab32(a.mask & 0xFFFFFFFF));
}
struct tnum tnum_bswap64(struct tnum a)
{
return TNUM(swab64(a.value), swab64(a.mask));
}
u64 tnum_step(struct tnum t, u64 z)
{
u64 tmax, j, p, q, r, s, v, u, w, res;
u8 k;
tmax = t.value | t.mask;
if (z >= tmax)
return tmax;
if (z < t.value)
return t.value;
j = t.value | (z & t.mask);
if (j > z) {
p = ~z & t.value & ~t.mask;
k = fls64(p);
q = U64_MAX << k;
r = q & z;
s = ~q & t.value;
v = r | s;
res = v;
} else {
p = z & ~t.value & ~t.mask;
k = fls64(p);
q = U64_MAX << k;
r = q & t.mask & z;
s = q & ~t.mask;
v = r | s;
u = v + (1ULL << k);
w = (u & t.mask) | t.value;
res = w;
}
return res;
}