root/drivers/net/ethernet/mellanox/mlxfw/mlxfw_mfa2.c
// SPDX-License-Identifier: BSD-3-Clause OR GPL-2.0
/* Copyright (c) 2017-2019 Mellanox Technologies. All rights reserved */

#define pr_fmt(fmt) "mlxfw_mfa2: " fmt

#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/netlink.h>
#include <linux/vmalloc.h>
#include <linux/xz.h>
#include "mlxfw_mfa2.h"
#include "mlxfw_mfa2_file.h"
#include "mlxfw_mfa2_tlv.h"
#include "mlxfw_mfa2_format.h"
#include "mlxfw_mfa2_tlv_multi.h"

/*               MFA2 FILE
 *  +----------------------------------+
 *  |        MFA2 finger print         |
 *  +----------------------------------+
 *  |   package descriptor multi_tlv   |
 *  | +------------------------------+ |     +-----------------+
 *  | |    package descriptor tlv    +-----> |num_devices=n    |
 *  | +------------------------------+ |     |num_components=m |
 *  +----------------------------------+     |CB offset        |
 *  |    device descriptor multi_tlv   |     |...              |
 *  | +------------------------------+ |     |                 |
 *  | |           PSID tlv           | |     +-----------------+
 *  | +------------------------------+ |
 *  | |     component index tlv      | |
 *  | +------------------------------+ |
 *  +----------------------------------+
 *  |  component descriptor multi_tlv  |
 *  | +------------------------------+ |     +-----------------+
 *  | |  component descriptor tlv    +-----> |Among others:    |
 *  | +------------------------------+ |     |CB offset=o      |
 *  +----------------------------------+     |comp index=i     |
 *  |                                  |     |...              |
 *  |                                  |     |                 |
 *  |                                  |     +-----------------+
 *  |        COMPONENT BLOCK (CB)      |
 *  |                                  |
 *  |                                  |
 *  |                                  |
 *  +----------------------------------+
 *
 * On the top level, an MFA2 file contains:
 *  - Fingerprint
 *  - Several multi_tlvs (TLVs of type MLXFW_MFA2_TLV_MULTI, as defined in
 *    mlxfw_mfa2_format.h)
 *  - Compresses content block
 *
 * The first multi_tlv
 * -------------------
 * The first multi TLV is treated as package descriptor, and expected to have a
 * first TLV child of type MLXFW_MFA2_TLV_PACKAGE_DESCRIPTOR which contains all
 * the global information needed to parse the file. Among others, it contains
 * the number of device descriptors and component descriptor following this
 * multi TLV.
 *
 * The device descriptor multi_tlv
 * -------------------------------
 * The multi TLVs following the package descriptor are treated as device
 * descriptor, and are expected to have the following children:
 *  - PSID TLV child of type MLXFW_MFA2_TLV_PSID containing that device PSID.
 *  - Component index of type MLXFW_MFA2_TLV_COMPONENT_PTR that contains that
 *    device component index.
 *
 * The component descriptor multi_tlv
 * ----------------------------------
 * The multi TLVs following the device descriptor multi TLVs are treated as
 * component descriptor, and are expected to have a first child of type
 * MLXFW_MFA2_TLV_COMPONENT_DESCRIPTOR that contains mostly the component index,
 * needed for the flash process and the offset to the binary within the
 * component block.
 */

static const u8 mlxfw_mfa2_fingerprint[] = "MLNX.MFA2.XZ.00!";
static const int mlxfw_mfa2_fingerprint_len =
                        sizeof(mlxfw_mfa2_fingerprint) - 1;

static const u8 mlxfw_mfa2_comp_magic[] = "#BIN.COMPONENT!#";
static const int mlxfw_mfa2_comp_magic_len = sizeof(mlxfw_mfa2_comp_magic) - 1;

bool mlxfw_mfa2_check(const struct firmware *fw)
{
        if (fw->size < sizeof(mlxfw_mfa2_fingerprint))
                return false;

        return memcmp(fw->data, mlxfw_mfa2_fingerprint,
                      mlxfw_mfa2_fingerprint_len) == 0;
}

static bool
mlxfw_mfa2_tlv_multi_validate(const struct mlxfw_mfa2_file *mfa2_file,
                              const struct mlxfw_mfa2_tlv_multi *multi)
{
        const struct mlxfw_mfa2_tlv *tlv;
        u16 idx;

        /* Check that all children are valid */
        mlxfw_mfa2_tlv_multi_foreach(mfa2_file, tlv, idx, multi) {
                if (!tlv) {
                        pr_err("Multi has invalid child");
                        return false;
                }
        }
        return true;
}

static bool
mlxfw_mfa2_file_dev_validate(const struct mlxfw_mfa2_file *mfa2_file,
                             const struct mlxfw_mfa2_tlv *dev_tlv,
                             u16 dev_idx)
{
        const struct mlxfw_mfa2_tlv_component_ptr *cptr;
        const struct mlxfw_mfa2_tlv_multi *multi;
        const struct mlxfw_mfa2_tlv_psid *psid;
        const struct mlxfw_mfa2_tlv *tlv;
        u16 cptr_count;
        u16 cptr_idx;
        int err;

        pr_debug("Device %d\n", dev_idx);

        multi = mlxfw_mfa2_tlv_multi_get(mfa2_file, dev_tlv);
        if (!multi) {
                pr_err("Device %d is not a valid TLV error\n", dev_idx);
                return false;
        }

        if (!mlxfw_mfa2_tlv_multi_validate(mfa2_file, multi))
                return false;

        /* Validate the device has PSID tlv */
        tlv = mlxfw_mfa2_tlv_multi_child_find(mfa2_file, multi,
                                              MLXFW_MFA2_TLV_PSID, 0);
        if (!tlv) {
                pr_err("Device %d does not have PSID\n", dev_idx);
                return false;
        }

        psid = mlxfw_mfa2_tlv_psid_get(mfa2_file, tlv);
        if (!psid) {
                pr_err("Device %d PSID TLV is not valid\n", dev_idx);
                return false;
        }

        print_hex_dump_debug("  -- Device PSID ", DUMP_PREFIX_NONE, 16, 16,
                             psid->psid, be16_to_cpu(tlv->len), true);

        /* Validate the device has COMPONENT_PTR */
        err = mlxfw_mfa2_tlv_multi_child_count(mfa2_file, multi,
                                               MLXFW_MFA2_TLV_COMPONENT_PTR,
                                               &cptr_count);
        if (err)
                return false;

        if (cptr_count == 0) {
                pr_err("Device %d has no components\n", dev_idx);
                return false;
        }

        for (cptr_idx = 0; cptr_idx < cptr_count; cptr_idx++) {
                tlv = mlxfw_mfa2_tlv_multi_child_find(mfa2_file, multi,
                                                      MLXFW_MFA2_TLV_COMPONENT_PTR,
                                                      cptr_idx);
                if (!tlv)
                        return false;

                cptr = mlxfw_mfa2_tlv_component_ptr_get(mfa2_file, tlv);
                if (!cptr) {
                        pr_err("Device %d COMPONENT_PTR TLV is not valid\n",
                               dev_idx);
                        return false;
                }

                pr_debug("  -- Component index %d\n",
                         be16_to_cpu(cptr->component_index));
        }
        return true;
}

static bool
mlxfw_mfa2_file_comp_validate(const struct mlxfw_mfa2_file *mfa2_file,
                              const struct mlxfw_mfa2_tlv *comp_tlv,
                              u16 comp_idx)
{
        const struct mlxfw_mfa2_tlv_component_descriptor *cdesc;
        const struct mlxfw_mfa2_tlv_multi *multi;
        const struct mlxfw_mfa2_tlv *tlv;

        pr_debug("Component %d\n", comp_idx);

        multi = mlxfw_mfa2_tlv_multi_get(mfa2_file, comp_tlv);
        if (!multi) {
                pr_err("Component %d is not a valid TLV error\n", comp_idx);
                return false;
        }

        if (!mlxfw_mfa2_tlv_multi_validate(mfa2_file, multi))
                return false;

        /* Check that component have COMPONENT_DESCRIPTOR as first child */
        tlv = mlxfw_mfa2_tlv_multi_child(mfa2_file, multi);
        if (!tlv) {
                pr_err("Component descriptor %d multi TLV error\n", comp_idx);
                return false;
        }

        cdesc = mlxfw_mfa2_tlv_component_descriptor_get(mfa2_file, tlv);
        if (!cdesc) {
                pr_err("Component %d does not have a valid descriptor\n",
                       comp_idx);
                return false;
        }
        pr_debug("  -- Component type %d\n", be16_to_cpu(cdesc->identifier));
        pr_debug("  -- Offset 0x%llx and size %d\n",
                 ((u64) be32_to_cpu(cdesc->cb_offset_h) << 32)
                 | be32_to_cpu(cdesc->cb_offset_l), be32_to_cpu(cdesc->size));

        return true;
}

static bool mlxfw_mfa2_file_validate(const struct mlxfw_mfa2_file *mfa2_file)
{
        const struct mlxfw_mfa2_tlv *tlv;
        u16 idx;

        pr_debug("Validating file\n");

        /* check that all the devices exist */
        mlxfw_mfa2_tlv_foreach(mfa2_file, tlv, idx, mfa2_file->first_dev,
                               mfa2_file->dev_count) {
                if (!tlv) {
                        pr_err("Device TLV error\n");
                        return false;
                }

                /* Check each device */
                if (!mlxfw_mfa2_file_dev_validate(mfa2_file, tlv, idx))
                        return false;
        }

        /* check that all the components exist */
        mlxfw_mfa2_tlv_foreach(mfa2_file, tlv, idx, mfa2_file->first_component,
                               mfa2_file->component_count) {
                if (!tlv) {
                        pr_err("Device TLV error\n");
                        return false;
                }

                /* Check each component */
                if (!mlxfw_mfa2_file_comp_validate(mfa2_file, tlv, idx))
                        return false;
        }
        return true;
}

struct mlxfw_mfa2_file *mlxfw_mfa2_file_init(const struct firmware *fw)
{
        const struct mlxfw_mfa2_tlv_package_descriptor *pd;
        const struct mlxfw_mfa2_tlv_multi *multi;
        const struct mlxfw_mfa2_tlv *multi_child;
        const struct mlxfw_mfa2_tlv *first_tlv;
        struct mlxfw_mfa2_file *mfa2_file;
        const void *first_tlv_ptr;
        const void *cb_top_ptr;

        mfa2_file = kzalloc_obj(*mfa2_file);
        if (!mfa2_file)
                return ERR_PTR(-ENOMEM);

        mfa2_file->fw = fw;
        first_tlv_ptr = fw->data + NLA_ALIGN(mlxfw_mfa2_fingerprint_len);
        first_tlv = mlxfw_mfa2_tlv_get(mfa2_file, first_tlv_ptr);
        if (!first_tlv) {
                pr_err("Could not parse package descriptor TLV\n");
                goto err_out;
        }

        multi = mlxfw_mfa2_tlv_multi_get(mfa2_file, first_tlv);
        if (!multi) {
                pr_err("First TLV is not of valid multi type\n");
                goto err_out;
        }

        multi_child = mlxfw_mfa2_tlv_multi_child(mfa2_file, multi);
        if (!multi_child)
                goto err_out;

        pd = mlxfw_mfa2_tlv_package_descriptor_get(mfa2_file, multi_child);
        if (!pd) {
                pr_err("Could not parse package descriptor TLV\n");
                goto err_out;
        }

        mfa2_file->first_dev = mlxfw_mfa2_tlv_next(mfa2_file, first_tlv);
        if (!mfa2_file->first_dev) {
                pr_err("First device TLV is not valid\n");
                goto err_out;
        }

        mfa2_file->dev_count = be16_to_cpu(pd->num_devices);
        mfa2_file->first_component = mlxfw_mfa2_tlv_advance(mfa2_file,
                                                            mfa2_file->first_dev,
                                                            mfa2_file->dev_count);
        mfa2_file->component_count = be16_to_cpu(pd->num_components);
        mfa2_file->cb = fw->data + NLA_ALIGN(be32_to_cpu(pd->cb_offset));
        if (!mlxfw_mfa2_valid_ptr(mfa2_file, mfa2_file->cb)) {
                pr_err("Component block is out side the file\n");
                goto err_out;
        }
        mfa2_file->cb_archive_size = be32_to_cpu(pd->cb_archive_size);
        cb_top_ptr = mfa2_file->cb + mfa2_file->cb_archive_size - 1;
        if (!mlxfw_mfa2_valid_ptr(mfa2_file, cb_top_ptr)) {
                pr_err("Component block size is too big\n");
                goto err_out;
        }

        if (!mlxfw_mfa2_file_validate(mfa2_file))
                goto err_out;
        return mfa2_file;
err_out:
        kfree(mfa2_file);
        return ERR_PTR(-EINVAL);
}

static const struct mlxfw_mfa2_tlv_multi *
mlxfw_mfa2_tlv_dev_get(const struct mlxfw_mfa2_file *mfa2_file,
                       const char *psid, u16 psid_size)
{
        const struct mlxfw_mfa2_tlv_psid *tlv_psid;
        const struct mlxfw_mfa2_tlv_multi *dev_multi;
        const struct mlxfw_mfa2_tlv *dev_tlv;
        const struct mlxfw_mfa2_tlv *tlv;
        u32 idx;

        /* for each device tlv */
        mlxfw_mfa2_tlv_foreach(mfa2_file, dev_tlv, idx, mfa2_file->first_dev,
                               mfa2_file->dev_count) {
                if (!dev_tlv)
                        return NULL;

                dev_multi = mlxfw_mfa2_tlv_multi_get(mfa2_file, dev_tlv);
                if (!dev_multi)
                        return NULL;

                /* find psid child and compare */
                tlv = mlxfw_mfa2_tlv_multi_child_find(mfa2_file, dev_multi,
                                                      MLXFW_MFA2_TLV_PSID, 0);
                if (!tlv)
                        return NULL;
                if (be16_to_cpu(tlv->len) != psid_size)
                        continue;

                tlv_psid = mlxfw_mfa2_tlv_psid_get(mfa2_file, tlv);
                if (!tlv_psid)
                        return NULL;

                if (memcmp(psid, tlv_psid->psid, psid_size) == 0)
                        return dev_multi;
        }

        return NULL;
}

int mlxfw_mfa2_file_component_count(const struct mlxfw_mfa2_file *mfa2_file,
                                    const char *psid, u32 psid_size,
                                    u32 *p_count)
{
        const struct mlxfw_mfa2_tlv_multi *dev_multi;
        u16 count;
        int err;

        dev_multi = mlxfw_mfa2_tlv_dev_get(mfa2_file, psid, psid_size);
        if (!dev_multi)
                return -EINVAL;

        err = mlxfw_mfa2_tlv_multi_child_count(mfa2_file, dev_multi,
                                               MLXFW_MFA2_TLV_COMPONENT_PTR,
                                               &count);
        if (err)
                return err;

        *p_count = count;
        return 0;
}

static int mlxfw_mfa2_xz_dec_run(struct xz_dec *xz_dec, struct xz_buf *xz_buf,
                                 bool *finished)
{
        enum xz_ret xz_ret;

        xz_ret = xz_dec_run(xz_dec, xz_buf);

        switch (xz_ret) {
        case XZ_STREAM_END:
                *finished = true;
                return 0;
        case XZ_OK:
                *finished = false;
                return 0;
        case XZ_MEM_ERROR:
                pr_err("xz no memory\n");
                return -ENOMEM;
        case XZ_DATA_ERROR:
                pr_err("xz file corrupted\n");
                return -EINVAL;
        case XZ_FORMAT_ERROR:
                pr_err("xz format not found\n");
                return -EINVAL;
        case XZ_OPTIONS_ERROR:
                pr_err("unsupported xz option\n");
                return -EINVAL;
        case XZ_MEMLIMIT_ERROR:
                pr_err("xz dictionary too small\n");
                return -EINVAL;
        default:
                pr_err("xz error %d\n", xz_ret);
                return -EINVAL;
        }
}

static int mlxfw_mfa2_file_cb_offset_xz(const struct mlxfw_mfa2_file *mfa2_file,
                                        off_t off, size_t size, u8 *buf)
{
        struct xz_dec *xz_dec;
        struct xz_buf dec_buf;
        off_t curr_off = 0;
        bool finished;
        int err;

        xz_dec = xz_dec_init(XZ_DYNALLOC, (u32) -1);
        if (!xz_dec)
                return -EINVAL;

        dec_buf.in_size = mfa2_file->cb_archive_size;
        dec_buf.in = mfa2_file->cb;
        dec_buf.in_pos = 0;
        dec_buf.out = buf;

        /* decode up to the offset */
        do {
                dec_buf.out_pos = 0;
                dec_buf.out_size = min_t(size_t, size, off - curr_off);
                if (dec_buf.out_size == 0)
                        break;

                err = mlxfw_mfa2_xz_dec_run(xz_dec, &dec_buf, &finished);
                if (err)
                        goto out;
                if (finished) {
                        pr_err("xz section too short\n");
                        err = -EINVAL;
                        goto out;
                }
                curr_off += dec_buf.out_pos;
        } while (curr_off != off);

        /* decode the needed section */
        dec_buf.out_pos = 0;
        dec_buf.out_size = size;
        err = mlxfw_mfa2_xz_dec_run(xz_dec, &dec_buf, &finished);
out:
        xz_dec_end(xz_dec);
        return err;
}

static const struct mlxfw_mfa2_tlv_component_descriptor *
mlxfw_mfa2_file_component_tlv_get(const struct mlxfw_mfa2_file *mfa2_file,
                                  u16 comp_index)
{
        const struct mlxfw_mfa2_tlv_multi *multi;
        const struct mlxfw_mfa2_tlv *multi_child;
        const struct mlxfw_mfa2_tlv *comp_tlv;

        if (comp_index > mfa2_file->component_count)
                return NULL;

        comp_tlv = mlxfw_mfa2_tlv_advance(mfa2_file, mfa2_file->first_component,
                                          comp_index);
        if (!comp_tlv)
                return NULL;

        multi = mlxfw_mfa2_tlv_multi_get(mfa2_file, comp_tlv);
        if (!multi)
                return NULL;

        multi_child = mlxfw_mfa2_tlv_multi_child(mfa2_file, multi);
        if (!multi_child)
                return NULL;

        return mlxfw_mfa2_tlv_component_descriptor_get(mfa2_file, multi_child);
}

struct mlxfw_mfa2_comp_data {
        struct mlxfw_mfa2_component comp;
        u8 buff[];
};

static const struct mlxfw_mfa2_tlv_component_descriptor *
mlxfw_mfa2_file_component_find(const struct mlxfw_mfa2_file *mfa2_file,
                               const char *psid, int psid_size,
                               int component_index)
{
        const struct mlxfw_mfa2_tlv_component_ptr *cptr;
        const struct mlxfw_mfa2_tlv_multi *dev_multi;
        const struct mlxfw_mfa2_tlv *cptr_tlv;
        u16 comp_idx;

        dev_multi = mlxfw_mfa2_tlv_dev_get(mfa2_file, psid, psid_size);
        if (!dev_multi)
                return NULL;

        cptr_tlv = mlxfw_mfa2_tlv_multi_child_find(mfa2_file, dev_multi,
                                                   MLXFW_MFA2_TLV_COMPONENT_PTR,
                                                   component_index);
        if (!cptr_tlv)
                return NULL;

        cptr = mlxfw_mfa2_tlv_component_ptr_get(mfa2_file, cptr_tlv);
        if (!cptr)
                return NULL;

        comp_idx = be16_to_cpu(cptr->component_index);
        return mlxfw_mfa2_file_component_tlv_get(mfa2_file, comp_idx);
}

struct mlxfw_mfa2_component *
mlxfw_mfa2_file_component_get(const struct mlxfw_mfa2_file *mfa2_file,
                              const char *psid, int psid_size,
                              int component_index)
{
        const struct mlxfw_mfa2_tlv_component_descriptor *comp;
        struct mlxfw_mfa2_comp_data *comp_data;
        u32 comp_buf_size;
        off_t cb_offset;
        u32 comp_size;
        int err;

        comp = mlxfw_mfa2_file_component_find(mfa2_file, psid, psid_size,
                                              component_index);
        if (!comp)
                return ERR_PTR(-EINVAL);

        cb_offset = (u64) be32_to_cpu(comp->cb_offset_h) << 32 |
                    be32_to_cpu(comp->cb_offset_l);
        comp_size = be32_to_cpu(comp->size);
        comp_buf_size = comp_size + mlxfw_mfa2_comp_magic_len;

        comp_data = vzalloc(sizeof(*comp_data) + comp_buf_size);
        if (!comp_data)
                return ERR_PTR(-ENOMEM);
        comp_data->comp.data_size = comp_size;
        comp_data->comp.index = be16_to_cpu(comp->identifier);
        err = mlxfw_mfa2_file_cb_offset_xz(mfa2_file, cb_offset, comp_buf_size,
                                           comp_data->buff);
        if (err) {
                pr_err("Component could not be reached in CB\n");
                goto err_out;
        }

        if (memcmp(comp_data->buff, mlxfw_mfa2_comp_magic,
                   mlxfw_mfa2_comp_magic_len) != 0) {
                pr_err("Component has wrong magic\n");
                err = -EINVAL;
                goto err_out;
        }

        comp_data->comp.data = comp_data->buff + mlxfw_mfa2_comp_magic_len;
        return &comp_data->comp;
err_out:
        vfree(comp_data);
        return ERR_PTR(err);
}

void mlxfw_mfa2_file_component_put(struct mlxfw_mfa2_component *comp)
{
        const struct mlxfw_mfa2_comp_data *comp_data;

        comp_data = container_of(comp, struct mlxfw_mfa2_comp_data, comp);
        vfree(comp_data);
}

void mlxfw_mfa2_file_fini(struct mlxfw_mfa2_file *mfa2_file)
{
        kfree(mfa2_file);
}