#include "ICBase.h"
#ifdef TRACE_HYPERV_IC
# define TRACE(x...) dprintf("\33[94mhyperv_ic:\33[0m " x)
#else
# define TRACE(x...) ;
#endif
#define TRACE_ALWAYS(x...) dprintf("\33[94mhyperv_ic:\33[0m " x)
#define ERROR(x...) dprintf("\33[94mhyperv_ic:\33[0m " x)
#define CALLED(x...) TRACE("CALLED %s\n", __PRETTY_FUNCTION__)
ICBase::ICBase(device_node* node, uint32 packetLength, uint16 messageType,
const uint32* messageVersions, uint32 messageVersionCount)
:
fStatus(B_NO_INIT),
fNode(node),
fFrameworkVersion(0),
fMessageVersion(0),
fHyperV(NULL),
fHyperVCookie(NULL),
fPacket(NULL),
fPacketLength(packetLength),
fMessageType(messageType),
fMessageVersions(messageVersions),
fMessageVersionCount(messageVersionCount)
{
CALLED();
device_node* parent = gDeviceManager->get_parent_node(node);
gDeviceManager->get_driver(parent, (driver_module_info**)&fHyperV, (void**)&fHyperVCookie);
gDeviceManager->put_node(parent);
fPacket = static_cast<uint8*>(malloc(fPacketLength));
if (fPacket == NULL) {
fStatus = B_NO_MEMORY;
return;
}
fStatus = B_OK;
}
ICBase::~ICBase()
{
CALLED();
free(fPacket);
}
status_t
ICBase::Connect(uint32 txLength, uint32 rxLength)
{
CALLED();
status_t status = fHyperV->open(fHyperVCookie, txLength, rxLength, _CallbackHandler, this);
if (status != B_OK) {
ERROR("Failed to open channel");
return status;
}
return B_OK;
}
void
ICBase::Disconnect()
{
CALLED();
fHyperV->close(fHyperVCookie);
}
void
ICBase::OnProtocolNegotiated()
{
}
void
ICBase::OnMessageSent(hv_ic_msg* icMessage)
{
}
status_t
ICBase::_NegotiateProtocol(hv_ic_msg_negotiate* message)
{
CALLED();
if (message->header.data_length < offsetof(hv_ic_msg_negotiate, versions[2])
- sizeof(message->header)) {
ERROR("IC[%u] invalid negotiate msg length 0x%X\n", fMessageType,
message->header.data_length);
return B_BAD_VALUE;
}
if (message->framework_version_count == 0 || message->message_version_count == 0) {
ERROR("IC[%u] invalid negotiate msg version count\n", fMessageType);
return B_BAD_VALUE;
}
uint32 versionCount = message->framework_version_count + message->message_version_count;
if (versionCount < 2) {
ERROR("IC[%u] invalid negotiate msg version count\n", fMessageType);
return B_BAD_VALUE;
}
bool foundFrameworkVersion = false;
for (uint32 i = 0; i < hv_framework_version_count; i++) {
for (uint32 j = 0; j < message->framework_version_count; j++) {
TRACE("IC[%u] checking fw version %u.%u against %u.%u\n", fMessageType,
GET_IC_VERSION_MAJOR(hv_framework_versions[i]),
GET_IC_VERSION_MINOR(hv_framework_versions[i]),
GET_IC_VERSION_MAJOR(message->versions[j]),
GET_IC_VERSION_MINOR(message->versions[j]));
if (hv_framework_versions[i] == message->versions[j]) {
fFrameworkVersion = hv_framework_versions[i];
foundFrameworkVersion = true;
break;
}
}
if (foundFrameworkVersion)
break;
}
bool foundMessageVersion = false;
for (uint32 i = 0; i < fMessageVersionCount; i++) {
for (uint32 j = message->message_version_count; j < versionCount; j++) {
TRACE("IC[%u] checking msg version %u.%u against %u.%u\n", fMessageType,
GET_IC_VERSION_MAJOR(fMessageVersions[i]),
GET_IC_VERSION_MINOR(fMessageVersions[i]),
GET_IC_VERSION_MAJOR(message->versions[j]),
GET_IC_VERSION_MINOR(message->versions[j]));
if (fMessageVersions[i] == message->versions[j]) {
fMessageVersion = fMessageVersions[i];
foundMessageVersion = true;
break;
}
}
if (foundMessageVersion)
break;
}
if (!foundFrameworkVersion || !foundMessageVersion) {
ERROR("IC%u unsupported versions\n", fMessageType);
message->framework_version_count = 0;
message->message_version_count = 0;
return B_UNSUPPORTED;
}
TRACE("IC[%u] found supported fw version %u.%u msg version %u.%u\n", fMessageType,
GET_IC_VERSION_MAJOR(fFrameworkVersion), GET_IC_VERSION_MINOR(fFrameworkVersion),
GET_IC_VERSION_MAJOR(fMessageVersion), GET_IC_VERSION_MINOR(fMessageVersion));
message->framework_version_count = 1;
message->message_version_count = 1;
message->versions[0] = fFrameworkVersion;
message->versions[1] = fMessageVersion;
return B_OK;
}
void
ICBase::_CallbackHandler(void* arg)
{
ICBase* icBaseDevice = reinterpret_cast<ICBase*>(arg);
icBaseDevice->_Callback();
}
void
ICBase::_Callback()
{
while (true) {
uint32 length = fPacketLength;
uint32 headerLength;
uint32 dataLength;
status_t status = fHyperV->read_packet(fHyperVCookie, fPacket, &length,
&headerLength, &dataLength);
if (status == B_DEV_NOT_READY) {
break;
} else if (status != B_OK) {
ERROR("IC[%u] failed to read packet (%s)\n", fMessageType, strerror(status));
break;
}
if (dataLength < sizeof(hv_ic_msg_header)) {
ERROR("IC[%u] invalid packet\n", fMessageType);
continue;
}
vmbus_pkt_header* header = reinterpret_cast<vmbus_pkt_header*>(fPacket);
hv_ic_msg* message = reinterpret_cast<hv_ic_msg*>(fPacket + headerLength);
if (message->header.data_length <= dataLength - sizeof(message->header)) {
if (message->header.type == HV_IC_MSGTYPE_NEGOTIATE) {
status = _NegotiateProtocol(&message->negotiate);
if (status == B_OK) {
OnProtocolNegotiated();
} else {
ERROR("IC[%u] protocol negotiation failed (%s)\n", fMessageType,
strerror(status));
message->header.status = HV_IC_STATUS_FAILED;
}
} else if (message->header.type == fMessageType) {
OnMessageReceived(message);
} else {
ERROR("IC[%u] unknown message type %u\n", fMessageType, message->header.type);
message->header.status = HV_IC_STATUS_FAILED;
}
} else {
ERROR("IC[%u] invalid msg data length 0x%X pkt length 0x%X\n", fMessageType,
message->header.data_length, dataLength);
message->header.status = HV_IC_STATUS_FAILED;
}
message->header.flags = HV_IC_FLAG_TRANSACTION | HV_IC_FLAG_RESPONSE;
status = fHyperV->write_packet(fHyperVCookie, VMBUS_PKTTYPE_DATA_INBAND, message,
sizeof(hv_ic_msg_header) + message->header.data_length, false, header->transaction_id);
if (status == B_OK)
OnMessageSent(message);
else
ERROR("IC[%u] failed to send message (%s)\n", fMessageType, strerror(status));
}
}