1 /* SPDX-License-Identifier: GPL-2.0-or-later */
3 * SM4 Cipher Algorithm, AES-NI/AVX optimized.
5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
7 * Copyright (c) 2021, Alibaba Group.
8 * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
11 #include <linux/module.h>
12 #include <linux/crypto.h>
13 #include <linux/kernel.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/sm4.h>
20 #define SM4_CRYPT8_BLOCK_SIZE (SM4_BLOCK_SIZE * 8)
22 asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
23 const u8 *src, int nblocks);
24 asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
25 const u8 *src, int nblocks);
26 asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
27 const u8 *src, u8 *iv);
28 asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
29 const u8 *src, u8 *iv);
30 asmlinkage void sm4_aesni_avx_cfb_dec_blk8(const u32 *rk, u8 *dst,
31 const u8 *src, u8 *iv);
33 static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
36 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
38 return sm4_expandkey(ctx, key, key_len);
41 static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
43 struct skcipher_walk walk;
47 err = skcipher_walk_virt(&walk, req, false);
49 while ((nbytes = walk.nbytes) > 0) {
50 const u8 *src = walk.src.virt.addr;
51 u8 *dst = walk.dst.virt.addr;
54 while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
55 sm4_aesni_avx_crypt8(rkey, dst, src, 8);
56 dst += SM4_CRYPT8_BLOCK_SIZE;
57 src += SM4_CRYPT8_BLOCK_SIZE;
58 nbytes -= SM4_CRYPT8_BLOCK_SIZE;
60 while (nbytes >= SM4_BLOCK_SIZE) {
61 unsigned int nblocks = min(nbytes >> 4, 4u);
62 sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
63 dst += nblocks * SM4_BLOCK_SIZE;
64 src += nblocks * SM4_BLOCK_SIZE;
65 nbytes -= nblocks * SM4_BLOCK_SIZE;
69 err = skcipher_walk_done(&walk, nbytes);
75 int sm4_avx_ecb_encrypt(struct skcipher_request *req)
77 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
78 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
80 return ecb_do_crypt(req, ctx->rkey_enc);
82 EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
84 int sm4_avx_ecb_decrypt(struct skcipher_request *req)
86 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
87 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
89 return ecb_do_crypt(req, ctx->rkey_dec);
91 EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
93 int sm4_cbc_encrypt(struct skcipher_request *req)
95 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
96 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
97 struct skcipher_walk walk;
101 err = skcipher_walk_virt(&walk, req, false);
103 while ((nbytes = walk.nbytes) > 0) {
104 const u8 *iv = walk.iv;
105 const u8 *src = walk.src.virt.addr;
106 u8 *dst = walk.dst.virt.addr;
108 while (nbytes >= SM4_BLOCK_SIZE) {
109 crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
110 sm4_crypt_block(ctx->rkey_enc, dst, dst);
112 src += SM4_BLOCK_SIZE;
113 dst += SM4_BLOCK_SIZE;
114 nbytes -= SM4_BLOCK_SIZE;
117 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
119 err = skcipher_walk_done(&walk, nbytes);
124 EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
126 int sm4_avx_cbc_decrypt(struct skcipher_request *req,
127 unsigned int bsize, sm4_crypt_func func)
129 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
130 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
131 struct skcipher_walk walk;
135 err = skcipher_walk_virt(&walk, req, false);
137 while ((nbytes = walk.nbytes) > 0) {
138 const u8 *src = walk.src.virt.addr;
139 u8 *dst = walk.dst.virt.addr;
143 while (nbytes >= bsize) {
144 func(ctx->rkey_dec, dst, src, walk.iv);
150 while (nbytes >= SM4_BLOCK_SIZE) {
151 u8 keystream[SM4_BLOCK_SIZE * 8];
152 u8 iv[SM4_BLOCK_SIZE];
153 unsigned int nblocks = min(nbytes >> 4, 8u);
156 sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
159 src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
160 dst += (nblocks - 1) * SM4_BLOCK_SIZE;
161 memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
163 for (i = nblocks - 1; i > 0; i--) {
164 crypto_xor_cpy(dst, src,
165 &keystream[i * SM4_BLOCK_SIZE],
167 src -= SM4_BLOCK_SIZE;
168 dst -= SM4_BLOCK_SIZE;
170 crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
171 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
172 dst += nblocks * SM4_BLOCK_SIZE;
173 src += (nblocks + 1) * SM4_BLOCK_SIZE;
174 nbytes -= nblocks * SM4_BLOCK_SIZE;
178 err = skcipher_walk_done(&walk, nbytes);
183 EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
185 static int cbc_decrypt(struct skcipher_request *req)
187 return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
188 sm4_aesni_avx_cbc_dec_blk8);
191 int sm4_cfb_encrypt(struct skcipher_request *req)
193 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
194 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
195 struct skcipher_walk walk;
199 err = skcipher_walk_virt(&walk, req, false);
201 while ((nbytes = walk.nbytes) > 0) {
202 u8 keystream[SM4_BLOCK_SIZE];
203 const u8 *iv = walk.iv;
204 const u8 *src = walk.src.virt.addr;
205 u8 *dst = walk.dst.virt.addr;
207 while (nbytes >= SM4_BLOCK_SIZE) {
208 sm4_crypt_block(ctx->rkey_enc, keystream, iv);
209 crypto_xor_cpy(dst, src, keystream, SM4_BLOCK_SIZE);
211 src += SM4_BLOCK_SIZE;
212 dst += SM4_BLOCK_SIZE;
213 nbytes -= SM4_BLOCK_SIZE;
216 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
219 if (walk.nbytes == walk.total && nbytes > 0) {
220 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
221 crypto_xor_cpy(dst, src, keystream, nbytes);
225 err = skcipher_walk_done(&walk, nbytes);
230 EXPORT_SYMBOL_GPL(sm4_cfb_encrypt);
232 int sm4_avx_cfb_decrypt(struct skcipher_request *req,
233 unsigned int bsize, sm4_crypt_func func)
235 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
236 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
237 struct skcipher_walk walk;
241 err = skcipher_walk_virt(&walk, req, false);
243 while ((nbytes = walk.nbytes) > 0) {
244 const u8 *src = walk.src.virt.addr;
245 u8 *dst = walk.dst.virt.addr;
249 while (nbytes >= bsize) {
250 func(ctx->rkey_enc, dst, src, walk.iv);
256 while (nbytes >= SM4_BLOCK_SIZE) {
257 u8 keystream[SM4_BLOCK_SIZE * 8];
258 unsigned int nblocks = min(nbytes >> 4, 8u);
260 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
262 memcpy(&keystream[SM4_BLOCK_SIZE], src,
263 (nblocks - 1) * SM4_BLOCK_SIZE);
264 memcpy(walk.iv, src + (nblocks - 1) * SM4_BLOCK_SIZE,
267 sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
270 crypto_xor_cpy(dst, src, keystream,
271 nblocks * SM4_BLOCK_SIZE);
272 dst += nblocks * SM4_BLOCK_SIZE;
273 src += nblocks * SM4_BLOCK_SIZE;
274 nbytes -= nblocks * SM4_BLOCK_SIZE;
280 if (walk.nbytes == walk.total && nbytes > 0) {
281 u8 keystream[SM4_BLOCK_SIZE];
283 sm4_crypt_block(ctx->rkey_enc, keystream, walk.iv);
284 crypto_xor_cpy(dst, src, keystream, nbytes);
288 err = skcipher_walk_done(&walk, nbytes);
293 EXPORT_SYMBOL_GPL(sm4_avx_cfb_decrypt);
295 static int cfb_decrypt(struct skcipher_request *req)
297 return sm4_avx_cfb_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
298 sm4_aesni_avx_cfb_dec_blk8);
301 int sm4_avx_ctr_crypt(struct skcipher_request *req,
302 unsigned int bsize, sm4_crypt_func func)
304 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
305 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
306 struct skcipher_walk walk;
310 err = skcipher_walk_virt(&walk, req, false);
312 while ((nbytes = walk.nbytes) > 0) {
313 const u8 *src = walk.src.virt.addr;
314 u8 *dst = walk.dst.virt.addr;
318 while (nbytes >= bsize) {
319 func(ctx->rkey_enc, dst, src, walk.iv);
325 while (nbytes >= SM4_BLOCK_SIZE) {
326 u8 keystream[SM4_BLOCK_SIZE * 8];
327 unsigned int nblocks = min(nbytes >> 4, 8u);
330 for (i = 0; i < nblocks; i++) {
331 memcpy(&keystream[i * SM4_BLOCK_SIZE],
332 walk.iv, SM4_BLOCK_SIZE);
333 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
335 sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
338 crypto_xor_cpy(dst, src, keystream,
339 nblocks * SM4_BLOCK_SIZE);
340 dst += nblocks * SM4_BLOCK_SIZE;
341 src += nblocks * SM4_BLOCK_SIZE;
342 nbytes -= nblocks * SM4_BLOCK_SIZE;
348 if (walk.nbytes == walk.total && nbytes > 0) {
349 u8 keystream[SM4_BLOCK_SIZE];
351 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
352 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
354 sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
356 crypto_xor_cpy(dst, src, keystream, nbytes);
362 err = skcipher_walk_done(&walk, nbytes);
367 EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
369 static int ctr_crypt(struct skcipher_request *req)
371 return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
372 sm4_aesni_avx_ctr_enc_blk8);
375 static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
378 .cra_name = "__ecb(sm4)",
379 .cra_driver_name = "__ecb-sm4-aesni-avx",
381 .cra_flags = CRYPTO_ALG_INTERNAL,
382 .cra_blocksize = SM4_BLOCK_SIZE,
383 .cra_ctxsize = sizeof(struct sm4_ctx),
384 .cra_module = THIS_MODULE,
386 .min_keysize = SM4_KEY_SIZE,
387 .max_keysize = SM4_KEY_SIZE,
388 .walksize = 8 * SM4_BLOCK_SIZE,
389 .setkey = sm4_skcipher_setkey,
390 .encrypt = sm4_avx_ecb_encrypt,
391 .decrypt = sm4_avx_ecb_decrypt,
394 .cra_name = "__cbc(sm4)",
395 .cra_driver_name = "__cbc-sm4-aesni-avx",
397 .cra_flags = CRYPTO_ALG_INTERNAL,
398 .cra_blocksize = SM4_BLOCK_SIZE,
399 .cra_ctxsize = sizeof(struct sm4_ctx),
400 .cra_module = THIS_MODULE,
402 .min_keysize = SM4_KEY_SIZE,
403 .max_keysize = SM4_KEY_SIZE,
404 .ivsize = SM4_BLOCK_SIZE,
405 .walksize = 8 * SM4_BLOCK_SIZE,
406 .setkey = sm4_skcipher_setkey,
407 .encrypt = sm4_cbc_encrypt,
408 .decrypt = cbc_decrypt,
411 .cra_name = "__cfb(sm4)",
412 .cra_driver_name = "__cfb-sm4-aesni-avx",
414 .cra_flags = CRYPTO_ALG_INTERNAL,
416 .cra_ctxsize = sizeof(struct sm4_ctx),
417 .cra_module = THIS_MODULE,
419 .min_keysize = SM4_KEY_SIZE,
420 .max_keysize = SM4_KEY_SIZE,
421 .ivsize = SM4_BLOCK_SIZE,
422 .chunksize = SM4_BLOCK_SIZE,
423 .walksize = 8 * SM4_BLOCK_SIZE,
424 .setkey = sm4_skcipher_setkey,
425 .encrypt = sm4_cfb_encrypt,
426 .decrypt = cfb_decrypt,
429 .cra_name = "__ctr(sm4)",
430 .cra_driver_name = "__ctr-sm4-aesni-avx",
432 .cra_flags = CRYPTO_ALG_INTERNAL,
434 .cra_ctxsize = sizeof(struct sm4_ctx),
435 .cra_module = THIS_MODULE,
437 .min_keysize = SM4_KEY_SIZE,
438 .max_keysize = SM4_KEY_SIZE,
439 .ivsize = SM4_BLOCK_SIZE,
440 .chunksize = SM4_BLOCK_SIZE,
441 .walksize = 8 * SM4_BLOCK_SIZE,
442 .setkey = sm4_skcipher_setkey,
443 .encrypt = ctr_crypt,
444 .decrypt = ctr_crypt,
448 static struct simd_skcipher_alg *
449 simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
451 static int __init sm4_init(void)
453 const char *feature_name;
455 if (!boot_cpu_has(X86_FEATURE_AVX) ||
456 !boot_cpu_has(X86_FEATURE_AES) ||
457 !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
458 pr_info("AVX or AES-NI instructions are not detected.\n");
462 if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
464 pr_info("CPU feature '%s' is not supported.\n", feature_name);
468 return simd_register_skciphers_compat(sm4_aesni_avx_skciphers,
469 ARRAY_SIZE(sm4_aesni_avx_skciphers),
470 simd_sm4_aesni_avx_skciphers);
473 static void __exit sm4_exit(void)
475 simd_unregister_skciphers(sm4_aesni_avx_skciphers,
476 ARRAY_SIZE(sm4_aesni_avx_skciphers),
477 simd_sm4_aesni_avx_skciphers);
480 module_init(sm4_init);
481 module_exit(sm4_exit);
483 MODULE_LICENSE("GPL v2");
484 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
485 MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
486 MODULE_ALIAS_CRYPTO("sm4");
487 MODULE_ALIAS_CRYPTO("sm4-aesni-avx");