root/crypto/zstd.c
// SPDX-License-Identifier: GPL-2.0-only
/*
 * Cryptographic API.
 *
 * Copyright (c) 2017-present, Facebook, Inc.
 */
#include <linux/crypto.h>
#include <linux/init.h>
#include <linux/interrupt.h>
#include <linux/mm.h>
#include <linux/module.h>
#include <linux/net.h>
#include <linux/overflow.h>
#include <linux/vmalloc.h>
#include <linux/zstd.h>
#include <crypto/internal/acompress.h>
#include <crypto/scatterwalk.h>


#define ZSTD_DEF_LEVEL          3
#define ZSTD_MAX_WINDOWLOG      18
#define ZSTD_MAX_SIZE           BIT(ZSTD_MAX_WINDOWLOG)

struct zstd_ctx {
        zstd_cctx *cctx;
        zstd_dctx *dctx;
        size_t wksp_size;
        zstd_parameters params;
        u8 wksp[] __aligned(8) __counted_by(wksp_size);
};

static DEFINE_MUTEX(zstd_stream_lock);

static void *zstd_alloc_stream(void)
{
        zstd_parameters params;
        struct zstd_ctx *ctx;
        size_t wksp_size;

        params = zstd_get_params(ZSTD_DEF_LEVEL, ZSTD_MAX_SIZE);

        wksp_size = max(zstd_cstream_workspace_bound(&params.cParams),
                        zstd_dstream_workspace_bound(ZSTD_MAX_SIZE));
        if (!wksp_size)
                return ERR_PTR(-EINVAL);

        ctx = kvmalloc_flex(*ctx, wksp, wksp_size);
        if (!ctx)
                return ERR_PTR(-ENOMEM);

        ctx->params = params;
        ctx->wksp_size = wksp_size;

        return ctx;
}

static void zstd_free_stream(void *ctx)
{
        kvfree(ctx);
}

static struct crypto_acomp_streams zstd_streams = {
        .alloc_ctx = zstd_alloc_stream,
        .free_ctx = zstd_free_stream,
};

static int zstd_init(struct crypto_acomp *acomp_tfm)
{
        int ret = 0;

        mutex_lock(&zstd_stream_lock);
        ret = crypto_acomp_alloc_streams(&zstd_streams);
        mutex_unlock(&zstd_stream_lock);

        return ret;
}

static int zstd_compress_one(struct acomp_req *req, struct zstd_ctx *ctx,
                             const void *src, void *dst, unsigned int *dlen)
{
        size_t out_len;

        ctx->cctx = zstd_init_cctx(ctx->wksp, ctx->wksp_size);
        if (!ctx->cctx)
                return -EINVAL;

        out_len = zstd_compress_cctx(ctx->cctx, dst, req->dlen, src, req->slen,
                                     &ctx->params);
        if (zstd_is_error(out_len))
                return -EINVAL;

        *dlen = out_len;

        return 0;
}

static int zstd_compress(struct acomp_req *req)
{
        struct crypto_acomp_stream *s;
        unsigned int pos, scur, dcur;
        unsigned int total_out = 0;
        bool data_available = true;
        zstd_out_buffer outbuf;
        struct acomp_walk walk;
        zstd_in_buffer inbuf;
        struct zstd_ctx *ctx;
        size_t pending_bytes;
        size_t num_bytes;
        int ret;

        s = crypto_acomp_lock_stream_bh(&zstd_streams);
        ctx = s->ctx;

        ret = acomp_walk_virt(&walk, req, true);
        if (ret)
                goto out;

        ctx->cctx = zstd_init_cstream(&ctx->params, 0, ctx->wksp, ctx->wksp_size);
        if (!ctx->cctx) {
                ret = -EINVAL;
                goto out;
        }

        do {
                dcur = acomp_walk_next_dst(&walk);
                if (!dcur) {
                        ret = -ENOSPC;
                        goto out;
                }

                outbuf.pos = 0;
                outbuf.dst = (u8 *)walk.dst.virt.addr;
                outbuf.size = dcur;

                do {
                        scur = acomp_walk_next_src(&walk);
                        if (dcur == req->dlen && scur == req->slen) {
                                ret = zstd_compress_one(req, ctx, walk.src.virt.addr,
                                                        walk.dst.virt.addr, &total_out);
                                acomp_walk_done_src(&walk, scur);
                                acomp_walk_done_dst(&walk, dcur);
                                goto out;
                        }

                        if (scur) {
                                inbuf.pos = 0;
                                inbuf.src = walk.src.virt.addr;
                                inbuf.size = scur;
                        } else {
                                data_available = false;
                                break;
                        }

                        num_bytes = zstd_compress_stream(ctx->cctx, &outbuf, &inbuf);
                        if (ZSTD_isError(num_bytes)) {
                                ret = -EIO;
                                goto out;
                        }

                        pending_bytes = zstd_flush_stream(ctx->cctx, &outbuf);
                        if (ZSTD_isError(pending_bytes)) {
                                ret = -EIO;
                                goto out;
                        }
                        acomp_walk_done_src(&walk, inbuf.pos);
                } while (dcur != outbuf.pos);

                total_out += outbuf.pos;
                acomp_walk_done_dst(&walk, dcur);
        } while (data_available);

        pos = outbuf.pos;
        num_bytes = zstd_end_stream(ctx->cctx, &outbuf);
        if (ZSTD_isError(num_bytes))
                ret = -EIO;
        else
                total_out += (outbuf.pos - pos);

out:
        if (ret)
                req->dlen = 0;
        else
                req->dlen = total_out;

        crypto_acomp_unlock_stream_bh(s);

        return ret;
}

static int zstd_decompress_one(struct acomp_req *req, struct zstd_ctx *ctx,
                               const void *src, void *dst, unsigned int *dlen)
{
        size_t out_len;

        ctx->dctx = zstd_init_dctx(ctx->wksp, ctx->wksp_size);
        if (!ctx->dctx)
                return -EINVAL;

        out_len = zstd_decompress_dctx(ctx->dctx, dst, req->dlen, src, req->slen);
        if (zstd_is_error(out_len))
                return -EINVAL;

        *dlen = out_len;

        return 0;
}

static int zstd_decompress(struct acomp_req *req)
{
        struct crypto_acomp_stream *s;
        unsigned int total_out = 0;
        unsigned int scur, dcur;
        zstd_out_buffer outbuf;
        struct acomp_walk walk;
        zstd_in_buffer inbuf;
        struct zstd_ctx *ctx;
        size_t pending_bytes;
        int ret;

        s = crypto_acomp_lock_stream_bh(&zstd_streams);
        ctx = s->ctx;

        ret = acomp_walk_virt(&walk, req, true);
        if (ret)
                goto out;

        ctx->dctx = zstd_init_dstream(ZSTD_MAX_SIZE, ctx->wksp, ctx->wksp_size);
        if (!ctx->dctx) {
                ret = -EINVAL;
                goto out;
        }

        do {
                scur = acomp_walk_next_src(&walk);
                if (scur) {
                        inbuf.pos = 0;
                        inbuf.size = scur;
                        inbuf.src = walk.src.virt.addr;
                } else {
                        break;
                }

                do {
                        dcur = acomp_walk_next_dst(&walk);
                        if (dcur == req->dlen && scur == req->slen) {
                                ret = zstd_decompress_one(req, ctx, walk.src.virt.addr,
                                                          walk.dst.virt.addr, &total_out);
                                acomp_walk_done_dst(&walk, dcur);
                                acomp_walk_done_src(&walk, scur);
                                goto out;
                        }

                        if (!dcur) {
                                ret = -ENOSPC;
                                goto out;
                        }

                        outbuf.pos = 0;
                        outbuf.dst = (u8 *)walk.dst.virt.addr;
                        outbuf.size = dcur;

                        pending_bytes = zstd_decompress_stream(ctx->dctx, &outbuf, &inbuf);
                        if (ZSTD_isError(pending_bytes)) {
                                ret = -EIO;
                                goto out;
                        }

                        total_out += outbuf.pos;

                        acomp_walk_done_dst(&walk, outbuf.pos);
                } while (inbuf.pos != scur);

                acomp_walk_done_src(&walk, scur);
        } while (ret == 0);

out:
        if (ret)
                req->dlen = 0;
        else
                req->dlen = total_out;

        crypto_acomp_unlock_stream_bh(s);

        return ret;
}

static struct acomp_alg zstd_acomp = {
        .base = {
                .cra_name = "zstd",
                .cra_driver_name = "zstd-generic",
                .cra_flags = CRYPTO_ALG_REQ_VIRT,
                .cra_module = THIS_MODULE,
        },
        .init = zstd_init,
        .compress = zstd_compress,
        .decompress = zstd_decompress,
};

static int __init zstd_mod_init(void)
{
        return crypto_register_acomp(&zstd_acomp);
}

static void __exit zstd_mod_fini(void)
{
        crypto_unregister_acomp(&zstd_acomp);
        crypto_acomp_free_streams(&zstd_streams);
}

module_init(zstd_mod_init);
module_exit(zstd_mod_fini);

MODULE_LICENSE("GPL");
MODULE_DESCRIPTION("Zstd Compression Algorithm");
MODULE_ALIAS_CRYPTO("zstd");