#include "connection.h"
#include "disk.h"
#include <unistd.h>
#include <arpa/inet.h>
#include "nbdserver.h"
#include <errno.h>
#include <poll.h>
#include "safecom.h"
#define NBD_REQUEST_MAGIC 0x25609513
#define NBD_REPLY_MAGIC 0x67446698
#define BUFSIZE ((1024*1024)+sizeof(struct nbd_reply))
#define NBD_CMD_MASK_COMMAND 0x0000ffff
#define OFFT_MAX ~((off_t)1<<(sizeof(off_t)*8-1))
Connection::Connection(int fd, Disk *pDisk)
{
m_pDisk = pDisk;
m_fd = fd;
m_bEnd = false;
}
void Connection::close()
{
m_bEnd = true;
}
// bool readit(int f, char *buf, size_t len)
// {
// ssize_t res;
// while (len > 0) {
// res = read(f, buf, len);
// if (res > 0) {
// len -= res;
// buf += res;
// }
// else {
// return false;
// }
// }
// return true;
// }
bool Connection::pollit(int f, char *buf, size_t len)
{
while (!m_bEnd) {
struct pollfd fds[1];
int timeout_msecs = 500;
fds[0].fd = f;
fds[0].events = POLLIN;
int ret = poll(fds, 1, timeout_msecs);
if (ret > 0) {
if (fds[0].revents & POLLIN) {
return saferead(f, buf, len);
}
else {
return false;
}
}
else if (ret < 0) {
return false;
}
}
return false;
}
// bool writeit(int f, char *buf, size_t len)
// {
// ssize_t res;
// while (len > 0) {
// if ((res = write(f, buf, len)) > 0) {
// len -= res;
// buf += res;
// }
// else {
// return false;
// }
// }
// return true;
// }
bool Connection::sendError(struct nbd_reply *reply, int errcode)
{
reply->error = htonl(errcode);
bool bResult = safewrite(m_fd, (char *)reply, sizeof(*reply));
reply->error = 0;
return bResult;
}
void Connection::handle()
{
struct nbd_request request;
struct nbd_reply reply;
reply.magic = htonl(NBD_REPLY_MAGIC);
reply.error = 0;
size_t nSize = getDisk()->getSize();
while (!m_bEnd) {
char buf[BUFSIZE+sizeof(struct nbd_reply)+1];
char* p;
size_t len;
size_t currlen;
size_t writelen;
uint16_t command;
if (!pollit(m_fd, (char *)&request, sizeof(request))) {
m_bEnd = true;
}
else {
request.from = ntohll(request.from);
request.type = ntohl(request.type);
command = request.type & NBD_CMD_MASK_COMMAND;
len = ntohl(request.len);
if (request.magic != htonl(NBD_REQUEST_MAGIC)) {
m_bEnd = true;
}
else {
memcpy(reply.handle, request.handle, sizeof(reply.handle));
if ((command==NBD_CMD_WRITE) || (command==NBD_CMD_READ)) {
if ((request.from + len) > (OFFT_MAX)) {
if (!sendError(&reply, EINVAL)) {
m_bEnd = true;
}
continue;
}
if (((ssize_t)((off_t)request.from + len) > nSize)) {
if (!sendError(&reply, EINVAL)) {
m_bEnd = true;
}
continue;
}
currlen = (len < BUFSIZE) ? len: BUFSIZE;
}
}
if (!m_bEnd) {
switch (command) {
case NBD_CMD_DISC:
m_bEnd = true;
break;
case NBD_CMD_WRITE:
while(len > 0) {
if (saferead(m_fd, buf, currlen)) {
if (!getDisk()->write(request.from, buf, currlen)) {
m_bEnd = true;
len = 0;
}
else {
len -= currlen;
request.from += currlen;
currlen = (len < BUFSIZE) ? len : BUFSIZE;
}
}
else {
m_bEnd = true;
len = 0;
}
}
if (!m_bEnd) {
m_bEnd = !safewrite(m_fd, (char *)&reply, sizeof(reply));
}
break;
case NBD_CMD_FLUSH:
getDisk()->sync();
m_bEnd = !safewrite(m_fd, (char *)&reply, sizeof(reply));
break;
case NBD_CMD_READ:
memcpy(buf, &reply, sizeof(struct nbd_reply));
p = buf + sizeof(struct nbd_reply);
writelen = currlen + sizeof(struct nbd_reply);
while(len > 0) {
if (!getDisk()->read(request.from, p, currlen)) {
if (!sendError(&reply, 1)) {
m_bEnd = true;
len = 0;
}
}
if (!m_bEnd) {
if (!safewrite(m_fd, buf, writelen)) {
m_bEnd = true;
len = 0;
}
else {
len -= currlen;
request.from += currlen;
currlen = (len < BUFSIZE) ? len : BUFSIZE;
p = buf;
writelen = currlen;
}
}
}
break;
case NBD_CMD_TRIM:
// we don't expect this, so do nothing
m_bEnd = !safewrite(m_fd, (char *)&reply, sizeof(reply));
break;
default:
m_bEnd = true;
break;
}
}
}
}
::close(m_fd);
}