Use byte-per-byte accesses when accessing possibly unaligned data.

This commit is contained in:
Reki 2022-05-03 10:20:42 +03:00
parent 1bbe9c5700
commit 7ffa952d01
4 changed files with 34 additions and 21 deletions

View File

@ -23,3 +23,16 @@ bool set_socket_buffers(int fd, int rcvbuf, int sndbuf);
uint64_t pntoh64(const void *p); uint64_t pntoh64(const void *p);
void phton64(uint8_t *p, uint64_t v); void phton64(uint8_t *p, uint64_t v);
static inline uint16_t pntoh16(const uint8_t *p) {
return ((uint16_t)p[0] << 8) | (uint16_t)p[1];
}
static inline void phton16(uint8_t *p, uint16_t v) {
p[0] = (uint8_t)(v >> 8);
p[1] = v & 0xFF;
}
static inline uint32_t pntoh32(const uint8_t *p) {
return ((uint32_t)p[0] << 24) | ((uint32_t)p[1] << 16) | ((uint32_t)p[2] << 8) | (uint32_t)p[3];
}

View File

@ -44,8 +44,6 @@ bool HttpExtractHost(const uint8_t *data, size_t len, char *host, size_t len_hos
return false; return false;
} }
static uint8_t tvb_get_varint(const uint8_t *tvb, uint64_t *value) static uint8_t tvb_get_varint(const uint8_t *tvb, uint64_t *value)
{ {
switch (*tvb >> 6) switch (*tvb >> 6)
@ -54,10 +52,10 @@ static uint8_t tvb_get_varint(const uint8_t *tvb, uint64_t *value)
if (value) *value = *tvb & 0x3F; if (value) *value = *tvb & 0x3F;
return 1; return 1;
case 1: /* 0b01 => 2 bytes length (14 bits Usable) */ case 1: /* 0b01 => 2 bytes length (14 bits Usable) */
if (value) *value = ntohs(*(uint16_t*)tvb) & 0x3FFF; if (value) *value = pntoh16(tvb) & 0x3FFF;
return 2; return 2;
case 2: /* 0b10 => 4 bytes length (30 bits Usable) */ case 2: /* 0b10 => 4 bytes length (30 bits Usable) */
if (value) *value = ntohl(*(uint32_t*)tvb) & 0x3FFFFFFF; if (value) *value = pntoh32(tvb) & 0x3FFFFFFF;
return 4; return 4;
case 3: /* 0b11 => 8 bytes length (62 bits Usable) */ case 3: /* 0b11 => 8 bytes length (62 bits Usable) */
if (value) *value = pntoh64(tvb) & 0x3FFFFFFFFFFFFFFF; if (value) *value = pntoh64(tvb) & 0x3FFFFFFFFFFFFFFF;
@ -87,7 +85,7 @@ bool IsQUICCryptoHello(const uint8_t *data, size_t len, size_t *hello_offset, si
} }
bool IsTLSClientHello(const uint8_t *data, size_t len) bool IsTLSClientHello(const uint8_t *data, size_t len)
{ {
return len >= 6 && data[0] == 0x16 && data[1] == 0x03 && data[2] >= 0x01 && data[2] <= 0x03 && data[5] == 0x01 && (ntohs(*(uint16_t*)(data + 3)) + 5) <= len; return len >= 6 && data[0] == 0x16 && data[1] == 0x03 && data[2] >= 0x01 && data[2] <= 0x03 && data[5] == 0x01 && (pntoh16(data + 3) + 5) <= len;
} }
bool TLSFindExtInHandshake(const uint8_t *data, size_t len, uint16_t type, const uint8_t **ext, size_t *len_ext) bool TLSFindExtInHandshake(const uint8_t *data, size_t len, uint16_t type, const uint8_t **ext, size_t *len_ext)
{ {
@ -114,7 +112,7 @@ bool TLSFindExtInHandshake(const uint8_t *data, size_t len, uint16_t type, const
l += data[l] + 1; l += data[l] + 1;
// CipherSuitesLength // CipherSuitesLength
if (len < (l + 2)) return false; if (len < (l + 2)) return false;
l += ntohs(*(uint16_t*)(data + l)) + 2; l += pntoh16(data + l) + 2;
// CompressionMethodsLength // CompressionMethodsLength
if (len < (l + 1)) return false; if (len < (l + 1)) return false;
l += data[l] + 1; l += data[l] + 1;
@ -122,18 +120,17 @@ bool TLSFindExtInHandshake(const uint8_t *data, size_t len, uint16_t type, const
if (len < (l + 2)) return false; if (len < (l + 2)) return false;
data += l; len -= l; data += l; len -= l;
l = ntohs(*(uint16_t*)data); l = pntoh16(data);
data += 2; len -= 2; data += 2; len -= 2;
if (l < len) return false; if (l < len) return false;
uint16_t ntype = htons(type);
while (l >= 4) while (l >= 4)
{ {
uint16_t etype = *(uint16_t*)data; uint16_t etype = pntoh16(data);
size_t elen = ntohs(*(uint16_t*)(data + 2)); size_t elen = pntoh16(data + 2);
data += 4; l -= 4; data += 4; l -= 4;
if (l < elen) break; if (l < elen) break;
if (etype == ntype) if (etype == type)
{ {
if (ext && len_ext) if (ext && len_ext)
{ {
@ -162,7 +159,7 @@ static bool TLSExtractHostFromExt(const uint8_t *ext, size_t elen, char *host, s
// u8 data+2 - server name type. 0=host_name // u8 data+2 - server name type. 0=host_name
// u16 data+3 - server name length // u16 data+3 - server name length
if (elen < 5 || ext[2] != 0) return false; if (elen < 5 || ext[2] != 0) return false;
size_t slen = ntohs(*(uint16_t*)(ext + 3)); size_t slen = pntoh16(ext + 3);
ext += 5; elen -= 5; ext += 5; elen -= 5;
if (slen < elen) return false; if (slen < elen) return false;
if (ext && len_host) if (ext && len_host)
@ -262,7 +259,7 @@ static bool quic_hkdf_expand_label(const uint8_t *secret, uint8_t secret_len, co
size_t hkdflabel_size = 2 + 1 + label_size + 1; size_t hkdflabel_size = 2 + 1 + label_size + 1;
if (hkdflabel_size > sizeof(hkdflabel)) return false; if (hkdflabel_size > sizeof(hkdflabel)) return false;
*(uint16_t*)hkdflabel = htons(out_len); phton16(hkdflabel, out_len);
hkdflabel[2] = (uint8_t)label_size; hkdflabel[2] = (uint8_t)label_size;
memcpy(hkdflabel + 3, label, label_size); memcpy(hkdflabel + 3, label, label_size);
hkdflabel[3 + label_size] = 0; hkdflabel[3 + label_size] = 0;

View File

@ -27,3 +27,7 @@ bool is_private6(const struct sockaddr_in6* a);
int set_keepalive(int fd); int set_keepalive(int fd);
int get_so_error(int fd); int get_so_error(int fd);
static inline uint16_t pntoh16(const uint8_t *p) {
return ((uint16_t)p[0] << 8) | (uint16_t)p[1];
}

View File

@ -45,7 +45,7 @@ bool HttpExtractHost(const uint8_t *data, size_t len, char *host, size_t len_hos
} }
bool IsTLSClientHello(const uint8_t *data, size_t len) bool IsTLSClientHello(const uint8_t *data, size_t len)
{ {
return len>=6 && data[0]==0x16 && data[1]==0x03 && data[2]>=0x01 && data[2]<=0x03 && data[5]==0x01 && (ntohs(*(uint16_t*)(data+3))+5)<=len; return len>=6 && data[0]==0x16 && data[1]==0x03 && data[2]>=0x01 && data[2]<=0x03 && data[5]==0x01 && (pntoh16(data+3)+5)<=len;
} }
bool TLSFindExt(const uint8_t *data, size_t len, uint16_t type, const uint8_t **ext, size_t *len_ext) bool TLSFindExt(const uint8_t *data, size_t len, uint16_t type, const uint8_t **ext, size_t *len_ext)
{ {
@ -76,7 +76,7 @@ bool TLSFindExt(const uint8_t *data, size_t len, uint16_t type, const uint8_t **
l += data[l]+1; l += data[l]+1;
// CipherSuitesLength // CipherSuitesLength
if (len<(l+2)) return false; if (len<(l+2)) return false;
l += ntohs(*(uint16_t*)(data+l))+2; l += pntoh16(data+l)+2;
// CompressionMethodsLength // CompressionMethodsLength
if (len<(l+1)) return false; if (len<(l+1)) return false;
l += data[l]+1; l += data[l]+1;
@ -84,18 +84,17 @@ bool TLSFindExt(const uint8_t *data, size_t len, uint16_t type, const uint8_t **
if (len<(l+2)) return false; if (len<(l+2)) return false;
data+=l; len-=l; data+=l; len-=l;
l=ntohs(*(uint16_t*)data); l=pntoh16(data);
data+=2; len-=2; data+=2; len-=2;
if (l<len) return false; if (l<len) return false;
uint16_t ntype=htons(type);
while(l>=4) while(l>=4)
{ {
uint16_t etype=*(uint16_t*)data; uint16_t etype=pntoh16(data);
size_t elen=ntohs(*(uint16_t*)(data+2)); size_t elen=pntoh16(data+2);
data+=4; l-=4; data+=4; l-=4;
if (l<elen) break; if (l<elen) break;
if (etype==ntype) if (etype==type)
{ {
if (ext && len_ext) if (ext && len_ext)
{ {
@ -119,7 +118,7 @@ bool TLSHelloExtractHost(const uint8_t *data, size_t len, char *host, size_t len
// u8 data+2 - server name type. 0=host_name // u8 data+2 - server name type. 0=host_name
// u16 data+3 - server name length // u16 data+3 - server name length
if (elen<5 || ext[2]!=0) return false; if (elen<5 || ext[2]!=0) return false;
size_t slen = ntohs(*(uint16_t*)(ext+3)); size_t slen = pntoh16(ext+3);
ext+=5; elen-=5; ext+=5; elen-=5;
if (slen<elen) return false; if (slen<elen) return false;
if (ext && len_host) if (ext && len_host)