root/src/tests/system/kernel/file_corruption/driver/checksum_device.cpp
/*
 * Copyright 2010, Ingo Weinhold, ingo_weinhold@gmx.de.
 * Distributed under the terms of the MIT License.
 */


#include "checksum_device.h"

#include <ctype.h>
#include <errno.h>
#include <fcntl.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>

#include <algorithm>

#include <device_manager.h>

#include <AutoDeleter.h>
#include <util/AutoLock.h>
#include <util/DoublyLinkedList.h>

#include <fs/KPath.h>
#include <lock.h>
#include <vm/vm.h>

#include "dma_resources.h"
#include "io_requests.h"
#include "IOSchedulerSimple.h"

#include "CheckSum.h"


//#define TRACE_CHECK_SUM_DEVICE
#ifdef TRACE_CHECK_SUM_DEVICE
#       define TRACE(x...)      dprintf(x)
#else
#       define TRACE(x) do {} while (false)
#endif


// parameters for the DMA resource
static const uint32 kDMAResourceBufferCount                     = 16;
static const uint32 kDMAResourceBounceBufferCount       = 16;

static const size_t kCheckSumLength = sizeof(CheckSum);
static const uint32 kCheckSumsPerBlock = B_PAGE_SIZE / sizeof(CheckSum);

static const char* const kDriverModuleName
        = "drivers/disk/virtual/checksum_device/driver_v1";
static const char* const kControlDeviceModuleName
        = "drivers/disk/virtual/checksum_device/control/device_v1";
static const char* const kRawDeviceModuleName
        = "drivers/disk/virtual/checksum_device/raw/device_v1";

static const char* const kControlDeviceName
        = "disk/virtual/checksum_device/control";
static const char* const kRawDeviceBaseName = "disk/virtual/checksum_device";

static const char* const kFilePathItem = "checksum_device/file_path";


struct RawDevice;
typedef DoublyLinkedList<RawDevice> RawDeviceList;

struct device_manager_info* sDeviceManager;

static RawDeviceList sDeviceList;
static mutex sDeviceListLock = MUTEX_INITIALIZER("checksum device list");


struct CheckSumBlock : public DoublyLinkedListLinkImpl<CheckSumBlock> {
        uint64          blockIndex;
        bool            used;
        bool            dirty;
        CheckSum        checkSums[kCheckSumsPerBlock];

        CheckSumBlock()
                :
                used(false)
        {
        }
};


struct CheckSumCache {
        CheckSumCache()
        {
                mutex_init(&fLock, "check sum cache");
        }

        ~CheckSumCache()
        {
                while (CheckSumBlock* block = fBlocks.RemoveHead())
                        delete block;

                mutex_destroy(&fLock);
        }

        status_t Init(int fd, uint64 blockCount, uint32 cachedBlockCount)
        {
                fBlockCount = blockCount;
                fFD = fd;

                for (uint32 i = 0; i < cachedBlockCount; i++) {
                        CheckSumBlock* block = new(std::nothrow) CheckSumBlock;
                        if (block == NULL)
                                return B_NO_MEMORY;

                        fBlocks.Add(block);
                }

                return B_OK;
        }

        status_t GetCheckSum(uint64 blockIndex, CheckSum& checkSum)
        {
                ASSERT(blockIndex < fBlockCount);

                MutexLocker locker(fLock);

                CheckSumBlock* block;
                status_t error = _GetBlock(
                        fBlockCount + blockIndex / kCheckSumsPerBlock, block);
                if (error != B_OK)
                        return error;

                checkSum = block->checkSums[blockIndex % kCheckSumsPerBlock];

                return B_OK;
        }

        status_t SetCheckSum(uint64 blockIndex, const CheckSum& checkSum)
        {
                ASSERT(blockIndex < fBlockCount);

                MutexLocker locker(fLock);

                CheckSumBlock* block;
                status_t error = _GetBlock(
                        fBlockCount + blockIndex / kCheckSumsPerBlock, block);
                if (error != B_OK)
                        return error;

                block->checkSums[blockIndex % kCheckSumsPerBlock] = checkSum;
                block->dirty = true;

#ifdef TRACE_CHECK_SUM_DEVICE
                TRACE("checksum_device: setting check sum of block %" B_PRIu64 " to: ",
                        blockIndex);
                for (size_t i = 0; i < kCheckSumLength; i++)
                        TRACE("%02x", checkSum.Data()[i]);
                TRACE("\n");
#endif

                return B_OK;
        }

private:
        typedef DoublyLinkedList<CheckSumBlock> BlockList;

private:
        status_t _GetBlock(uint64 blockIndex, CheckSumBlock*& _block)
        {
                // check whether we have already cached the block
                for (BlockList::Iterator it = fBlocks.GetIterator();
                        CheckSumBlock* block = it.Next();) {
                        if (block->used && blockIndex == block->blockIndex) {
                                // we know it -- requeue and return
                                it.Remove();
                                fBlocks.Add(block);
                                _block = block;
                                return B_OK;
                        }
                }

                // flush the least recently used block and recycle it
                CheckSumBlock* block = fBlocks.Head();
                status_t error = _FlushBlock(block);
                if (error != B_OK)
                        return error;

                error = _ReadBlock(block, blockIndex);
                if (error != B_OK)
                        return error;

                // requeue
                fBlocks.Remove(block);
                fBlocks.Add(block);

                _block = block;
                return B_OK;
        }

        status_t _FlushBlock(CheckSumBlock* block)
        {
                if (!block->used || !block->dirty)
                        return B_OK;

                ssize_t written = pwrite(fFD, block->checkSums, B_PAGE_SIZE,
                        block->blockIndex * B_PAGE_SIZE);
                if (written < 0)
                        return errno;
                if (written != B_PAGE_SIZE)
                        return B_ERROR;

                block->dirty = false;
                return B_OK;
        }

        status_t _ReadBlock(CheckSumBlock* block, uint64 blockIndex)
        {
                // mark unused for the failure cases -- reset later
                block->used = false;

                ssize_t bytesRead = pread(fFD, block->checkSums, B_PAGE_SIZE,
                        blockIndex * B_PAGE_SIZE);
                if (bytesRead < 0)
                        return errno;
                if (bytesRead != B_PAGE_SIZE)
                        return B_ERROR;

                block->blockIndex = blockIndex;
                block->used = true;
                block->dirty = false;

                return B_OK;
        }

private:
        mutex           fLock;
        uint64          fBlockCount;
        int                     fFD;
        BlockList       fBlocks;        // LRU first
};


struct Device {
        Device(device_node* node)
                :
                fNode(node)
        {
                mutex_init(&fLock, "checksum device");
        }

        virtual ~Device()
        {
                mutex_destroy(&fLock);
        }

        bool Lock()             { mutex_lock(&fLock); return true; }
        void Unlock()   { mutex_unlock(&fLock); }

        device_node* Node() const       { return fNode; }

        virtual status_t PublishDevice() = 0;

protected:
        mutex                   fLock;
        device_node*    fNode;
};


struct ControlDevice : Device {
        ControlDevice(device_node* node)
                :
                Device(node)
        {
        }

        status_t Register(const char* fileName)
        {
                device_attr attrs[] = {
                        {B_DEVICE_PRETTY_NAME, B_STRING_TYPE,
                                {string: "Checksum Raw Device"}},
                        {kFilePathItem, B_STRING_TYPE, {string: fileName}},
                        {NULL}
                };

                return sDeviceManager->register_node(
                        sDeviceManager->get_parent_node(Node()), kDriverModuleName, attrs,
                        NULL, NULL);
        }

        virtual status_t PublishDevice()
        {
                return sDeviceManager->publish_device(Node(), kControlDeviceName,
                        kControlDeviceModuleName);
        }
};


struct RawDevice : Device, DoublyLinkedListLinkImpl<RawDevice> {
        RawDevice(device_node* node)
                :
                Device(node),
                fIndex(-1),
                fFD(-1),
                fFileSize(0),
                fDeviceSize(0),
                fDeviceName(NULL),
                fDMAResource(NULL),
                fIOScheduler(NULL),
                fTransferBuffer(NULL),
                fCheckSumCache(NULL)
        {
        }

        virtual ~RawDevice()
        {
                if (fIndex >= 0) {
                        MutexLocker locker(sDeviceListLock);
                        sDeviceList.Remove(this);
                }

                if (fFD >= 0)
                        close(fFD);

                free(fDeviceName);
        }

        int32 Index() const                             { return fIndex; }
        off_t DeviceSize() const                { return fDeviceSize; }
        const char* DeviceName() const  { return fDeviceName; }

        status_t Init(const char* fileName)
        {
                // open file/device
                fFD = open(fileName, O_RDWR | O_NOCACHE);
                        // TODO: The O_NOCACHE is a work-around for a page writer problem.
                        // Since it collects pages for writing back without regard for
                        // which caches and file systems they belong to, a deadlock can
                        // result when pages from both the underlying file system and the
                        // one using the checksum device are collected in one run.
                if (fFD < 0)
                        return errno;

                // get the size
                struct stat st;
                if (fstat(fFD, &st) < 0)
                        return errno;

                switch (st.st_mode & S_IFMT) {
                        case S_IFREG:
                                fFileSize = st.st_size;
                                break;
                        case S_IFCHR:
                        case S_IFBLK:
                        {
                                device_geometry geometry;
                                if (ioctl(fFD, B_GET_GEOMETRY, &geometry, sizeof(geometry)) < 0)
                                        return errno;

                                fFileSize = (off_t)geometry.bytes_per_sector
                                        * geometry.sectors_per_track
                                        * geometry.cylinder_count * geometry.head_count;
                                break;
                        }
                        default:
                                return B_BAD_VALUE;
                }

                fFileSize = fFileSize / B_PAGE_SIZE * B_PAGE_SIZE;
                fDeviceSize = fFileSize / (B_PAGE_SIZE + kCheckSumLength) * B_PAGE_SIZE;

                // find a free slot
                fIndex = 0;
                RawDevice* nextDevice = NULL;
                MutexLocker locker(sDeviceListLock);
                for (RawDeviceList::Iterator it = sDeviceList.GetIterator();
                                (nextDevice = it.Next()) != NULL;) {
                        if (nextDevice->Index() > fIndex)
                                break;
                        fIndex = nextDevice->Index() + 1;
                }

                sDeviceList.InsertBefore(nextDevice, this);

                // construct our device path
                KPath path(kRawDeviceBaseName);
                char buffer[32];
                snprintf(buffer, sizeof(buffer), "%" B_PRId32 "/raw", fIndex);

                status_t error = path.Append(buffer);
                if (error != B_OK)
                        return error;

                fDeviceName = path.DetachBuffer();

                return B_OK;
        }

        status_t Prepare()
        {
                fCheckSumCache = new(std::nothrow) CheckSumCache;
                if (fCheckSumCache == NULL) {
                        Unprepare();
                        return B_NO_MEMORY;
                }

                status_t error = fCheckSumCache->Init(fFD, fDeviceSize / B_PAGE_SIZE,
                        16);
                if (error != B_OK) {
                        Unprepare();
                        return error;
                }

                // no DMA restrictions
                const dma_restrictions restrictions = {};

                fDMAResource = new(std::nothrow) DMAResource;
                if (fDMAResource == NULL) {
                        Unprepare();
                        return B_NO_MEMORY;
                }

                error = fDMAResource->Init(restrictions, B_PAGE_SIZE,
                        kDMAResourceBufferCount, kDMAResourceBounceBufferCount);
                if (error != B_OK) {
                        Unprepare();
                        return error;
                }

                fIOScheduler = new(std::nothrow) IOSchedulerSimple(fDMAResource);
                if (fIOScheduler == NULL) {
                        Unprepare();
                        return B_NO_MEMORY;
                }

                error = fIOScheduler->Init("checksum device scheduler");
                if (error != B_OK) {
                        Unprepare();
                        return error;
                }

                fIOScheduler->SetCallback(&_DoIOEntry, this);

                fTransferBuffer = malloc(B_PAGE_SIZE);
                if (fTransferBuffer == NULL) {
                        Unprepare();
                        return B_NO_MEMORY;
                }

                return B_OK;
        }

        void Unprepare()
        {
                free(fTransferBuffer);
                fTransferBuffer = NULL;

                delete fIOScheduler;
                fIOScheduler = NULL;

                delete fDMAResource;
                fDMAResource = NULL;

                delete fCheckSumCache;
                fCheckSumCache = NULL;
        }

        status_t DoIO(IORequest* request)
        {
                return fIOScheduler->ScheduleRequest(request);
        }

        virtual status_t PublishDevice()
        {
                return sDeviceManager->publish_device(Node(), fDeviceName,
                        kRawDeviceModuleName);
        }

        status_t GetBlockCheckSum(uint64 blockIndex, CheckSum& checkSum)
        {
                return fCheckSumCache->GetCheckSum(blockIndex, checkSum);
        }

        status_t SetBlockCheckSum(uint64 blockIndex, const CheckSum& checkSum)
        {
                return fCheckSumCache->SetCheckSum(blockIndex, checkSum);
        }

private:
        static status_t _DoIOEntry(void* data, IOOperation* operation)
        {
                return ((RawDevice*)data)->_DoIO(operation);
        }

        status_t _DoIO(IOOperation* operation)
        {
                off_t offset = operation->Offset();
                generic_size_t length = operation->Length();

                ASSERT(offset % B_PAGE_SIZE == 0);
                ASSERT(length % B_PAGE_SIZE == 0);

                const generic_io_vec* vecs = operation->Vecs();
                generic_size_t vecOffset = 0;
                bool isWrite = operation->IsWrite();

                while (length > 0) {
                        status_t error = _TransferBlock(vecs, vecOffset, offset, isWrite);
                        if (error != B_OK) {
                                fIOScheduler->OperationCompleted(operation, error, 0);
                                return error;
                        }

                        offset += B_PAGE_SIZE;
                        length -= B_PAGE_SIZE;
                }

                fIOScheduler->OperationCompleted(operation, B_OK, operation->Length());
                return B_OK;
        }

        status_t _TransferBlock(const generic_io_vec*& vecs,
                generic_size_t& vecOffset, off_t offset, bool isWrite)
        {
                if (isWrite) {
                        // write -- copy data to transfer buffer
                        status_t error = _CopyData(vecs, vecOffset, true);
                        if (error != B_OK)
                                return error;
                        _CheckCheckSum(offset / B_PAGE_SIZE);
                }

                ssize_t transferred = isWrite
                        ? pwrite(fFD, fTransferBuffer, B_PAGE_SIZE, offset)
                        : pread(fFD, fTransferBuffer, B_PAGE_SIZE, offset);

                if (transferred < 0)
                        return errno;
                if (transferred != B_PAGE_SIZE)
                        return B_ERROR;

                if (!isWrite) {
                        // read -- copy data from transfer buffer
                        status_t error =_CopyData(vecs, vecOffset, false);
                        if (error != B_OK)
                                return error;
                }

                return B_OK;
        }

        status_t _CopyData(const generic_io_vec*& vecs, generic_size_t& vecOffset,
                bool toBuffer)
        {
                uint8* buffer = (uint8*)fTransferBuffer;
                size_t length = B_PAGE_SIZE;
                while (length > 0) {
                        size_t toCopy = std::min((generic_size_t)length,
                                vecs->length - vecOffset);

                        if (toCopy == 0) {
                                vecs++;
                                vecOffset = 0;
                                continue;
                        }

                        phys_addr_t vecAddress = vecs->base + vecOffset;

                        status_t error = toBuffer
                                ? vm_memcpy_from_physical(buffer, vecAddress, toCopy, false)
                                : vm_memcpy_to_physical(vecAddress, buffer, toCopy, false);
                        if (error != B_OK)
                                return error;

                        buffer += toCopy;
                        length -= toCopy;
                        vecOffset += toCopy;
                }

                return B_OK;
        }

        void _CheckCheckSum(uint64 blockIndex)
        {
                // get the checksum the block should have
                CheckSum expectedCheckSum;
                if (fCheckSumCache->GetCheckSum(blockIndex, expectedCheckSum) != B_OK)
                        return;

                // if the checksum is clear, we aren't supposed to check
                if (expectedCheckSum.IsZero()) {
                        dprintf("checksum_device: skipping check sum check for block %"
                                B_PRIu64 "\n", blockIndex);
                        return;
                }

                // compute the transfer buffer check sum
                fSHA256.Init();
                fSHA256.Update(fTransferBuffer, B_PAGE_SIZE);

                if (expectedCheckSum != fSHA256.Digest())
                        panic("Check sum mismatch for block %" B_PRIu64 " (exptected at %p"
                                ", actual at %p)", blockIndex, &expectedCheckSum,
                                fSHA256.Digest());
        }

private:
        int32                   fIndex;
        int                             fFD;
        off_t                   fFileSize;
        off_t                   fDeviceSize;
        char*                   fDeviceName;
        DMAResource*    fDMAResource;
        IOScheduler*    fIOScheduler;
        void*                   fTransferBuffer;
        CheckSumCache*  fCheckSumCache;
        SHA256                  fSHA256;
};


struct RawDeviceCookie {
        RawDeviceCookie(RawDevice* device, int openMode)
                :
                fDevice(device),
                fOpenMode(openMode)
        {
        }

        RawDevice* Device() const       { return fDevice; }
        int OpenMode() const            { return fOpenMode; }

private:
        RawDevice*      fDevice;
        int                     fOpenMode;
};


// #pragma mark -


static bool
parse_command_line(char* buffer, char**& _argv, int& _argc)
{
        // Process the argument string. We split at whitespace, heeding quotes and
        // escaped characters. The processed arguments are written to the given
        // buffer, separated by single null chars.
        char* start = buffer;
        char* out = buffer;
        bool pendingArgument = false;
        int argc = 0;
        while (*start != '\0') {
                // ignore whitespace
                if (isspace(*start)) {
                        if (pendingArgument) {
                                *out = '\0';
                                out++;
                                argc++;
                                pendingArgument = false;
                        }
                        start++;
                        continue;
                }

                pendingArgument = true;

                if (*start == '"' || *start == '\'') {
                        // quoted text -- continue until closing quote
                        char quote = *start;
                        start++;
                        while (*start != '\0' && *start != quote) {
                                if (*start == '\\' && quote == '"') {
                                        start++;
                                        if (*start == '\0')
                                                break;
                                }
                                *out = *start;
                                start++;
                                out++;
                        }

                        if (*start != '\0')
                                start++;
                } else {
                        // unquoted text
                        if (*start == '\\') {
                                // escaped char
                                start++;
                                if (start == '\0')
                                        break;
                        }

                        *out = *start;
                        start++;
                        out++;
                }
        }

        if (pendingArgument) {
                *out = '\0';
                argc++;
        }

        // allocate argument vector
        char** argv = new(std::nothrow) char*[argc + 1];
        if (argv == NULL)
                return false;

        // fill vector
        start = buffer;
        for (int i = 0; i < argc; i++) {
                argv[i] = start;
                start += strlen(start) + 1;
        }
        argv[argc] = NULL;

        _argv = argv;
        _argc = argc;
        return true;
}


//      #pragma mark - driver


static float
checksum_driver_supports_device(device_node* parent)
{
        const char* bus = NULL;
        if (sDeviceManager->get_attr_string(parent, B_DEVICE_BUS, &bus, false)
                        == B_OK && !strcmp(bus, "generic"))
                return 0.8;

        return -1;
}


static status_t
checksum_driver_register_device(device_node* parent)
{
        device_attr attrs[] = {
                {B_DEVICE_PRETTY_NAME, B_STRING_TYPE,
                        {string: "Checksum Control Device"}},
                {NULL}
        };

        return sDeviceManager->register_node(parent, kDriverModuleName, attrs, NULL,
                NULL);
}


static status_t
checksum_driver_init_driver(device_node* node, void** _driverCookie)
{
        const char* fileName;
        if (sDeviceManager->get_attr_string(node, kFilePathItem, &fileName, false)
                        == B_OK) {
                RawDevice* device = new(std::nothrow) RawDevice(node);
                if (device == NULL)
                        return B_NO_MEMORY;

                status_t error = device->Init(fileName);
                if (error != B_OK) {
                        delete device;
                        return error;
                }

                *_driverCookie = (Device*)device;
        } else {
                ControlDevice* device = new(std::nothrow) ControlDevice(node);
                if (device == NULL)
                        return B_NO_MEMORY;

                *_driverCookie = (Device*)device;
        }

        return B_OK;
}


static void
checksum_driver_uninit_driver(void* driverCookie)
{
        Device* device = (Device*)driverCookie;
        delete device;
}


static status_t
checksum_driver_register_child_devices(void* driverCookie)
{
        Device* device = (Device*)driverCookie;
        return device->PublishDevice();
}


//      #pragma mark - control device


static status_t
checksum_control_device_init_device(void* driverCookie, void** _deviceCookie)
{
        *_deviceCookie = driverCookie;
        return B_OK;
}


static void
checksum_control_device_uninit_device(void* deviceCookie)
{
}


static status_t
checksum_control_device_open(void* deviceCookie, const char* path, int openMode,
        void** _cookie)
{
        *_cookie = deviceCookie;
        return B_OK;
}


static status_t
checksum_control_device_close(void* cookie)
{
        return B_OK;
}


static status_t
checksum_control_device_free(void* cookie)
{
        return B_OK;
}


static status_t
checksum_control_device_read(void* cookie, off_t position, void* buffer,
        size_t* _length)
{
        *_length = 0;
        return B_OK;
}


static status_t
checksum_control_device_write(void* cookie, off_t position, const void* data,
        size_t* _length)
{
        ControlDevice* device = (ControlDevice*)cookie;

        if (position != 0)
                return B_BAD_VALUE;

        // copy data to stack buffer
        char* buffer = (char*)malloc(*_length + 1);
        if (buffer == NULL)
                return B_NO_MEMORY;
        MemoryDeleter bufferDeleter(buffer);

        if (IS_USER_ADDRESS(data)) {
                if (user_memcpy(buffer, data, *_length) != B_OK)
                        return B_BAD_ADDRESS;
        } else
                memcpy(buffer, data, *_length);

        buffer[*_length] = '\0';

        // parse arguments
        char** argv;
        int argc;
        if (!parse_command_line(buffer, argv, argc))
                return B_NO_MEMORY;
        ArrayDeleter<char*> argvDeleter(argv);

        if (argc == 0) {
                dprintf("\"help\" for usage!\n");
                return B_BAD_VALUE;
        }

        // execute command
        if (strcmp(argv[0], "help") == 0) {
                // help
                dprintf("register <path>\n");
                dprintf("  Registers file <path> as a new checksum device.\n");
                dprintf("unregister <device>\n");
                dprintf("  Unregisters <device>.\n");
        } else if (strcmp(argv[0], "register") == 0) {
                // register
                if (argc != 2) {
                        dprintf("Usage: register <path>\n");
                        return B_BAD_VALUE;
                }

                return device->Register(argv[1]);
        } else if (strcmp(argv[0], "unregister") == 0) {
                // unregister
                if (argc != 2) {
                        dprintf("Usage: unregister <device>\n");
                        return B_BAD_VALUE;
                }

                const char* deviceName = argv[1];
                if (strncmp(deviceName, "/dev/", 5) == 0)
                        deviceName += 5;

                // find the device in the list and unregister it
                MutexLocker locker(sDeviceListLock);
                for (RawDeviceList::Iterator it = sDeviceList.GetIterator();
                                RawDevice* device = it.Next();) {
                        if (strcmp(device->DeviceName(), deviceName) == 0) {
                                // TODO: Race condition: We should mark the device as going to
                                // be unregistered, so no one else can try the same after we
                                // unlock!
                                locker.Unlock();
// TODO: The following doesn't work! unpublish_device(), as per implementation
// (partially commented out) and unregister_node() returns B_BUSY.
                                status_t error = sDeviceManager->unpublish_device(
                                        device->Node(), device->DeviceName());
                                if (error != B_OK) {
                                        dprintf("Failed to unpublish device \"%s\": %s\n",
                                                deviceName, strerror(error));
                                        return error;
                                }

                                error = sDeviceManager->unregister_node(device->Node());
                                if (error != B_OK) {
                                        dprintf("Failed to unregister node \"%s\": %s\n",
                                                deviceName, strerror(error));
                                        return error;
                                }

                                return B_OK;
                        }
                }

                dprintf("Device \"%s\" not found!\n", deviceName);
                return B_BAD_VALUE;
        } else {
                dprintf("Invalid command \"%s\"!\n", argv[0]);
                return B_BAD_VALUE;
        }

        return B_OK;
}


static status_t
checksum_control_device_control(void* cookie, uint32 op, void* buffer,
        size_t length)
{
        return B_BAD_VALUE;
}


//      #pragma mark - raw device


static status_t
checksum_raw_device_init_device(void* driverCookie, void** _deviceCookie)
{
        RawDevice* device = static_cast<RawDevice*>((Device*)driverCookie);

        status_t error = device->Prepare();
        if (error != B_OK)
                return error;

        *_deviceCookie = device;
        return B_OK;
}


static void
checksum_raw_device_uninit_device(void* deviceCookie)
{
        RawDevice* device = (RawDevice*)deviceCookie;
        device->Unprepare();
}


static status_t
checksum_raw_device_open(void* deviceCookie, const char* path, int openMode,
        void** _cookie)
{
        RawDevice* device = (RawDevice*)deviceCookie;

        RawDeviceCookie* cookie = new(std::nothrow) RawDeviceCookie(device,
                openMode);
        if (cookie == NULL)
                return B_NO_MEMORY;

        *_cookie = cookie;
        return B_OK;
}


static status_t
checksum_raw_device_close(void* cookie)
{
        return B_OK;
}


static status_t
checksum_raw_device_free(void* _cookie)
{
        RawDeviceCookie* cookie = (RawDeviceCookie*)_cookie;
        delete cookie;
        return B_OK;
}


static status_t
checksum_raw_device_read(void* _cookie, off_t pos, void* buffer,
        size_t* _length)
{
        RawDeviceCookie* cookie = (RawDeviceCookie*)_cookie;
        RawDevice* device = cookie->Device();

        size_t length = *_length;

        if (pos >= device->DeviceSize())
                return B_BAD_VALUE;
        if (pos + length > device->DeviceSize())
                length = device->DeviceSize() - pos;

        IORequest request;
        status_t status = request.Init(pos, (addr_t)buffer, length, false, 0);
        if (status != B_OK)
                return status;

        status = device->DoIO(&request);
        if (status != B_OK)
                return status;

        status = request.Wait(0, 0);
        if (status == B_OK)
                *_length = length;
        return status;
}


static status_t
checksum_raw_device_write(void* _cookie, off_t pos, const void* buffer,
        size_t* _length)
{
        RawDeviceCookie* cookie = (RawDeviceCookie*)_cookie;
        RawDevice* device = cookie->Device();

        size_t length = *_length;

        if (pos >= device->DeviceSize())
                return B_BAD_VALUE;
        if (pos + length > device->DeviceSize())
                length = device->DeviceSize() - pos;

        IORequest request;
        status_t status = request.Init(pos, (addr_t)buffer, length, true, 0);
        if (status != B_OK)
                return status;

        status = device->DoIO(&request);
        if (status != B_OK)
                return status;

        status = request.Wait(0, 0);
        if (status == B_OK)
                *_length = length;

        return status;
}


static status_t
checksum_raw_device_io(void* _cookie, io_request* request)
{
        RawDeviceCookie* cookie = (RawDeviceCookie*)_cookie;
        RawDevice* device = cookie->Device();

        return device->DoIO(request);
}


static status_t
checksum_raw_device_control(void* _cookie, uint32 op, void* buffer,
        size_t length)
{
        RawDeviceCookie* cookie = (RawDeviceCookie*)_cookie;
        RawDevice* device = cookie->Device();

        switch (op) {
                case B_GET_DEVICE_SIZE:
                {
                        size_t size = device->DeviceSize();
                        return user_memcpy(buffer, &size, sizeof(size_t));
                }

                case B_SET_NONBLOCKING_IO:
                case B_SET_BLOCKING_IO:
                        return B_OK;

                case B_GET_READ_STATUS:
                case B_GET_WRITE_STATUS:
                {
                        bool value = true;
                        return user_memcpy(buffer, &value, sizeof(bool));
                }

                case B_GET_GEOMETRY:
                case B_GET_BIOS_GEOMETRY:
                {
                        device_geometry geometry;
                        geometry.bytes_per_sector = B_PAGE_SIZE;
                        geometry.sectors_per_track = 1;
                        geometry.cylinder_count = device->DeviceSize() / B_PAGE_SIZE;
                                // TODO: We're limited to 2^32 * B_PAGE_SIZE, if we don't use
                                // sectors_per_track and head_count.
                        geometry.head_count = 1;
                        geometry.device_type = B_DISK;
                        geometry.removable = true;
                        geometry.read_only = false;
                        geometry.write_once = false;

                        return user_memcpy(buffer, &geometry, sizeof(device_geometry));
                }

                case B_GET_MEDIA_STATUS:
                {
                        status_t status = B_OK;
                        return user_memcpy(buffer, &status, sizeof(status_t));
                }

                case B_SET_UNINTERRUPTABLE_IO:
                case B_SET_INTERRUPTABLE_IO:
                case B_FLUSH_DRIVE_CACHE:
                        return B_OK;

                case CHECKSUM_DEVICE_IOCTL_GET_CHECK_SUM:
                {
                        if (IS_USER_ADDRESS(buffer)) {
                                checksum_device_ioctl_check_sum getCheckSum;
                                if (user_memcpy(&getCheckSum, buffer, sizeof(getCheckSum))
                                                != B_OK) {
                                        return B_BAD_ADDRESS;
                                }

                                status_t error = device->GetBlockCheckSum(
                                        getCheckSum.blockIndex, getCheckSum.checkSum);
                                if (error != B_OK)
                                        return error;

                                return user_memcpy(buffer, &getCheckSum, sizeof(getCheckSum));
                        }

                        checksum_device_ioctl_check_sum* getCheckSum
                                = (checksum_device_ioctl_check_sum*)buffer;
                        return device->GetBlockCheckSum(getCheckSum->blockIndex,
                                getCheckSum->checkSum);
                }

                case CHECKSUM_DEVICE_IOCTL_SET_CHECK_SUM:
                {
                        if (IS_USER_ADDRESS(buffer)) {
                                checksum_device_ioctl_check_sum setCheckSum;
                                if (user_memcpy(&setCheckSum, buffer, sizeof(setCheckSum))
                                                != B_OK) {
                                        return B_BAD_ADDRESS;
                                }

                                return device->SetBlockCheckSum(setCheckSum.blockIndex,
                                        setCheckSum.checkSum);
                        }

                        checksum_device_ioctl_check_sum* setCheckSum
                                = (checksum_device_ioctl_check_sum*)buffer;
                        return device->SetBlockCheckSum(setCheckSum->blockIndex,
                                setCheckSum->checkSum);
                }
        }
        return B_BAD_VALUE;
}


// #pragma mark -


module_dependency module_dependencies[] = {
        {B_DEVICE_MANAGER_MODULE_NAME, (module_info**)&sDeviceManager},
        {}
};


static const struct driver_module_info sChecksumDeviceDriverModule = {
        {
                kDriverModuleName,
                0,
                NULL
        },

        checksum_driver_supports_device,
        checksum_driver_register_device,
        checksum_driver_init_driver,
        checksum_driver_uninit_driver,
        checksum_driver_register_child_devices
};

static const struct device_module_info sChecksumControlDeviceModule = {
        {
                kControlDeviceModuleName,
                0,
                NULL
        },

        checksum_control_device_init_device,
        checksum_control_device_uninit_device,
        NULL,

        checksum_control_device_open,
        checksum_control_device_close,
        checksum_control_device_free,

        checksum_control_device_read,
        checksum_control_device_write,
        NULL,   // io

        checksum_control_device_control,

        NULL,   // select
        NULL    // deselect
};

static const struct device_module_info sChecksumRawDeviceModule = {
        {
                kRawDeviceModuleName,
                0,
                NULL
        },

        checksum_raw_device_init_device,
        checksum_raw_device_uninit_device,
        NULL,

        checksum_raw_device_open,
        checksum_raw_device_close,
        checksum_raw_device_free,

        checksum_raw_device_read,
        checksum_raw_device_write,
        checksum_raw_device_io,

        checksum_raw_device_control,

        NULL,   // select
        NULL    // deselect
};

const module_info* modules[] = {
        (module_info*)&sChecksumDeviceDriverModule,
        (module_info*)&sChecksumControlDeviceModule,
        (module_info*)&sChecksumRawDeviceModule,
        NULL
};