diff --git a/src/apps/tftp/tftp.c b/src/apps/tftp/tftp.c index 46b7ae3a..3c175a6a 100644 --- a/src/apps/tftp/tftp.c +++ b/src/apps/tftp/tftp.c @@ -89,10 +89,10 @@ struct tftp_state { u16_t blknum; u8_t retries; u8_t mode_write; + u8_t tftp_mode; }; static struct tftp_state tftp_state; -static u8_t tftp_mode; static void tftp_tmr(void *arg); @@ -117,12 +117,12 @@ close_handle(void) } static struct pbuf* -init_packet(int opcode, int extra, int size) +init_packet(u16_t opcode, u16_t extra, size_t size) { struct pbuf* p = pbuf_alloc(PBUF_TRANSPORT, (u16_t)(TFTP_HEADER_LENGTH + size), PBUF_RAM); u16_t* payload; - if(p != NULL) { + if (p != NULL) { payload = (u16_t*) p->payload; payload[0] = PP_HTONS(opcode); payload[1] = lwip_htons(extra); @@ -131,74 +131,83 @@ init_packet(int opcode, int extra, int size) return p; } -static void -send_request(const ip_addr_t *addr, u16_t port, int opcode, const char* fname, const char* mode) +static err_t +send_request(const ip_addr_t *addr, u16_t port, u16_t opcode, const char* fname, const char* mode) { size_t fname_length = strlen(fname)+1; size_t mode_length = strlen(mode)+1; - struct pbuf* p = init_packet(opcode, 0, -2 + fname_length + mode_length); + struct pbuf* p = init_packet(opcode, 0, fname_length + mode_length - 2); char* payload; + err_t ret; - if(p == NULL) { - return; + if (p == NULL) { + return ERR_MEM; } payload = (char*) p->payload; MEMCPY(payload+2, fname, fname_length); MEMCPY(payload+2+fname_length, mode, mode_length); - udp_sendto(tftp_state.upcb, p, addr, port); + ret = udp_sendto(tftp_state.upcb, p, addr, port); pbuf_free(p); + return ret; } -static void +static err_t send_error(const ip_addr_t *addr, u16_t port, enum tftp_error code, const char *str) { int str_length = strlen(str); struct pbuf *p; u16_t *payload; + err_t ret; p = init_packet(TFTP_ERROR, code, str_length + 1); if (p == NULL) { - return; + return ERR_MEM; } payload = (u16_t *) p->payload; MEMCPY(&payload[2], str, str_length + 1); - udp_sendto(tftp_state.upcb, p, addr, port); + ret = udp_sendto(tftp_state.upcb, p, addr, port); pbuf_free(p); + return ret; } -static void +static err_t send_ack(const ip_addr_t *addr, u16_t port, u16_t blknum) { struct pbuf *p; + err_t ret; p = init_packet(TFTP_ACK, blknum, 0); if (p == NULL) { - return; + return ERR_MEM; } - udp_sendto(tftp_state.upcb, p, addr, port); + ret = udp_sendto(tftp_state.upcb, p, addr, port); pbuf_free(p); + return ret; } -static void +static err_t resend_data(const ip_addr_t *addr, u16_t port) { + err_t ret; struct pbuf *p = pbuf_alloc(PBUF_TRANSPORT, tftp_state.last_data->len, PBUF_RAM); if (p == NULL) { - return; + return ERR_MEM; } - if (pbuf_copy(p, tftp_state.last_data) != ERR_OK) { + ret = pbuf_copy(p, tftp_state.last_data); + if (ret != ERR_OK) { pbuf_free(p); - return; + return ret; } - udp_sendto(tftp_state.upcb, p, addr, port); + ret = udp_sendto(tftp_state.upcb, p, addr, port); pbuf_free(p); + return ret; } static void @@ -264,7 +273,7 @@ recv(void *arg, struct udp_pcb *upcb, struct pbuf *p, const ip_addr_t *addr, u16 break; } - if ((tftp_mode & LWIP_TFTP_MODE_SERVER) == 0) { + if ((tftp_state.tftp_mode & LWIP_TFTP_MODE_SERVER) == 0) { send_error(addr, port, TFTP_ERROR_ACCESS_VIOLATION, "TFTP server not enabled"); break; } @@ -444,8 +453,6 @@ tftp_init_common(u8_t mode, const struct tftp_context *ctx) return ERR_MEM; } - tftp_mode = mode; - ret = udp_bind(pcb, IP_ANY_TYPE, TFTP_PORT); if (ret != ERR_OK) { udp_remove(pcb); @@ -458,6 +465,7 @@ tftp_init_common(u8_t mode, const struct tftp_context *ctx) tftp_state.timer = 0; tftp_state.last_data = NULL; tftp_state.upcb = pcb; + tftp_state.tftp_mode = mode; udp_recv(pcb, recv, NULL); @@ -498,25 +506,23 @@ void tftp_cleanup(void) err_t tftp_get(void* handle, const ip_addr_t *addr, u16_t port, const char* fname, const char* mode) { - LWIP_ERROR("TFTP client is not enabled (tftp_init)", (tftp_mode & LWIP_TFTP_MODE_CLIENT) == 0, return ERR_VAL); + LWIP_ERROR("TFTP client is not enabled (tftp_init)", (tftp_state.tftp_mode & LWIP_TFTP_MODE_CLIENT) != 0, return ERR_VAL); tftp_state.handle = handle; tftp_state.blknum = 1; tftp_state.mode_write = 1; /* We want to receive data */ - send_request(addr, port, TFTP_RRQ, fname, mode); - return ERR_OK; + return send_request(addr, port, TFTP_RRQ, fname, mode); } err_t tftp_put(void* handle, const ip_addr_t *addr, u16_t port, const char* fname, const char* mode) { - LWIP_ERROR("TFTP client is not enabled (tftp_init)", (tftp_mode & LWIP_TFTP_MODE_CLIENT) == 0, return ERR_VAL); + LWIP_ERROR("TFTP client is not enabled (tftp_init)", (tftp_state.tftp_mode & LWIP_TFTP_MODE_CLIENT) != 0, return ERR_VAL); tftp_state.handle = handle; tftp_state.blknum = 1; tftp_state.mode_write = 0; /* We want to send data */ - send_request(addr, port, TFTP_WRQ, fname, mode); - return ERR_OK; + return send_request(addr, port, TFTP_WRQ, fname, mode); } #endif /* LWIP_UDP */