root/usr/src/common/crypto/modes/ctr.c
/*
 * CDDL HEADER START
 *
 * The contents of this file are subject to the terms of the
 * Common Development and Distribution License (the "License").
 * You may not use this file except in compliance with the License.
 *
 * You can obtain a copy of the license at usr/src/OPENSOLARIS.LICENSE
 * or http://www.opensolaris.org/os/licensing.
 * See the License for the specific language governing permissions
 * and limitations under the License.
 *
 * When distributing Covered Code, include this CDDL HEADER in each
 * file and include the License file at usr/src/OPENSOLARIS.LICENSE.
 * If applicable, add the following below this CDDL HEADER, with the
 * fields enclosed by brackets "[]" replaced with your own identifying
 * information: Portions Copyright [yyyy] [name of copyright owner]
 *
 * CDDL HEADER END
 */
/*
 * Copyright 2008 Sun Microsystems, Inc.  All rights reserved.
 * Use is subject to license terms.
 *
 * Copyright 2019 Joyent, Inc.
 */

#ifndef _KERNEL
#include <strings.h>
#include <limits.h>
#include <assert.h>
#include <security/cryptoki.h>
#endif

#include <sys/debug.h>
#include <sys/types.h>
#include <modes/modes.h>
#include <sys/crypto/common.h>
#include <sys/crypto/impl.h>
#include <sys/byteorder.h>

/*
 * CTR (counter mode) is a stream cipher.  That is, it generates a
 * pseudo-random keystream that is used to XOR with the input to
 * encrypt or decrypt.  The pseudo-random keystream is generated by
 * concatenating a nonce (supplied during initialzation) and with a
 * counter (initialized to zero) to form an input block to the cipher
 * mechanism.  The resulting output of the cipher is used as a chunk
 * of the pseudo-random keystream.  Once all of the bytes of the
 * keystream block have been used, the counter is incremented and
 * the process repeats.
 *
 * Since this is a stream cipher, we do not accumulate input cipher
 * text like we do for block modes.  Instead we use ctr_ctx_t->ctr_offset
 * to track the amount of bytes used in the current keystream block.
 */

static void
ctr_new_keyblock(ctr_ctx_t *ctx,
    int (*cipher)(const void *ks, const uint8_t *pt, uint8_t *ct))
{
        uint64_t lower_counter, upper_counter;

        /* increment the counter */
        lower_counter = ntohll(ctx->ctr_cb[1] & ctx->ctr_lower_mask);
        lower_counter = htonll(lower_counter + 1);
        lower_counter &= ctx->ctr_lower_mask;
        ctx->ctr_cb[1] = (ctx->ctr_cb[1] & ~(ctx->ctr_lower_mask)) |
            lower_counter;

        /* wrap around */
        if (lower_counter == 0) {
                upper_counter = ntohll(ctx->ctr_cb[0] & ctx->ctr_upper_mask);
                upper_counter = htonll(upper_counter + 1);
                upper_counter &= ctx->ctr_upper_mask;
                ctx->ctr_cb[0] = (ctx->ctr_cb[0] & ~(ctx->ctr_upper_mask)) |
                    upper_counter;
        }

        /* generate the new keyblock */
        cipher(ctx->ctr_keysched, (uint8_t *)ctx->ctr_cb,
            (uint8_t *)ctx->ctr_keystream);
        ctx->ctr_offset = 0;
}

/*
 * XOR the input with the keystream and write the result to out.
 * This requires that the amount of data in 'in' is >= outlen
 * (ctr_mode_contiguous_blocks() guarantees this for us before we are
 * called).  As CTR mode is a stream cipher, we cannot use a cipher's
 * xxx_xor_block function (e.g. aes_xor_block()) as we must handle
 * arbitrary lengths of input and should not buffer/accumulate partial blocks
 * between calls.
 */
static void
ctr_xor(ctr_ctx_t *ctx, const uint8_t *in, uint8_t *out, size_t outlen,
    size_t block_size,
    int (*cipher)(const void *ks, const uint8_t *pt, uint8_t *ct))
{
        const uint8_t *keyp;
        size_t keyamt;

        while (outlen > 0) {
                /*
                 * This occurs once we've consumed all the bytes in the
                 * current block of the keystream. ctr_init_ctx() creates
                 * the initial block of the keystream, so we always start
                 * with a full block of key data.
                 */
                if (ctx->ctr_offset == block_size) {
                        ctr_new_keyblock(ctx, cipher);
                }

                keyp = (uint8_t *)ctx->ctr_keystream + ctx->ctr_offset;
                keyamt = block_size - ctx->ctr_offset;

                /*
                 * xor a byte at a time (while we have data and output
                 * space) and try to get in, out, and keyp 32-bit aligned.
                 * If in, out, and keyp all do become 32-bit aligned,
                 * we switch to xor-ing 32-bits at a time until we run out
                 * of 32-bit chunks, then switch back to xor-ing a byte at
                 * a time for any remainder.
                 */
                while (keyamt > 0 && outlen > 0 &&
                    !IS_P2ALIGNED(in, sizeof (uint32_t)) &&
                    !IS_P2ALIGNED(out, sizeof (uint32_t)) &&
                    !IS_P2ALIGNED(keyp, sizeof (uint32_t))) {
                        *out++ = *in++ ^ *keyp++;
                        keyamt--;
                        outlen--;
                }

                if (keyamt > 3 && outlen > 3 &&
                    IS_P2ALIGNED(in, sizeof (uint32_t)) &&
                    IS_P2ALIGNED(out, sizeof (uint32_t)) &&
                    IS_P2ALIGNED(keyp, sizeof (uint32_t))) {
                        const uint32_t *key32 = (const uint32_t *)keyp;
                        const uint32_t *in32 = (const uint32_t *)in;
                        uint32_t *out32 = (uint32_t *)out;

                        do {
                                *out32++ = *in32++ ^ *key32++;
                                keyamt -= sizeof (uint32_t);
                                outlen -= sizeof (uint32_t);
                        } while (keyamt > 3 && outlen > 3);

                        keyp = (const uint8_t *)key32;
                        in = (const uint8_t *)in32;
                        out = (uint8_t *)out32;
                }

                while (keyamt > 0 && outlen > 0) {
                        *out++ = *in++ ^ *keyp++;
                        keyamt--;
                        outlen--;
                }

                ctx->ctr_offset = block_size - keyamt;
        }
}

/*
 * Encrypt and decrypt multiple blocks of data in counter mode.
 */
int
ctr_mode_contiguous_blocks(ctr_ctx_t *ctx, char *in, size_t in_length,
    crypto_data_t *out, size_t block_size,
    int (*cipher)(const void *ks, const uint8_t *pt, uint8_t *ct))
{
        size_t in_remainder = in_length;
        uint8_t *inp = (uint8_t *)in;
        void *iov_or_mp;
        offset_t offset;
        uint8_t *out_data;
        uint8_t *out_data_remainder;
        size_t out_data_len;

        if (block_size > sizeof (ctx->ctr_keystream))
                return (CRYPTO_ARGUMENTS_BAD);

        if (out == NULL)
                return (CRYPTO_ARGUMENTS_BAD);

        /* Make sure 'out->cd_offset + in_length' doesn't overflow. */
        if (out->cd_offset < 0)
                return (CRYPTO_DATA_LEN_RANGE);
        if (SIZE_MAX - in_length < (size_t)out->cd_offset)
                return (CRYPTO_ENCRYPTED_DATA_LEN_RANGE);

        /*
         * This check guarantees 'out' contains sufficient space for
         * the resulting output.
         */
        if (out->cd_offset + in_length > out->cd_length)
                return (CRYPTO_BUFFER_TOO_SMALL);

        crypto_init_ptrs(out, &iov_or_mp, &offset);

        /* Now XOR the output with the keystream */
        while (in_remainder > 0) {
                /*
                 * If out is a uio_t or an mblk_t, in_remainder might be
                 * larger than an individual iovec_t or mblk_t in out.
                 * crypto_get_ptrs uses the value of offset to set the
                 * the value of out_data to the correct address for writing
                 * and sets out_data_len to reflect the largest amount of data
                 * (up to in_remainder) that can be written to out_data. It
                 * also increments offset by out_data_len. out_data_remainder
                 * is set to the start of the next segment for writing, however
                 * it is not used here since the updated value of offset
                 * will be used in the next loop iteration to locate the
                 * next mblk_t/iovec_t. Since the sum of the size of all data
                 * buffers in 'out' (out->cd_length) was checked immediately
                 * prior to starting the loop, we should always terminate
                 * the loop.
                 */
                crypto_get_ptrs(out, &iov_or_mp, &offset, &out_data,
                    &out_data_len, &out_data_remainder, in_remainder);

                /*
                 * crypto_get_ptrs() should guarantee these, but act as a
                 * safeguard in case the behavior ever changes.
                 */
                ASSERT3U(out_data_len, <=, in_remainder);
                ASSERT3U(out_data_len, >, 0);

                ctr_xor(ctx, inp, out_data, out_data_len, block_size, cipher);

                inp += out_data_len;
                in_remainder -= out_data_len;
        }

        out->cd_offset += in_length;

        return (CRYPTO_SUCCESS);
}

int
ctr_init_ctx(ctr_ctx_t *ctr_ctx, ulong_t count, uint8_t *cb,
    int (*cipher)(const void *ks, const uint8_t *pt, uint8_t *ct),
    void (*copy_block)(uint8_t *, uint8_t *))
{
        uint64_t upper_mask = 0;
        uint64_t lower_mask = 0;

        if (count == 0 || count > 128) {
                return (CRYPTO_MECHANISM_PARAM_INVALID);
        }
        /* upper 64 bits of the mask */
        if (count >= 64) {
                count -= 64;
                upper_mask = (count == 64) ? UINT64_MAX : (1ULL << count) - 1;
                lower_mask = UINT64_MAX;
        } else {
                /* now the lower 63 bits */
                lower_mask = (1ULL << count) - 1;
        }
        ctr_ctx->ctr_lower_mask = htonll(lower_mask);
        ctr_ctx->ctr_upper_mask = htonll(upper_mask);

        copy_block(cb, (uchar_t *)ctr_ctx->ctr_cb);
        ctr_ctx->ctr_lastp = (uint8_t *)&ctr_ctx->ctr_cb[0];

        /* Generate the first block of the keystream */
        cipher(ctr_ctx->ctr_keysched, (uint8_t *)ctr_ctx->ctr_cb,
            (uint8_t *)ctr_ctx->ctr_keystream);

        ctr_ctx->ctr_flags |= CTR_MODE;
        return (CRYPTO_SUCCESS);
}

/* ARGSUSED */
void *
ctr_alloc_ctx(int kmflag)
{
        ctr_ctx_t *ctr_ctx;

#ifdef _KERNEL
        if ((ctr_ctx = kmem_zalloc(sizeof (ctr_ctx_t), kmflag)) == NULL)
#else
        if ((ctr_ctx = calloc(1, sizeof (ctr_ctx_t))) == NULL)
#endif
                return (NULL);

        ctr_ctx->ctr_flags = CTR_MODE;
        return (ctr_ctx);
}