diff --git a/src/apps/mdns/mdns.c b/src/apps/mdns/mdns.c index 6a399376..f38709db 100644 --- a/src/apps/mdns/mdns.c +++ b/src/apps/mdns/mdns.c @@ -261,15 +261,8 @@ struct mdns_answer { u16_t rd_offset; }; -/** - * Add a label part to a domain - * @param domain The domain to add a label to - * @param label The label to add, like <hostname>, 'local', 'com' or '' - * @param len The length of the label - * @return ERR_OK on success, an err_t otherwise if label too long - */ err_t -mdns_domain_add_label(struct mdns_domain *domain, const char *label, u8_t len) +mdns_domain_add_label_base(struct mdns_domain *domain, u8_t len) { if (len > MDNS_LABEL_MAXLEN) { return ERR_VAL; @@ -283,6 +276,23 @@ mdns_domain_add_label(struct mdns_domain *domain, const char *label, u8_t len) } domain->name[domain->length] = len; domain->length++; + return ERR_OK; +} + +/** + * Add a label part to a domain + * @param domain The domain to add a label to + * @param label The label to add, like <hostname>, 'local', 'com' or '' + * @param len The length of the label + * @return ERR_OK on success, an err_t otherwise if label too long + */ +err_t +mdns_domain_add_label(struct mdns_domain *domain, const char *label, u8_t len) +{ + err_t err = mdns_domain_add_label_base(domain, len); + if (err != ERR_OK) { + return err; + } if (len) { MEMCPY(&domain->name[domain->length], label, len); domain->length += len; @@ -290,6 +300,27 @@ mdns_domain_add_label(struct mdns_domain *domain, const char *label, u8_t len) return ERR_OK; } +/** + * Add a label part to a domain (@see mdns_domain_add_label but copy directly from pbuf) + */ +static err_t +mdns_domain_add_label_pbuf(struct mdns_domain *domain, const struct pbuf *p, u16_t offset, u8_t len) +{ + err_t err = mdns_domain_add_label_base(domain, len); + if (err != ERR_OK) { + return err; + } + if (len) { + if (pbuf_copy_partial(p, &domain->name[domain->length], len, offset) != len) { + /* take back the ++ done before */ + domain->length--; + return ERR_ARG; + } + domain->length += len; + } + return ERR_OK; +} + /** * Internal readname function with max 6 levels of recursion following jumps * while decompressing name @@ -333,22 +364,16 @@ mdns_readname_loop(struct pbuf *p, u16_t offset, struct mdns_domain *domain, uns /* normal label */ if (c <= MDNS_LABEL_MAXLEN) { - u8_t label[MDNS_LABEL_MAXLEN]; err_t res; if (c + domain->length >= MDNS_DOMAIN_MAXLEN) { return MDNS_READNAME_ERROR; } - if (c != 0) { - if (pbuf_copy_partial(p, label, c, offset) != c) { - return MDNS_READNAME_ERROR; - } - offset += c; - } - res = mdns_domain_add_label(domain, (char *) label, c); + res = mdns_domain_add_label_pbuf(domain, p, offset, c); if (res != ERR_OK) { return MDNS_READNAME_ERROR; } + offset += c; } else { /* bad length byte */ return MDNS_READNAME_ERROR;