root/arch/x86/crypto/aes-ctr-avx-x86_64.S
/* SPDX-License-Identifier: Apache-2.0 OR BSD-2-Clause */
//
// Copyright 2025 Google LLC
//
// Author: Eric Biggers <ebiggers@google.com>
//
// This file is dual-licensed, meaning that you can use it under your choice of
// either of the following two licenses:
//
// Licensed under the Apache License 2.0 (the "License").  You may obtain a copy
// of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// or
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice,
//    this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
//------------------------------------------------------------------------------
//
// This file contains x86_64 assembly implementations of AES-CTR and AES-XCTR
// using the following sets of CPU features:
//      - AES-NI && AVX
//      - VAES && AVX2
//      - VAES && AVX512BW && AVX512VL && BMI2
//
// See the function definitions at the bottom of the file for more information.

#include <linux/linkage.h>
#include <linux/cfi_types.h>

.section .rodata
.p2align 4

.Lbswap_mask:
        .octa   0x000102030405060708090a0b0c0d0e0f

.Lctr_pattern:
        .quad   0, 0
.Lone:
        .quad   1, 0
.Ltwo:
        .quad   2, 0
        .quad   3, 0

.Lfour:
        .quad   4, 0

.text

// Move a vector between memory and a register.
.macro  _vmovdqu        src, dst
.if VL < 64
        vmovdqu         \src, \dst
.else
        vmovdqu8        \src, \dst
.endif
.endm

// Move a vector between registers.
.macro  _vmovdqa        src, dst
.if VL < 64
        vmovdqa         \src, \dst
.else
        vmovdqa64       \src, \dst
.endif
.endm

// Broadcast a 128-bit value from memory to all 128-bit lanes of a vector
// register.
.macro  _vbroadcast128  src, dst
.if VL == 16
        vmovdqu         \src, \dst
.elseif VL == 32
        vbroadcasti128  \src, \dst
.else
        vbroadcasti32x4 \src, \dst
.endif
.endm

// XOR two vectors together.
.macro  _vpxor  src1, src2, dst
.if VL < 64
        vpxor           \src1, \src2, \dst
.else
        vpxord          \src1, \src2, \dst
.endif
.endm

// Load 1 <= %ecx <= 15 bytes from the pointer \src into the xmm register \dst
// and zeroize any remaining bytes.  Clobbers %rax, %rcx, and \tmp{64,32}.
.macro  _load_partial_block     src, dst, tmp64, tmp32
        sub             $8, %ecx                // LEN - 8
        jle             .Lle8\@

        // Load 9 <= LEN <= 15 bytes.
        vmovq           (\src), \dst            // Load first 8 bytes
        mov             (\src, %rcx), %rax      // Load last 8 bytes
        neg             %ecx
        shl             $3, %ecx
        shr             %cl, %rax               // Discard overlapping bytes
        vpinsrq         $1, %rax, \dst, \dst
        jmp             .Ldone\@

.Lle8\@:
        add             $4, %ecx                // LEN - 4
        jl              .Llt4\@

        // Load 4 <= LEN <= 8 bytes.
        mov             (\src), %eax            // Load first 4 bytes
        mov             (\src, %rcx), \tmp32    // Load last 4 bytes
        jmp             .Lcombine\@

.Llt4\@:
        // Load 1 <= LEN <= 3 bytes.
        add             $2, %ecx                // LEN - 2
        movzbl          (\src), %eax            // Load first byte
        jl              .Lmovq\@
        movzwl          (\src, %rcx), \tmp32    // Load last 2 bytes
.Lcombine\@:
        shl             $3, %ecx
        shl             %cl, \tmp64
        or              \tmp64, %rax            // Combine the two parts
.Lmovq\@:
        vmovq           %rax, \dst
.Ldone\@:
.endm

// Store 1 <= %ecx <= 15 bytes from the xmm register \src to the pointer \dst.
// Clobbers %rax, %rcx, and \tmp{64,32}.
.macro  _store_partial_block    src, dst, tmp64, tmp32
        sub             $8, %ecx                // LEN - 8
        jl              .Llt8\@

        // Store 8 <= LEN <= 15 bytes.
        vpextrq         $1, \src, %rax
        mov             %ecx, \tmp32
        shl             $3, %ecx
        ror             %cl, %rax
        mov             %rax, (\dst, \tmp64)    // Store last LEN - 8 bytes
        vmovq           \src, (\dst)            // Store first 8 bytes
        jmp             .Ldone\@

.Llt8\@:
        add             $4, %ecx                // LEN - 4
        jl              .Llt4\@

        // Store 4 <= LEN <= 7 bytes.
        vpextrd         $1, \src, %eax
        mov             %ecx, \tmp32
        shl             $3, %ecx
        ror             %cl, %eax
        mov             %eax, (\dst, \tmp64)    // Store last LEN - 4 bytes
        vmovd           \src, (\dst)            // Store first 4 bytes
        jmp             .Ldone\@

.Llt4\@:
        // Store 1 <= LEN <= 3 bytes.
        vpextrb         $0, \src, 0(\dst)
        cmp             $-2, %ecx               // LEN - 4 == -2, i.e. LEN == 2?
        jl              .Ldone\@
        vpextrb         $1, \src, 1(\dst)
        je              .Ldone\@
        vpextrb         $2, \src, 2(\dst)
.Ldone\@:
.endm

// Prepare the next two vectors of AES inputs in AESDATA\i0 and AESDATA\i1, and
// XOR each with the zero-th round key.  Also update LE_CTR if !\final.
.macro  _prepare_2_ctr_vecs     is_xctr, i0, i1, final=0
.if \is_xctr
  .if USE_AVX512
        vmovdqa64       LE_CTR, AESDATA\i0
        vpternlogd      $0x96, XCTR_IV, RNDKEY0, AESDATA\i0
  .else
        vpxor           XCTR_IV, LE_CTR, AESDATA\i0
        vpxor           RNDKEY0, AESDATA\i0, AESDATA\i0
  .endif
        vpaddq          LE_CTR_INC1, LE_CTR, AESDATA\i1

  .if USE_AVX512
        vpternlogd      $0x96, XCTR_IV, RNDKEY0, AESDATA\i1
  .else
        vpxor           XCTR_IV, AESDATA\i1, AESDATA\i1
        vpxor           RNDKEY0, AESDATA\i1, AESDATA\i1
  .endif
.else
        vpshufb         BSWAP_MASK, LE_CTR, AESDATA\i0
        _vpxor          RNDKEY0, AESDATA\i0, AESDATA\i0
        vpaddq          LE_CTR_INC1, LE_CTR, AESDATA\i1
        vpshufb         BSWAP_MASK, AESDATA\i1, AESDATA\i1
        _vpxor          RNDKEY0, AESDATA\i1, AESDATA\i1
.endif
.if !\final
        vpaddq          LE_CTR_INC2, LE_CTR, LE_CTR
.endif
.endm

// Do all AES rounds on the data in the given AESDATA vectors, excluding the
// zero-th and last rounds.
.macro  _aesenc_loop    vecs:vararg
        mov             KEY, %rax
1:
        _vbroadcast128  (%rax), RNDKEY
.irp i, \vecs
        vaesenc         RNDKEY, AESDATA\i, AESDATA\i
.endr
        add             $16, %rax
        cmp             %rax, RNDKEYLAST_PTR
        jne             1b
.endm

// Finalize the keystream blocks in the given AESDATA vectors by doing the last
// AES round, then XOR those keystream blocks with the corresponding data.
// Reduce latency by doing the XOR before the vaesenclast, utilizing the
// property vaesenclast(key, a) ^ b == vaesenclast(key ^ b, a).
.macro  _aesenclast_and_xor     vecs:vararg
.irp i, \vecs
        _vpxor          \i*VL(SRC), RNDKEYLAST, RNDKEY
        vaesenclast     RNDKEY, AESDATA\i, AESDATA\i
.endr
.irp i, \vecs
        _vmovdqu        AESDATA\i, \i*VL(DST)
.endr
.endm

// XOR the keystream blocks in the specified AESDATA vectors with the
// corresponding data.
.macro  _xor_data       vecs:vararg
.irp i, \vecs
        _vpxor          \i*VL(SRC), AESDATA\i, AESDATA\i
.endr
.irp i, \vecs
        _vmovdqu        AESDATA\i, \i*VL(DST)
.endr
.endm

.macro  _aes_ctr_crypt          is_xctr

        // Define register aliases V0-V15 that map to the xmm, ymm, or zmm
        // registers according to the selected Vector Length (VL).
.irp i, 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
  .if VL == 16
        .set    V\i,            %xmm\i
  .elseif VL == 32
        .set    V\i,            %ymm\i
  .elseif VL == 64
        .set    V\i,            %zmm\i
  .else
        .error "Unsupported Vector Length (VL)"
  .endif
.endr

        // Function arguments
        .set    KEY,            %rdi    // Initially points to the start of the
                                        // crypto_aes_ctx, then is advanced to
                                        // point to the index 1 round key
        .set    KEY32,          %edi    // Available as temp register after all
                                        // keystream blocks have been generated
        .set    SRC,            %rsi    // Pointer to next source data
        .set    DST,            %rdx    // Pointer to next destination data
        .set    LEN,            %ecx    // Remaining length in bytes.
                                        // Note: _load_partial_block relies on
                                        // this being in %ecx.
        .set    LEN64,          %rcx    // Zero-extend LEN before using!
        .set    LEN8,           %cl
.if \is_xctr
        .set    XCTR_IV_PTR,    %r8     // const u8 iv[AES_BLOCK_SIZE];
        .set    XCTR_CTR,       %r9     // u64 ctr;
.else
        .set    LE_CTR_PTR,     %r8     // const u64 le_ctr[2];
.endif

        // Additional local variables
        .set    RNDKEYLAST_PTR, %r10
        .set    AESDATA0,       V0
        .set    AESDATA0_XMM,   %xmm0
        .set    AESDATA1,       V1
        .set    AESDATA1_XMM,   %xmm1
        .set    AESDATA2,       V2
        .set    AESDATA3,       V3
        .set    AESDATA4,       V4
        .set    AESDATA5,       V5
        .set    AESDATA6,       V6
        .set    AESDATA7,       V7
.if \is_xctr
        .set    XCTR_IV,        V8
.else
        .set    BSWAP_MASK,     V8
.endif
        .set    LE_CTR,         V9
        .set    LE_CTR_XMM,     %xmm9
        .set    LE_CTR_INC1,    V10
        .set    LE_CTR_INC2,    V11
        .set    RNDKEY0,        V12
        .set    RNDKEYLAST,     V13
        .set    RNDKEY,         V14

        // Create the first vector of counters.
.if \is_xctr
  .if VL == 16
        vmovq           XCTR_CTR, LE_CTR
  .elseif VL == 32
        vmovq           XCTR_CTR, LE_CTR_XMM
        inc             XCTR_CTR
        vmovq           XCTR_CTR, AESDATA0_XMM
        vinserti128     $1, AESDATA0_XMM, LE_CTR, LE_CTR
  .else
        vpbroadcastq    XCTR_CTR, LE_CTR
        vpsrldq         $8, LE_CTR, LE_CTR
        vpaddq          .Lctr_pattern(%rip), LE_CTR, LE_CTR
  .endif
        _vbroadcast128  (XCTR_IV_PTR), XCTR_IV
.else
        _vbroadcast128  (LE_CTR_PTR), LE_CTR
  .if VL > 16
        vpaddq          .Lctr_pattern(%rip), LE_CTR, LE_CTR
  .endif
        _vbroadcast128  .Lbswap_mask(%rip), BSWAP_MASK
.endif

.if VL == 16
        _vbroadcast128  .Lone(%rip), LE_CTR_INC1
.elseif VL == 32
        _vbroadcast128  .Ltwo(%rip), LE_CTR_INC1
.else
        _vbroadcast128  .Lfour(%rip), LE_CTR_INC1
.endif
        vpsllq          $1, LE_CTR_INC1, LE_CTR_INC2

        // Load the AES key length: 16 (AES-128), 24 (AES-192), or 32 (AES-256).
        movl            480(KEY), %eax

        // Compute the pointer to the last round key.
        lea             6*16(KEY, %rax, 4), RNDKEYLAST_PTR

        // Load the zero-th and last round keys.
        _vbroadcast128  (KEY), RNDKEY0
        _vbroadcast128  (RNDKEYLAST_PTR), RNDKEYLAST

        // Make KEY point to the first round key.
        add             $16, KEY

        // This is the main loop, which encrypts 8 vectors of data at a time.
        add             $-8*VL, LEN
        jl              .Lloop_8x_done\@
.Lloop_8x\@:
        _prepare_2_ctr_vecs     \is_xctr, 0, 1
        _prepare_2_ctr_vecs     \is_xctr, 2, 3
        _prepare_2_ctr_vecs     \is_xctr, 4, 5
        _prepare_2_ctr_vecs     \is_xctr, 6, 7
        _aesenc_loop    0,1,2,3,4,5,6,7
        _aesenclast_and_xor 0,1,2,3,4,5,6,7
        sub             $-8*VL, SRC
        sub             $-8*VL, DST
        add             $-8*VL, LEN
        jge             .Lloop_8x\@
.Lloop_8x_done\@:
        sub             $-8*VL, LEN
        jz              .Ldone\@

        // 1 <= LEN < 8*VL.  Generate 2, 4, or 8 more vectors of keystream
        // blocks, depending on the remaining LEN.

        _prepare_2_ctr_vecs     \is_xctr, 0, 1
        _prepare_2_ctr_vecs     \is_xctr, 2, 3
        cmp             $4*VL, LEN
        jle             .Lenc_tail_atmost4vecs\@

        // 4*VL < LEN < 8*VL.  Generate 8 vectors of keystream blocks.  Use the
        // first 4 to XOR 4 full vectors of data.  Then XOR the remaining data.
        _prepare_2_ctr_vecs     \is_xctr, 4, 5
        _prepare_2_ctr_vecs     \is_xctr, 6, 7, final=1
        _aesenc_loop    0,1,2,3,4,5,6,7
        _aesenclast_and_xor 0,1,2,3
        vaesenclast     RNDKEYLAST, AESDATA4, AESDATA0
        vaesenclast     RNDKEYLAST, AESDATA5, AESDATA1
        vaesenclast     RNDKEYLAST, AESDATA6, AESDATA2
        vaesenclast     RNDKEYLAST, AESDATA7, AESDATA3
        sub             $-4*VL, SRC
        sub             $-4*VL, DST
        add             $-4*VL, LEN
        cmp             $1*VL-1, LEN
        jle             .Lxor_tail_partial_vec_0\@
        _xor_data       0
        cmp             $2*VL-1, LEN
        jle             .Lxor_tail_partial_vec_1\@
        _xor_data       1
        cmp             $3*VL-1, LEN
        jle             .Lxor_tail_partial_vec_2\@
        _xor_data       2
        cmp             $4*VL-1, LEN
        jle             .Lxor_tail_partial_vec_3\@
        _xor_data       3
        jmp             .Ldone\@

.Lenc_tail_atmost4vecs\@:
        cmp             $2*VL, LEN
        jle             .Lenc_tail_atmost2vecs\@

        // 2*VL < LEN <= 4*VL.  Generate 4 vectors of keystream blocks.  Use the
        // first 2 to XOR 2 full vectors of data.  Then XOR the remaining data.
        _aesenc_loop    0,1,2,3
        _aesenclast_and_xor 0,1
        vaesenclast     RNDKEYLAST, AESDATA2, AESDATA0
        vaesenclast     RNDKEYLAST, AESDATA3, AESDATA1
        sub             $-2*VL, SRC
        sub             $-2*VL, DST
        add             $-2*VL, LEN
        jmp             .Lxor_tail_upto2vecs\@

.Lenc_tail_atmost2vecs\@:
        // 1 <= LEN <= 2*VL.  Generate 2 vectors of keystream blocks.  Then XOR
        // the remaining data.
        _aesenc_loop    0,1
        vaesenclast     RNDKEYLAST, AESDATA0, AESDATA0
        vaesenclast     RNDKEYLAST, AESDATA1, AESDATA1

.Lxor_tail_upto2vecs\@:
        cmp             $1*VL-1, LEN
        jle             .Lxor_tail_partial_vec_0\@
        _xor_data       0
        cmp             $2*VL-1, LEN
        jle             .Lxor_tail_partial_vec_1\@
        _xor_data       1
        jmp             .Ldone\@

.Lxor_tail_partial_vec_1\@:
        add             $-1*VL, LEN
        jz              .Ldone\@
        sub             $-1*VL, SRC
        sub             $-1*VL, DST
        _vmovdqa        AESDATA1, AESDATA0
        jmp             .Lxor_tail_partial_vec_0\@

.Lxor_tail_partial_vec_2\@:
        add             $-2*VL, LEN
        jz              .Ldone\@
        sub             $-2*VL, SRC
        sub             $-2*VL, DST
        _vmovdqa        AESDATA2, AESDATA0
        jmp             .Lxor_tail_partial_vec_0\@

.Lxor_tail_partial_vec_3\@:
        add             $-3*VL, LEN
        jz              .Ldone\@
        sub             $-3*VL, SRC
        sub             $-3*VL, DST
        _vmovdqa        AESDATA3, AESDATA0

.Lxor_tail_partial_vec_0\@:
        // XOR the remaining 1 <= LEN < VL bytes.  It's easy if masked
        // loads/stores are available; otherwise it's a bit harder...
.if USE_AVX512
        mov             $-1, %rax
        bzhi            LEN64, %rax, %rax
        kmovq           %rax, %k1
        vmovdqu8        (SRC), AESDATA1{%k1}{z}
        vpxord          AESDATA1, AESDATA0, AESDATA0
        vmovdqu8        AESDATA0, (DST){%k1}
.else
  .if VL == 32
        cmp             $16, LEN
        jl              1f
        vpxor           (SRC), AESDATA0_XMM, AESDATA1_XMM
        vmovdqu         AESDATA1_XMM, (DST)
        add             $16, SRC
        add             $16, DST
        sub             $16, LEN
        jz              .Ldone\@
        vextracti128    $1, AESDATA0, AESDATA0_XMM
1:
  .endif
        mov             LEN, %r10d
        _load_partial_block     SRC, AESDATA1_XMM, KEY, KEY32
        vpxor           AESDATA1_XMM, AESDATA0_XMM, AESDATA0_XMM
        mov             %r10d, %ecx
        _store_partial_block    AESDATA0_XMM, DST, KEY, KEY32
.endif

.Ldone\@:
.if VL > 16
        vzeroupper
.endif
        RET
.endm

// Below are the definitions of the functions generated by the above macro.
// They have the following prototypes:
//
//
// void aes_ctr64_crypt_##suffix(const struct crypto_aes_ctx *key,
//                               const u8 *src, u8 *dst, int len,
//                               const u64 le_ctr[2]);
//
// void aes_xctr_crypt_##suffix(const struct crypto_aes_ctx *key,
//                              const u8 *src, u8 *dst, int len,
//                              const u8 iv[AES_BLOCK_SIZE], u64 ctr);
//
// Both functions generate |len| bytes of keystream, XOR it with the data from
// |src|, and write the result to |dst|.  On non-final calls, |len| must be a
// multiple of 16.  On the final call, |len| can be any value.
//
// aes_ctr64_crypt_* implement "regular" CTR, where the keystream is generated
// from a 128-bit big endian counter that increments by 1 for each AES block.
// HOWEVER, to keep the assembly code simple, some of the counter management is
// left to the caller.  aes_ctr64_crypt_* take the counter in little endian
// form, only increment the low 64 bits internally, do the conversion to big
// endian internally, and don't write the updated counter back to memory.  The
// caller is responsible for converting the starting IV to the little endian
// le_ctr, detecting the (very rare) case of a carry out of the low 64 bits
// being needed and splitting at that point with a carry done in between, and
// updating le_ctr after each part if the message is multi-part.
//
// aes_xctr_crypt_* implement XCTR as specified in "Length-preserving encryption
// with HCTR2" (https://eprint.iacr.org/2021/1441.pdf).  XCTR is an
// easier-to-implement variant of CTR that uses little endian byte order and
// eliminates carries.  |ctr| is the per-message block counter starting at 1.

.set    VL, 16
.set    USE_AVX512, 0
SYM_TYPED_FUNC_START(aes_ctr64_crypt_aesni_avx)
        _aes_ctr_crypt  0
SYM_FUNC_END(aes_ctr64_crypt_aesni_avx)
SYM_TYPED_FUNC_START(aes_xctr_crypt_aesni_avx)
        _aes_ctr_crypt  1
SYM_FUNC_END(aes_xctr_crypt_aesni_avx)

.set    VL, 32
.set    USE_AVX512, 0
SYM_TYPED_FUNC_START(aes_ctr64_crypt_vaes_avx2)
        _aes_ctr_crypt  0
SYM_FUNC_END(aes_ctr64_crypt_vaes_avx2)
SYM_TYPED_FUNC_START(aes_xctr_crypt_vaes_avx2)
        _aes_ctr_crypt  1
SYM_FUNC_END(aes_xctr_crypt_vaes_avx2)

.set    VL, 64
.set    USE_AVX512, 1
SYM_TYPED_FUNC_START(aes_ctr64_crypt_vaes_avx512)
        _aes_ctr_crypt  0
SYM_FUNC_END(aes_ctr64_crypt_vaes_avx512)
SYM_TYPED_FUNC_START(aes_xctr_crypt_vaes_avx512)
        _aes_ctr_crypt  1
SYM_FUNC_END(aes_xctr_crypt_vaes_avx512)