root/drivers/media/platform/allegro-dvt/nal-rbsp.c
// SPDX-License-Identifier: GPL-2.0
/*
 * Copyright (C) 2019-2020 Pengutronix, Michael Tretter <kernel@pengutronix.de>
 *
 * Helper functions to generate a raw byte sequence payload from values.
 */

#include <linux/kernel.h>
#include <linux/types.h>
#include <linux/string.h>
#include <linux/v4l2-controls.h>

#include <linux/device.h>
#include <linux/export.h>
#include <linux/log2.h>

#include "nal-rbsp.h"

void rbsp_init(struct rbsp *rbsp, void *addr, size_t size,
               struct nal_rbsp_ops *ops)
{
        if (!rbsp)
                return;

        rbsp->data = addr;
        rbsp->size = size;
        rbsp->pos = 0;
        rbsp->ops = ops;
        rbsp->error = 0;
}

void rbsp_unsupported(struct rbsp *rbsp)
{
        rbsp->error = -EINVAL;
}

static int rbsp_read_bits(struct rbsp *rbsp, int n, unsigned int *value);
static int rbsp_write_bits(struct rbsp *rbsp, int n, unsigned int value);

/*
 * When reading or writing, the emulation_prevention_three_byte is detected
 * only when the 2 one bits need to be inserted. Therefore, we are not
 * actually adding the 0x3 byte, but the 2 one bits and the six 0 bits of the
 * next byte.
 */
#define EMULATION_PREVENTION_THREE_BYTE (0x3 << 6)

static int add_emulation_prevention_three_byte(struct rbsp *rbsp)
{
        rbsp->num_consecutive_zeros = 0;
        rbsp_write_bits(rbsp, 8, EMULATION_PREVENTION_THREE_BYTE);

        return 0;
}

static int discard_emulation_prevention_three_byte(struct rbsp *rbsp)
{
        unsigned int tmp = 0;

        rbsp->num_consecutive_zeros = 0;
        rbsp_read_bits(rbsp, 8, &tmp);
        if (tmp != EMULATION_PREVENTION_THREE_BYTE)
                return -EINVAL;

        return 0;
}

static inline int rbsp_read_bit(struct rbsp *rbsp)
{
        int shift;
        int ofs;
        int bit;
        int err;

        if (rbsp->num_consecutive_zeros == 22) {
                err = discard_emulation_prevention_three_byte(rbsp);
                if (err)
                        return err;
        }

        shift = 7 - (rbsp->pos % 8);
        ofs = rbsp->pos / 8;
        if (ofs >= rbsp->size)
                return -EINVAL;

        bit = (rbsp->data[ofs] >> shift) & 1;

        rbsp->pos++;

        if (bit == 1 ||
            (rbsp->num_consecutive_zeros < 7 && (rbsp->pos % 8 == 0)))
                rbsp->num_consecutive_zeros = 0;
        else
                rbsp->num_consecutive_zeros++;

        return bit;
}

static inline int rbsp_write_bit(struct rbsp *rbsp, bool value)
{
        int shift;
        int ofs;

        if (rbsp->num_consecutive_zeros == 22)
                add_emulation_prevention_three_byte(rbsp);

        shift = 7 - (rbsp->pos % 8);
        ofs = rbsp->pos / 8;
        if (ofs >= rbsp->size)
                return -EINVAL;

        rbsp->data[ofs] &= ~(1 << shift);
        rbsp->data[ofs] |= value << shift;

        rbsp->pos++;

        if (value ||
            (rbsp->num_consecutive_zeros < 7 && (rbsp->pos % 8 == 0))) {
                rbsp->num_consecutive_zeros = 0;
        } else {
                rbsp->num_consecutive_zeros++;
        }

        return 0;
}

static inline int rbsp_read_bits(struct rbsp *rbsp, int n, unsigned int *value)
{
        int i;
        int bit;
        unsigned int tmp = 0;

        if (n > 8 * sizeof(*value))
                return -EINVAL;

        for (i = n; i > 0; i--) {
                bit = rbsp_read_bit(rbsp);
                if (bit < 0)
                        return bit;
                tmp |= bit << (i - 1);
        }

        if (value)
                *value = tmp;

        return 0;
}

static int rbsp_write_bits(struct rbsp *rbsp, int n, unsigned int value)
{
        int ret;

        if (n > 8 * sizeof(value))
                return -EINVAL;

        while (n--) {
                ret = rbsp_write_bit(rbsp, (value >> n) & 1);
                if (ret)
                        return ret;
        }

        return 0;
}

static int rbsp_read_uev(struct rbsp *rbsp, unsigned int *value)
{
        int leading_zero_bits = 0;
        unsigned int tmp = 0;
        int ret;

        while ((ret = rbsp_read_bit(rbsp)) == 0)
                leading_zero_bits++;
        if (ret < 0)
                return ret;

        if (leading_zero_bits > 0) {
                ret = rbsp_read_bits(rbsp, leading_zero_bits, &tmp);
                if (ret)
                        return ret;
        }

        if (value)
                *value = (1 << leading_zero_bits) - 1 + tmp;

        return 0;
}

static int rbsp_write_uev(struct rbsp *rbsp, unsigned int *value)
{
        int ret;
        int leading_zero_bits;

        if (!value)
                return -EINVAL;

        leading_zero_bits = ilog2(*value + 1);

        ret = rbsp_write_bits(rbsp, leading_zero_bits, 0);
        if (ret)
                return ret;

        return rbsp_write_bits(rbsp, leading_zero_bits + 1, *value + 1);
}

static int rbsp_read_sev(struct rbsp *rbsp, int *value)
{
        int ret;
        unsigned int tmp;

        ret = rbsp_read_uev(rbsp, &tmp);
        if (ret)
                return ret;

        if (value) {
                if (tmp & 1)
                        *value = (tmp + 1) / 2;
                else
                        *value = -(tmp / 2);
        }

        return 0;
}

static int rbsp_write_sev(struct rbsp *rbsp, int *value)
{
        unsigned int tmp;

        if (!value)
                return -EINVAL;

        if (*value > 0)
                tmp = (2 * (*value)) | 1;
        else
                tmp = -2 * (*value);

        return rbsp_write_uev(rbsp, &tmp);
}

static int __rbsp_write_bit(struct rbsp *rbsp, int *value)
{
        return rbsp_write_bit(rbsp, *value);
}

static int __rbsp_write_bits(struct rbsp *rbsp, int n, unsigned int *value)
{
        return rbsp_write_bits(rbsp, n, *value);
}

struct nal_rbsp_ops write = {
        .rbsp_bit = __rbsp_write_bit,
        .rbsp_bits = __rbsp_write_bits,
        .rbsp_uev = rbsp_write_uev,
        .rbsp_sev = rbsp_write_sev,
};

static int __rbsp_read_bit(struct rbsp *rbsp, int *value)
{
        int tmp = rbsp_read_bit(rbsp);

        if (tmp < 0)
                return tmp;
        *value = tmp;

        return 0;
}

struct nal_rbsp_ops read = {
        .rbsp_bit = __rbsp_read_bit,
        .rbsp_bits = rbsp_read_bits,
        .rbsp_uev = rbsp_read_uev,
        .rbsp_sev = rbsp_read_sev,
};

void rbsp_bit(struct rbsp *rbsp, int *value)
{
        if (rbsp->error)
                return;
        rbsp->error = rbsp->ops->rbsp_bit(rbsp, value);
}

void rbsp_bits(struct rbsp *rbsp, int n, int *value)
{
        if (rbsp->error)
                return;
        rbsp->error = rbsp->ops->rbsp_bits(rbsp, n, value);
}

void rbsp_uev(struct rbsp *rbsp, unsigned int *value)
{
        if (rbsp->error)
                return;
        rbsp->error = rbsp->ops->rbsp_uev(rbsp, value);
}

void rbsp_sev(struct rbsp *rbsp, int *value)
{
        if (rbsp->error)
                return;
        rbsp->error = rbsp->ops->rbsp_sev(rbsp, value);
}

void rbsp_trailing_bits(struct rbsp *rbsp)
{
        unsigned int rbsp_stop_one_bit = 1;
        unsigned int rbsp_alignment_zero_bit = 0;

        rbsp_bit(rbsp, &rbsp_stop_one_bit);
        rbsp_bits(rbsp, round_up(rbsp->pos, 8) - rbsp->pos,
                  &rbsp_alignment_zero_bit);
}