#include <sys/types.h>
#include <sys/sunddi.h>
#include <sys/kmem.h>
#include <sys/sysmacros.h>
#include <smbsrv/smb_kproto.h>
#include <smbsrv/alloc.h>
#define SMB_SMH_MAGIC 0x534D485F
#define SMB_SMH_VALID(_smh_) ASSERT((_smh_)->smh_magic == SMB_SMH_MAGIC)
#define SMB_MEM2SMH(_mem_) ((smb_mem_header_t *)(_mem_) - 1)
typedef struct smb_mem_header {
uint32_t smh_magic;
size_t smh_size;
smb_request_t *smh_sr;
list_node_t smh_lnd;
} smb_mem_header_t;
static void *smb_alloc(smb_request_t *, size_t, boolean_t);
static void smb_free(smb_request_t *, void *, boolean_t);
static void *smb_realloc(smb_request_t *, void *, size_t, boolean_t);
void *
smb_mem_alloc(size_t size)
{
return (smb_alloc(NULL, size, B_FALSE));
}
void *
smb_mem_zalloc(size_t size)
{
return (smb_alloc(NULL, size, B_TRUE));
}
void *
smb_mem_realloc(void *ptr, size_t size)
{
return (smb_realloc(NULL, ptr, size, B_FALSE));
}
void *
smb_mem_rezalloc(void *ptr, size_t size)
{
return (smb_realloc(NULL, ptr, size, B_TRUE));
}
void
smb_mem_free(void *ptr)
{
smb_free(NULL, ptr, B_FALSE);
}
void
smb_mem_zfree(void *ptr)
{
smb_free(NULL, ptr, B_TRUE);
}
char *
smb_mem_strdup(const char *ptr)
{
char *p;
size_t size;
size = strlen(ptr) + 1;
p = smb_alloc(NULL, size, B_FALSE);
bcopy(ptr, p, size);
return (p);
}
void
smb_srm_init(smb_request_t *sr)
{
list_create(&sr->sr_storage, sizeof (smb_mem_header_t),
offsetof(smb_mem_header_t, smh_lnd));
}
void
smb_srm_fini(smb_request_t *sr)
{
smb_mem_header_t *smh;
while ((smh = list_head(&sr->sr_storage)) != NULL)
smb_free(sr, ++smh, B_FALSE);
list_destroy(&sr->sr_storage);
}
void *
smb_srm_alloc(smb_request_t *sr, size_t size)
{
return (smb_alloc(sr, size, B_FALSE));
}
void *
smb_srm_zalloc(smb_request_t *sr, size_t size)
{
return (smb_alloc(sr, size, B_TRUE));
}
void *
smb_srm_realloc(smb_request_t *sr, void *p, size_t size)
{
return (smb_realloc(sr, p, size, B_FALSE));
}
void *
smb_srm_rezalloc(smb_request_t *sr, void *p, size_t size)
{
return (smb_realloc(sr, p, size, B_TRUE));
}
char *
smb_srm_strdup(smb_request_t *sr, const char *s)
{
char *p;
size_t size;
size = strlen(s) + 1;
p = smb_srm_alloc(sr, size);
bcopy(s, p, size);
return (p);
}
static void *
smb_alloc(smb_request_t *sr, size_t size, boolean_t zero)
{
smb_mem_header_t *smh;
if (zero) {
smh = kmem_zalloc(size + sizeof (smb_mem_header_t), KM_SLEEP);
} else {
smh = kmem_alloc(size + sizeof (smb_mem_header_t), KM_SLEEP);
smh->smh_sr = NULL;
bzero(&smh->smh_lnd, sizeof (smh->smh_lnd));
}
smh->smh_sr = sr;
smh->smh_size = size;
smh->smh_magic = SMB_SMH_MAGIC;
if (sr != NULL) {
SMB_REQ_VALID(sr);
list_insert_tail(&sr->sr_storage, smh);
}
return (++smh);
}
static void
smb_free(smb_request_t *sr, void *ptr, boolean_t zero)
{
smb_mem_header_t *smh;
if (ptr != NULL) {
smh = SMB_MEM2SMH(ptr);
SMB_SMH_VALID(smh);
ASSERT(sr == smh->smh_sr);
if (sr != NULL) {
SMB_REQ_VALID(sr);
list_remove(&sr->sr_storage, smh);
}
if (zero)
bzero(ptr, smh->smh_size);
smh->smh_magic = 0;
kmem_free(smh, smh->smh_size + sizeof (smb_mem_header_t));
}
}
static void *
smb_realloc(smb_request_t *sr, void *ptr, size_t size, boolean_t zero)
{
smb_mem_header_t *smh;
void *new_ptr;
if (ptr == NULL)
return (smb_alloc(sr, size, zero));
smh = SMB_MEM2SMH(ptr);
SMB_SMH_VALID(smh);
ASSERT(sr == smh->smh_sr);
if (size == 0) {
smb_free(sr, ptr, zero);
return (NULL);
}
if (smh->smh_size >= size) {
if ((zero) && (smh->smh_size > size))
bzero((caddr_t)ptr + size, smh->smh_size - size);
return (ptr);
}
new_ptr = smb_alloc(sr, size, B_FALSE);
bcopy(ptr, new_ptr, smh->smh_size);
if (zero)
bzero((caddr_t)new_ptr + smh->smh_size, size - smh->smh_size);
smb_free(sr, ptr, zero);
return (new_ptr);
}