about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorRen Kararou <[email protected]>2025-01-11 23:38:03 -0600
committerRen Kararou <[email protected]>2025-01-11 23:38:03 -0600
commit5d087643d3b4aacc907e119d992fbc4038130941 (patch)
treeb3a35a2d2d336654919c381329ad302ab7cb33b7
parentb1b3e7eca4c8153d7c6a2422923ad4ec2b78a223 (diff)
downloadnbtpd-5d087643d3b4aacc907e119d992fbc4038130941.tar.gz
nbtpd-5d087643d3b4aacc907e119d992fbc4038130941.tar.bz2
nbtpd-5d087643d3b4aacc907e119d992fbc4038130941.zip
working WRQ flow
-rw-r--r--src/handlers.c91
-rw-r--r--src/packet.c4
2 files changed, 75 insertions, 20 deletions
diff --git a/src/handlers.c b/src/handlers.c
index 5bfad60..eed91f4 100644
--- a/src/handlers.c
+++ b/src/handlers.c
@@ -43,6 +43,7 @@ int makesock(nbd_nbtpd_args *argptr) {
 	struct timeval timeout = { 30, 0 };
 	if (setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) {
 		syslog(LOG_ERR, "unable to set socket timeout: %s", strerror(errno));
+		close(s);
 		return -1;
 	}
 	if (connect(s, (struct sockaddr *)&(argptr->client), sizeof(struct sockaddr)) < 0) {
@@ -52,13 +53,15 @@ int makesock(nbd_nbtpd_args *argptr) {
 			inet_ntoa(argptr->client.sin_addr),
 			strerror(errno)
 		);
+		close(s);
 		return -1;
 	}
 	return s;
 }
 
 char *checkpath(nbd_nbtpd_args *argptr, uint8_t read) {
-	char *fname = NULL, *wd = NULL;
+	char *fname = NULL, *wd = NULL, *path = NULL, *ptr = NULL;
+	char **parts = NULL;
 	if (!is_netascii_str((char *)&argptr->path)) {
 		argptr->err = 1;
 		goto cleanup;
@@ -67,7 +70,45 @@ char *checkpath(nbd_nbtpd_args *argptr, uint8_t read) {
 		fname = realpath((char *)&argptr->path, NULL);
 	} else {
 		//TODO: figure out how to canonicalize non-existent path
-		fname = realpath((char *)&argptr->path, NULL);
+		char *tok = NULL;
+		parts = malloc(sizeof(parts) * NBD_NBTPD_ARGS_PATH_MAX);
+		if (parts == NULL) {
+			goto cleanup;
+		}
+		memset(parts, '\0', sizeof(parts) * NBD_NBTPD_ARGS_PATH_MAX);
+		ptr = (char *)&argptr->path;
+		for (int i = 0; ((tok = strsep(&ptr, "/")) != NULL); i++ ) {
+			parts[i] = tok;
+		}
+		path = malloc(NBD_NBTPD_ARGS_PATH_MAX);
+		if (path == NULL) {
+			goto cleanup;
+		}
+		memset(path, '\0', NBD_NBTPD_ARGS_PATH_MAX);
+		int z = 0;
+		for (int i = 0; parts[i] != NULL; i++) {
+			if (strncmp(parts[i], "..", 2) == 0) {
+				continue;
+			}
+			if (strncmp(parts[i], ".", NBD_NBTPD_ARGS_PATH_MAX) == 0) {
+				continue;
+			}
+			for (int x = 0; parts[i][x] != '\0'; x++) {
+				if (z < NBD_NBTPD_ARGS_PATH_MAX) {
+					path[z++] = parts[i][x];
+				} else {
+					goto cleanup;
+				}
+			}
+			if (z < NBD_NBTPD_ARGS_PATH_MAX) {
+				path[z++] = '/';
+			} else {
+				goto cleanup;
+			}
+		}
+		// erase trailing slash
+		path[z - 1] = '\0';
+		fname = strdup(path);
 	}
 	if (fname == NULL) {
 		syslog(LOG_ERR, "unable to get real path: %s", strerror(errno));
@@ -98,23 +139,33 @@ char *checkpath(nbd_nbtpd_args *argptr, uint8_t read) {
 	}
 #endif
 	syslog(LOG_DEBUG, "cwd: %s :: realpath: %s", wd, fname);
-	if (strncmp(wd, fname, strlen(wd))) {
-		syslog(
-			LOG_ERR,
-			"%s:%d requested invalid file %s",
-			inet_ntoa(argptr->client.sin_addr),
-			ntohs(argptr->client.sin_port),
-			fname
-		);
-		argptr->err = 2;
-		goto cleanup;
+	if (read) {
+		if (strncmp(wd, fname, strlen(wd))) {
+			syslog(
+				LOG_ERR,
+				"%s:%d requested invalid file %s",
+				inet_ntoa(argptr->client.sin_addr),
+				ntohs(argptr->client.sin_port),
+				fname
+			);
+			argptr->err = 2;
+			goto cleanup;
+		}
+	} else {
+		strcat(wd, "/");
+		strcat(wd, fname);
+		char *intermediate = wd;
+		wd = fname;
+		fname = intermediate;
 	}
 cleanup:
 	free(wd);
+	free(path);
+	free(parts);
 	return fname;
 }
 
-inline ssize_t senderror(int s, nbd_nbtpd_args *argptr) {
+ssize_t senderror(int s, nbd_nbtpd_args *argptr) {
 	size_t buflen = 4 + strlen(nbd_tftp_error_to_message((nbd_tftp_ecode)argptr->err));
 	char *buf = nbd_tftp_ser_error_from_code((nbd_tftp_ecode)argptr->err);
 	ssize_t sb = send(s, buf, buflen, 0);
@@ -283,6 +334,7 @@ void *write_req_resp(void *args) {
 	);
 	fname = checkpath(argptr, 0);
 	if (fname == NULL) {
+		syslog(LOG_ERR, "fname is NULL");
 		goto pre_socket;
 	}
 	if (!is_netascii_str((char *)&(argptr->mode))) {
@@ -298,6 +350,7 @@ void *write_req_resp(void *args) {
 	}
 	nbd_opmode opmode = get_mode((char *)&(argptr->mode));
 	if ((opmode != NETASCII) && (opmode != OCTET)) {
+		syslog(LOG_ERR, "%s:%d mode is not supported.", inet_ntoa(argptr->client.sin_addr), ntohs(argptr->client.sin_port));
 		argptr->err = 4;
 		goto pre_socket;
 	}
@@ -331,7 +384,7 @@ void *write_req_resp(void *args) {
 	while (lon) {
 		uint8_t verif = 0;
 		uint8_t vcount = 0;
-		char *packet = nbd_tftp_ser_ack_from_block_num(bnum);
+		char *packet = nbd_tftp_ser_ack_from_block_num(bnum++);
 		while (!verif) {
 			syslog(LOG_DEBUG, "sending ack number %d to %s:%d",
 				bnum,
@@ -374,13 +427,14 @@ void *write_req_resp(void *args) {
 						inet_ntoa(argptr->client.sin_addr),
 						htons(argptr->client.sin_port),
 						data.block_num,
-						bnum + 1
+						bnum
 					);
-					if (data.block_num == (bnum + 1)) {
+					if (data.block_num == bnum) {
 						rxon = 0;
 						verif = 1;
 						if (data.datalen > 0) {
-							if (fwrite(data.data, data.datalen, 1, fp) < 1) {
+							if (fwrite(data.data, 1, data.datalen, fp) < data.datalen) {
+								syslog(LOG_ERR, "filewrite failed for %s: %s", fname, strerror(errno));
 								argptr->err = 0;
 								senderror(s, argptr);
 								goto clean_socket;
@@ -388,13 +442,14 @@ void *write_req_resp(void *args) {
 						}
 						break;
 					}
-					if (data.block_num == bnum) {
+					if (data.block_num == (bnum - 1)) {
 						rxon = 0;
 						vcount++;
 						break;
 					}
 				}
 				if (++rxcount > 30) {
+					syslog(LOG_ERR, "retry count exceeded");
 					argptr->err = 0;
 					senderror(s, argptr);
 					goto clean_socket;
diff --git a/src/packet.c b/src/packet.c
index 0663b64..d439ec9 100644
--- a/src/packet.c
+++ b/src/packet.c
@@ -56,7 +56,7 @@ nbd_tftp_packet_data nbd_tftp_de_data(char *data, size_t len) {
 		ret.opcode = ((uint16_t)data[0] << 8) + data[1];
 		ret.block_num = ((uint16_t)data[2] << 8) + data[3];
 		ret.datalen = len - 4;
-		if (ret.datalen <= 0) {
+		if (ret.datalen > 0) {
 			if ((ret.data = malloc(ret.datalen)) != NULL) {
 				memcpy(ret.data, (data + 4), ret.datalen);
 			}
@@ -104,7 +104,7 @@ char *nbd_tftp_ser_ack(nbd_tftp_packet_ack ack) {
 char *nbd_tftp_ser_ack_from_block_num(uint16_t block_num) {
 	char *buf = malloc(sizeof(uint16_t) + sizeof(block_num));
 	if (buf != NULL) {
-		uint16_t netopcode = htons(5);
+		uint16_t netopcode = htons(4);
 		uint16_t netbnum = htons(block_num);
 		memcpy(buf, &netopcode, sizeof(netopcode));
 		memcpy((buf + sizeof(netopcode)), &netbnum, sizeof(netbnum));