#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <openssl/pem.h>
#include <openssl/pkcs12.h>
#include "err_local.h"
#include "pkcs12_local.h"
#include "x509_local.h"
static int
alg_get(X509_ALGOR *alg, int *nid, int *iter, int *salt_len)
{
const ASN1_OBJECT *aobj;
int param_type;
const void *param;
PBEPARAM *pbe = NULL;
int ret = 0;
*nid = *iter = *salt_len = 0;
X509_ALGOR_get0(&aobj, ¶m_type, ¶m, alg);
if (param_type != V_ASN1_SEQUENCE)
goto err;
if ((pbe = ASN1_item_unpack(param, &PBEPARAM_it)) == NULL)
goto err;
*nid = OBJ_obj2nid(alg->algorithm);
*iter = ASN1_INTEGER_get(pbe->iter);
*salt_len = pbe->salt->length;
ret = 1;
err:
PBEPARAM_free(pbe);
return ret;
}
static int
newpass_bag(PKCS12_SAFEBAG *bag, const char *oldpass, const char *newpass)
{
PKCS8_PRIV_KEY_INFO *p8 = NULL;
X509_SIG *keybag;
int nid, salt_len, iter;
int ret = 0;
if (OBJ_obj2nid(bag->type) != NID_pkcs8ShroudedKeyBag)
goto done;
if ((p8 = PKCS8_decrypt(bag->value.shkeybag, oldpass, -1)) == NULL)
goto err;
if (!alg_get(bag->value.shkeybag->algor, &nid, &iter, &salt_len))
goto err;
if ((keybag = PKCS8_encrypt(nid, NULL, newpass, -1, NULL, salt_len,
iter, p8)) == NULL)
goto err;
X509_SIG_free(bag->value.shkeybag);
bag->value.shkeybag = keybag;
done:
ret = 1;
err:
PKCS8_PRIV_KEY_INFO_free(p8);
return ret;
}
static int
newpass_bags(STACK_OF(PKCS12_SAFEBAG) *bags, const char *oldpass,
const char *newpass)
{
int i;
for (i = 0; i < sk_PKCS12_SAFEBAG_num(bags); i++) {
PKCS12_SAFEBAG *bag = sk_PKCS12_SAFEBAG_value(bags, i);
if (!newpass_bag(bag, oldpass, newpass))
return 0;
}
return 1;
}
static int
pkcs7_repack_data(PKCS7 *pkcs7, STACK_OF(PKCS7) *safes, const char *oldpass,
const char *newpass)
{
STACK_OF(PKCS12_SAFEBAG) *bags;
PKCS7 *data = NULL;
int ret = 0;
if ((bags = PKCS12_unpack_p7data(pkcs7)) == NULL)
goto err;
if (!newpass_bags(bags, oldpass, newpass))
goto err;
if ((data = PKCS12_pack_p7data(bags)) == NULL)
goto err;
if (sk_PKCS7_push(safes, data) == 0)
goto err;
data = NULL;
ret = 1;
err:
sk_PKCS12_SAFEBAG_pop_free(bags, PKCS12_SAFEBAG_free);
PKCS7_free(data);
return ret;
}
static int
pkcs7_repack_encdata(PKCS7 *pkcs7, STACK_OF(PKCS7) *safes, const char *oldpass,
const char *newpass)
{
STACK_OF(PKCS12_SAFEBAG) *bags;
int nid, iter, salt_len;
PKCS7 *data = NULL;
int ret = 0;
if ((bags = PKCS12_unpack_p7encdata(pkcs7, oldpass, -1)) == NULL)
goto err;
if (!alg_get(pkcs7->d.encrypted->enc_data->algorithm, &nid,
&iter, &salt_len))
goto err;
if (!newpass_bags(bags, oldpass, newpass))
goto err;
if ((data = PKCS12_pack_p7encdata(nid, newpass, -1, NULL, salt_len,
iter, bags)) == NULL)
goto err;
if (!sk_PKCS7_push(safes, data))
goto err;
data = NULL;
ret = 1;
err:
sk_PKCS12_SAFEBAG_pop_free(bags, PKCS12_SAFEBAG_free);
PKCS7_free(data);
return ret;
}
static int
pkcs12_repack_authsafes(PKCS12 *pkcs12, STACK_OF(PKCS7) *safes,
const char *newpass)
{
ASN1_OCTET_STRING *old_data;
ASN1_OCTET_STRING *new_mac = NULL;
unsigned char mac[EVP_MAX_MD_SIZE];
unsigned int mac_len;
int ret = 0;
if ((old_data = pkcs12->authsafes->d.data) == NULL)
goto err;
if ((pkcs12->authsafes->d.data = ASN1_OCTET_STRING_new()) == NULL)
goto err;
if (!PKCS12_pack_authsafes(pkcs12, safes))
goto err;
if (!PKCS12_gen_mac(pkcs12, newpass, -1, mac, &mac_len))
goto err;
if ((new_mac = ASN1_OCTET_STRING_new()) == NULL)
goto err;
if (!ASN1_OCTET_STRING_set(new_mac, mac, mac_len))
goto err;
ASN1_OCTET_STRING_free(pkcs12->mac->dinfo->digest);
pkcs12->mac->dinfo->digest = new_mac;
new_mac = NULL;
ASN1_OCTET_STRING_free(old_data);
old_data = NULL;
ret = 1;
err:
if (old_data != NULL) {
ASN1_OCTET_STRING_free(pkcs12->authsafes->d.data);
pkcs12->authsafes->d.data = old_data;
}
explicit_bzero(mac, sizeof(mac));
ASN1_OCTET_STRING_free(new_mac);
return ret;
}
int
PKCS12_newpass(PKCS12 *pkcs12, const char *oldpass, const char *newpass)
{
STACK_OF(PKCS7) *authsafes = NULL, *safes = NULL;
int i;
int ret = 0;
if (pkcs12 == NULL) {
PKCS12error(PKCS12_R_INVALID_NULL_PKCS12_POINTER);
goto err;
}
if (!PKCS12_verify_mac(pkcs12, oldpass, -1)) {
PKCS12error(PKCS12_R_MAC_VERIFY_FAILURE);
goto err;
}
if ((authsafes = PKCS12_unpack_authsafes(pkcs12)) == NULL)
goto err;
if ((safes = sk_PKCS7_new_null()) == NULL)
goto err;
for (i = 0; i < sk_PKCS7_num(authsafes); i++) {
PKCS7 *pkcs7 = sk_PKCS7_value(authsafes, i);
switch (OBJ_obj2nid(pkcs7->type)) {
case NID_pkcs7_data:
if (pkcs7_repack_data(pkcs7, safes, oldpass, newpass))
goto err;
break;
case NID_pkcs7_encrypted:
if (pkcs7_repack_encdata(pkcs7, safes, oldpass, newpass))
goto err;
break;
}
}
if (!pkcs12_repack_authsafes(pkcs12, safes, newpass))
goto err;
ret = 1;
err:
sk_PKCS7_pop_free(authsafes, PKCS7_free);
sk_PKCS7_pop_free(safes, PKCS7_free);
return ret;
}
LCRYPTO_ALIAS(PKCS12_newpass);