root/tools/perf/util/zstd.c
// SPDX-License-Identifier: GPL-2.0

#include <string.h>

#include "util/compress.h"
#include "util/debug.h"

int zstd_init(struct zstd_data *data, int level)
{
        data->comp_level = level;
        data->dstream = NULL;
        data->cstream = NULL;
        return 0;
}

int zstd_fini(struct zstd_data *data)
{
        if (data->dstream) {
                ZSTD_freeDStream(data->dstream);
                data->dstream = NULL;
        }

        if (data->cstream) {
                ZSTD_freeCStream(data->cstream);
                data->cstream = NULL;
        }

        return 0;
}

ssize_t zstd_compress_stream_to_records(struct zstd_data *data, void *dst, size_t dst_size,
                                       void *src, size_t src_size, size_t max_record_size,
                                       size_t process_header(void *record, size_t increment))
{
        size_t ret, size, compressed = 0;
        ZSTD_inBuffer input = { src, src_size, 0 };
        ZSTD_outBuffer output;
        void *record;

        if (!data->cstream) {
                data->cstream = ZSTD_createCStream();
                if (data->cstream == NULL) {
                        pr_err("Couldn't create compression stream.\n");
                        return -1;
                }

                ret = ZSTD_initCStream(data->cstream, data->comp_level);
                if (ZSTD_isError(ret)) {
                        pr_err("Failed to initialize compression stream: %s\n",
                                ZSTD_getErrorName(ret));
                        return -1;
                }
        }

        while (input.pos < input.size) {
                record = dst;
                size = process_header(record, 0);
                compressed += size;
                dst += size;
                dst_size -= size;
                output = (ZSTD_outBuffer){ dst, (dst_size > max_record_size) ?
                                                max_record_size : dst_size, 0 };
                ret = ZSTD_compressStream(data->cstream, &output, &input);
                ZSTD_flushStream(data->cstream, &output);
                if (ZSTD_isError(ret)) {
                        pr_err("failed to compress %ld bytes: %s\n",
                                (long)src_size, ZSTD_getErrorName(ret));
                        memcpy(dst, src, src_size);
                        return src_size;
                }
                size = output.pos;
                size = process_header(record, size);
                compressed += size;
                dst += size;
                dst_size -= size;
        }

        return compressed;
}

size_t zstd_decompress_stream(struct zstd_data *data, void *src, size_t src_size,
                              void *dst, size_t dst_size)
{
        size_t ret;
        ZSTD_inBuffer input = { src, src_size, 0 };
        ZSTD_outBuffer output = { dst, dst_size, 0 };

        if (!data->dstream) {
                data->dstream = ZSTD_createDStream();
                if (data->dstream == NULL) {
                        pr_err("Couldn't create decompression stream.\n");
                        return 0;
                }

                ret = ZSTD_initDStream(data->dstream);
                if (ZSTD_isError(ret)) {
                        pr_err("Failed to initialize decompression stream: %s\n",
                                ZSTD_getErrorName(ret));
                        return 0;
                }
        }
        while (input.pos < input.size) {
                ret = ZSTD_decompressStream(data->dstream, &output, &input);
                if (ZSTD_isError(ret)) {
                        pr_err("failed to decompress (B): %zd -> %zd, dst_size %zd : %s\n",
                               src_size, output.size, dst_size, ZSTD_getErrorName(ret));
                        break;
                }
                output.dst  = dst + output.pos;
                output.size = dst_size - output.pos;
        }

        return output.pos;
}