wt: patch web-transport-proto header interop for Cloudflare relay

This commit is contained in:
every.channel 2026-02-18 01:28:57 -08:00
parent aa4bddcba0
commit 523c601dc3
No known key found for this signature in database
17 changed files with 2567 additions and 2 deletions

2
.gitignore vendored
View file

@ -15,6 +15,8 @@ result
third_party/*
!third_party/iroh-live
!third_party/iroh-org
!third_party/web-transport-proto
!third_party/web-transport-proto/**
third_party/iroh-org/*
!third_party/iroh-org/iroh-gossip

2
Cargo.lock generated
View file

@ -8312,8 +8312,6 @@ dependencies = [
[[package]]
name = "web-transport-proto"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17633ea7058419f87cbb7f341ab75ac5c1d6d187c154b0bd4c87539e66f4c4e4"
dependencies = [
"bytes",
"http",

View file

@ -34,3 +34,10 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"
tracing = "0.1"
tracing-subscriber = "0.3"
[patch.crates-io]
# Cloudflare's relay uses standard WebTransport subprotocol negotiation. The upstream
# `web-transport-proto` crate (used by `web-transport-quinn`) currently uses legacy
# header names (`wt-available-protocols` / `wt-protocol`), which prevents negotiating
# `moqt-*` and causes the relay to close after MoQ SETUP.
web-transport-proto = { path = "third_party/web-transport-proto" }

View file

@ -60,6 +60,7 @@ newer draft implementations.
Implementation choice:
- Cloudflare's relay preview currently does **not** support `ANNOUNCE` (namespace-style publishing). `ec-node wt-publish` uses the `moq-lite` publish model via `moq-native` and `moq-mux` (fMP4 ingestion) for Cloudflare relay compatibility.
- On NixOS deployments, we disable `moq-native`'s WebSocket fallback (`MOQ_CLIENT_WEBSOCKET_ENABLED=false`) to ensure WebTransport (QUIC) is used. This avoids the WebSocket path occasionally "winning" the race and then failing MoQ negotiation against the Cloudflare relay, causing rapid reconnect loops.
- For Cloudflare relay interop, we patch `web-transport-proto` to send and accept the standard WebTransport subprotocol negotiation header (`sec-webtransport-protocol`) in addition to the legacy `wt-available-protocols`/`wt-protocol` headers. Without subprotocol negotiation, the relay may not select a `moqt-*` protocol and can close the session immediately after MoQ `SETUP`.
### Share link

View file

@ -0,0 +1,90 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [0.2.7](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.6...web-transport-proto-v0.2.7) - 2025-09-03
### Other
- Rename the repo. ([#94](https://github.com/moq-dev/web-transport/pull/94))
- Fix clippy warnings ([#91](https://github.com/moq-dev/web-transport/pull/91))
- Add support for session closed capsule ([#86](https://github.com/moq-dev/web-transport/pull/86))
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
## [0.5.2](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.5.1...web-transport-proto-v0.5.2) - 2026-02-13
### Other
- Fix some API mistakes. ([#163](https://github.com/moq-dev/web-transport/pull/163))
## [0.5.1](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.5.0...web-transport-proto-v0.5.1) - 2026-02-11
### Other
- Async accept ([#159](https://github.com/moq-dev/web-transport/pull/159))
## [0.5.0](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.4.0...web-transport-proto-v0.5.0) - 2026-02-10
### Other
- Fix capsule protocol handling ([#152](https://github.com/moq-dev/web-transport/pull/152))
## [0.4.0](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.3.1...web-transport-proto-v0.4.0) - 2026-01-23
### Other
- Sub-protocol negotiation + breaking API changes ([#143](https://github.com/moq-dev/web-transport/pull/143))
## [0.3.1](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.3.0...web-transport-proto-v0.3.1) - 2026-01-07
### Other
- Rename the repo into a new org. ([#132](https://github.com/moq-dev/web-transport/pull/132))
## [0.2.8](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.7...web-transport-proto-v0.2.8) - 2025-10-17
### Other
- Make traits compatible with WASM ([#107](https://github.com/moq-dev/web-transport/pull/107))
- Check all feature combinations ([#102](https://github.com/moq-dev/web-transport/pull/102))
## [0.2.6](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.5...web-transport-proto-v0.2.6) - 2025-05-15
### Other
- Add (generic) support for learning when a stream is closed. ([#73](https://github.com/moq-dev/web-transport/pull/73))
- Add url query to CONNECT :path request ([#70](https://github.com/moq-dev/web-transport/pull/70))
## [0.2.5](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.4...web-transport-proto-v0.2.5) - 2025-03-26
### Other
- Added Ring feature flag ([#68](https://github.com/moq-dev/web-transport/pull/68))
## [0.2.4](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.3...web-transport-proto-v0.2.4) - 2025-01-15
### Other
- Bump some deps. ([#55](https://github.com/moq-dev/web-transport/pull/55))
- Clippy fixes. ([#53](https://github.com/moq-dev/web-transport/pull/53))
## [0.2.3](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.2...web-transport-proto-v0.2.3) - 2024-09-02
### Other
- Don't set the N bit for literals. ([#41](https://github.com/moq-dev/web-transport/pull/41))
## [0.2.2](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.1...web-transport-proto-v0.2.2) - 2024-08-15
### Other
- Some more documentation. ([#34](https://github.com/moq-dev/web-transport/pull/34))

View file

@ -0,0 +1,59 @@
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
#
# When uploading crates to the registry Cargo will automatically
# "normalize" Cargo.toml files for maximal compatibility
# with all versions of Cargo and also rewrite `path` dependencies
# to registry (e.g., crates.io) dependencies.
#
# If you are reading this file be aware that the original Cargo.toml
# will likely look very different (and much more reasonable).
# See Cargo.toml.orig for the original contents.
[package]
edition = "2021"
name = "web-transport-proto"
version = "0.5.2"
authors = ["Luke Curley"]
build = false
autolib = false
autobins = false
autoexamples = false
autotests = false
autobenches = false
description = "WebTransport core protocol"
readme = "README.md"
keywords = [
"quic",
"http3",
"webtransport",
]
categories = [
"network-programming",
"web-programming",
]
license = "MIT OR Apache-2.0"
repository = "https://github.com/moq-dev/web-transport"
[lib]
name = "web_transport_proto"
path = "src/lib.rs"
[dependencies.bytes]
version = "1"
[dependencies.http]
version = "1"
[dependencies.sfv]
version = "0.14"
[dependencies.thiserror]
version = "2"
[dependencies.tokio]
version = "1"
features = ["io-util"]
default-features = false
[dependencies.url]
version = "2"

View file

@ -0,0 +1,6 @@
[![crates.io](https://img.shields.io/crates/v/web-transport-proto)](https://crates.io/crates/web-transport-proto)
[![discord](https://img.shields.io/discord/1124083992740761730)](https://discord.gg/FCYF3p99mr)
# web-transport-proto
The gritty WebTransport protocol implementation.
Not meant to be used directly, but as a dependency for [web-transport-quinn](../web-transport-quinn) and [web-transport-wasm](../web-transport-wasm).

View file

@ -0,0 +1,376 @@
use std::sync::Arc;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::{VarInt, VarIntUnexpectedEnd};
// The spec (draft-ietf-webtrans-http3-06) says the type is 0x2843, which would
// varint-encode to 0x68 0x43. However, actual wire data shows 0x43 0x28 which
// decodes to 808. There may be a discrepancy in implementations or specs.
// Using 0x2843 as specified in the standard.
const CLOSE_WEBTRANSPORT_SESSION_TYPE: u64 = 0x2843;
const MAX_MESSAGE_SIZE: usize = 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Capsule {
CloseWebTransportSession { code: u32, reason: String },
Grease { num: u64 },
Unknown { typ: VarInt, payload: Bytes },
}
impl Capsule {
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, CapsuleError> {
let typ = VarInt::decode(buf)?;
let length = VarInt::decode(buf)?;
let mut payload = buf.take(length.into_inner() as usize);
// Check declared length first - reject immediately if too large
if payload.limit() > MAX_MESSAGE_SIZE {
return Err(CapsuleError::MessageTooLong);
}
// Then check if all declared bytes are buffered
if payload.remaining() < payload.limit() {
return Err(CapsuleError::UnexpectedEnd);
}
let typ_val = typ.into_inner();
if let Some(num) = is_grease(typ_val) {
payload.advance(payload.remaining());
return Ok(Self::Grease { num });
}
match typ_val {
CLOSE_WEBTRANSPORT_SESSION_TYPE => {
if payload.remaining() < 4 {
return Err(CapsuleError::UnexpectedEnd);
}
let error_code = payload.get_u32();
let message_len = payload.remaining();
if message_len > MAX_MESSAGE_SIZE {
return Err(CapsuleError::MessageTooLong);
}
let mut message_bytes = vec![0u8; message_len];
payload.copy_to_slice(&mut message_bytes);
let error_message =
String::from_utf8(message_bytes).map_err(|_| CapsuleError::InvalidUtf8)?;
Ok(Self::CloseWebTransportSession {
code: error_code,
reason: error_message,
})
}
_ => {
let mut payload_bytes = vec![0u8; payload.remaining()];
payload.copy_to_slice(&mut payload_bytes);
Ok(Self::Unknown {
typ,
payload: Bytes::from(payload_bytes),
})
}
}
}
pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Option<Self>, CapsuleError> {
let mut buf = Vec::new();
loop {
if stream.read_buf(&mut buf).await? == 0 {
if buf.is_empty() {
return Ok(None);
}
return Err(CapsuleError::UnexpectedEnd);
}
let mut limit = std::io::Cursor::new(&buf);
match Self::decode(&mut limit) {
Ok(capsule) => return Ok(Some(capsule)),
Err(CapsuleError::UnexpectedEnd) => continue,
Err(e) => return Err(e),
}
}
}
pub fn encode<B: BufMut>(&self, buf: &mut B) {
match self {
Self::CloseWebTransportSession {
code: error_code,
reason: error_message,
} => {
// Encode the capsule type
VarInt::from_u64(CLOSE_WEBTRANSPORT_SESSION_TYPE)
.unwrap()
.encode(buf);
// Calculate and encode the length
let length = 4 + error_message.len();
VarInt::from_u32(length as u32).encode(buf);
// Encode the error code (32-bit)
buf.put_u32(*error_code);
// Encode the error message
buf.put_slice(error_message.as_bytes());
}
Self::Grease { num } => {
// Generate grease type: 0x29 * N + 0x17
// Check for overflow
let grease_type = num
.checked_mul(0x29)
.and_then(|v| v.checked_add(0x17))
.expect("grease num value would overflow u64");
VarInt::from_u64(grease_type).unwrap().encode(buf);
// Grease capsules have zero-length payload
VarInt::from_u32(0).encode(buf);
}
Self::Unknown { typ, payload } => {
// Encode the capsule type
typ.encode(buf);
// Encode the length
VarInt::try_from(payload.len()).unwrap().encode(buf);
// Encode the payload
buf.put_slice(payload);
}
}
}
pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), CapsuleError> {
let mut buf = BytesMut::new();
self.encode(&mut buf);
stream.write_all_buf(&mut buf).await?;
Ok(())
}
}
// RFC 9297 Section 5.4: Capsule types of the form 0x29 * N + 0x17
// Returns Some(N) if the value is a grease type, None otherwise
fn is_grease(val: u64) -> Option<u64> {
if val < 0x17 {
return None;
}
let num = (val - 0x17) / 0x29;
if val == 0x29 * num + 0x17 {
Some(num)
} else {
None
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum CapsuleError {
#[error("unexpected end of buffer")]
UnexpectedEnd,
#[error("invalid UTF-8")]
InvalidUtf8,
#[error("message too long")]
MessageTooLong,
#[error("unknown capsule type: {0:?}")]
UnknownType(VarInt),
#[error("varint decode error: {0:?}")]
VarInt(#[from] VarIntUnexpectedEnd),
#[error("io error: {0}")]
Io(Arc<std::io::Error>),
}
impl From<std::io::Error> for CapsuleError {
fn from(err: std::io::Error) -> Self {
CapsuleError::Io(Arc::new(err))
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
#[test]
fn test_close_webtransport_session_decode() {
// Test with spec-compliant type 0x2843 (encodes as 0x68 0x43)
let mut data = Vec::new();
VarInt::from_u64(0x2843).unwrap().encode(&mut data);
VarInt::from_u32(8).encode(&mut data);
data.extend_from_slice(b"\x00\x00\x01\xa4test");
let mut buf = data.as_slice();
let capsule = Capsule::decode(&mut buf).unwrap();
match capsule {
Capsule::CloseWebTransportSession {
code: error_code,
reason: error_message,
} => {
assert_eq!(error_code, 420);
assert_eq!(error_message, "test");
}
_ => panic!("Expected CloseWebTransportSession"),
}
assert_eq!(buf.len(), 0); // All bytes consumed
}
#[test]
fn test_close_webtransport_session_encode() {
let capsule = Capsule::CloseWebTransportSession {
code: 420,
reason: "test".to_string(),
};
let mut buf = Vec::new();
capsule.encode(&mut buf);
// Expected format: type(0x2843 as varint = 0x68 0x43) + length(8 as varint) + error_code(420 as u32 BE) + "test"
assert_eq!(buf, b"\x68\x43\x08\x00\x00\x01\xa4test");
}
#[test]
fn test_close_webtransport_session_roundtrip() {
let original = Capsule::CloseWebTransportSession {
code: 12345,
reason: "Connection closed by application".to_string(),
};
let mut buf = Vec::new();
original.encode(&mut buf);
let mut read_buf = buf.as_slice();
let decoded = Capsule::decode(&mut read_buf).unwrap();
assert_eq!(original, decoded);
assert_eq!(read_buf.len(), 0); // All bytes consumed
}
#[test]
fn test_empty_error_message() {
let capsule = Capsule::CloseWebTransportSession {
code: 0,
reason: String::new(),
};
let mut buf = Vec::new();
capsule.encode(&mut buf);
// Type(0x2843 as varint = 0x68 0x43) + Length(4) + error_code(0)
assert_eq!(buf, b"\x68\x43\x04\x00\x00\x00\x00");
let mut read_buf = buf.as_slice();
let decoded = Capsule::decode(&mut read_buf).unwrap();
assert_eq!(capsule, decoded);
}
#[test]
fn test_invalid_utf8() {
// Create a capsule with invalid UTF-8 in the message
let mut data = Vec::new();
VarInt::from_u64(0x2843).unwrap().encode(&mut data); // type
VarInt::from_u32(5).encode(&mut data); // length(5)
data.extend_from_slice(b"\x00\x00\x00\x00"); // error_code(0)
data.push(0xFF); // Invalid UTF-8 byte
let mut buf = data.as_slice();
let result = Capsule::decode(&mut buf);
assert!(matches!(result, Err(CapsuleError::InvalidUtf8)));
}
#[test]
fn test_truncated_error_code() {
// Capsule with length indicating 3 bytes but error code needs 4
let mut data = Vec::new();
VarInt::from_u64(0x2843).unwrap().encode(&mut data); // type
VarInt::from_u32(3).encode(&mut data); // length(3)
data.extend_from_slice(b"\x00\x00\x00"); // incomplete error code
let mut buf = data.as_slice();
let result = Capsule::decode(&mut buf);
assert!(matches!(result, Err(CapsuleError::UnexpectedEnd)));
}
#[test]
fn test_unknown_capsule() {
// Test handling of unknown capsule types
let unknown_type = 0x1234u64;
let payload_data = b"unknown payload";
let mut data = Vec::new();
VarInt::from_u64(unknown_type).unwrap().encode(&mut data);
VarInt::from_u32(payload_data.len() as u32).encode(&mut data);
data.extend_from_slice(payload_data);
let mut buf = data.as_slice();
let capsule = Capsule::decode(&mut buf).unwrap();
match capsule {
Capsule::Unknown { typ, payload } => {
assert_eq!(typ.into_inner(), unknown_type);
assert_eq!(payload.as_ref(), payload_data);
}
_ => panic!("Expected Unknown capsule"),
}
}
#[test]
fn test_unknown_capsule_roundtrip() {
let capsule = Capsule::Unknown {
typ: VarInt::from_u64(0x9999).unwrap(),
payload: Bytes::from("test payload"),
};
let mut buf = Vec::new();
capsule.encode(&mut buf);
let mut read_buf = buf.as_slice();
let decoded = Capsule::decode(&mut read_buf).unwrap();
assert_eq!(capsule, decoded);
assert_eq!(read_buf.len(), 0);
}
#[test]
fn test_grease_capsule() {
// Test grease formula: 0x29 * N + 0x17
for num in [0, 1, 5, 100, 1000] {
let capsule = Capsule::Grease { num };
let mut buf = Vec::new();
capsule.encode(&mut buf);
let mut read_buf = buf.as_slice();
let decoded = Capsule::decode(&mut read_buf).unwrap();
assert_eq!(capsule, decoded);
assert_eq!(read_buf.len(), 0);
}
}
#[test]
fn test_grease_values() {
// Verify specific grease type values
assert_eq!(is_grease(0x17), Some(0)); // N=0
assert_eq!(is_grease(0x40), Some(1)); // N=1: 0x29 + 0x17 = 0x40
assert_eq!(is_grease(0x69), Some(2)); // N=2: 0x29*2 + 0x17 = 0x69
assert_eq!(is_grease(0x18), None); // Not a grease value
assert_eq!(is_grease(0x41), None); // Not a grease value
}
#[test]
#[should_panic(expected = "grease num value would overflow u64")]
fn test_grease_overflow() {
let capsule = Capsule::Grease { num: u64::MAX };
let mut buf = Vec::new();
capsule.encode(&mut buf);
}
}

View file

@ -0,0 +1,402 @@
use std::{str::FromStr, sync::Arc};
use bytes::{Buf, BufMut, BytesMut};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use url::Url;
use super::{qpack, Frame, VarInt};
use thiserror::Error;
// Errors that can occur during the connect request.
#[derive(Error, Debug, Clone)]
pub enum ConnectError {
#[error("unexpected end of input")]
UnexpectedEnd,
#[error("qpack error")]
QpackError(#[from] qpack::DecodeError),
#[error("unexpected frame {0:?}")]
UnexpectedFrame(Frame),
#[error("invalid method")]
InvalidMethod,
#[error("invalid url")]
InvalidUrl(#[from] url::ParseError),
#[error("invalid status")]
InvalidStatus,
#[error("expected 200, got: {0:?}")]
WrongStatus(Option<http::StatusCode>),
#[error("expected connect, got: {0:?}")]
WrongMethod(Option<http::method::Method>),
#[error("expected https, got: {0:?}")]
WrongScheme(Option<String>),
#[error("expected authority header")]
WrongAuthority,
#[error("expected webtransport, got: {0:?}")]
WrongProtocol(Option<String>),
#[error("expected path header")]
WrongPath,
#[error("invalid protocol header")]
InvalidProtocol,
#[error("structured field error: {0}")]
StructuredFieldError(Arc<sfv::Error>),
#[error("non-200 status: {0:?}")]
ErrorStatus(http::StatusCode),
#[error("io error: {0}")]
Io(Arc<std::io::Error>),
}
impl From<std::io::Error> for ConnectError {
fn from(err: std::io::Error) -> Self {
ConnectError::Io(Arc::new(err))
}
}
impl From<sfv::Error> for ConnectError {
fn from(err: sfv::Error) -> Self {
ConnectError::StructuredFieldError(Arc::new(err))
}
}
/// A CONNECT request to initiate a WebTransport session.
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct ConnectRequest {
/// The URL to connect to.
pub url: Url,
/// The subprotocols requested (if any).
pub protocols: Vec<String>,
}
impl ConnectRequest {
pub fn new(url: impl Into<Url>) -> Self {
Self {
url: url.into(),
protocols: Vec::new(),
}
}
pub fn with_protocol(mut self, protocol: impl Into<String>) -> Self {
self.protocols.push(protocol.into());
self
}
pub fn with_protocols(mut self, protocols: impl IntoIterator<Item = String>) -> Self {
self.protocols.extend(protocols);
self
}
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
let (typ, mut data) = Frame::read(buf).map_err(|_| ConnectError::UnexpectedEnd)?;
if typ != Frame::HEADERS {
return Err(ConnectError::UnexpectedFrame(typ));
}
// We no longer return UnexpectedEnd because we know the buffer should be large enough.
let headers = qpack::Headers::decode(&mut data)?;
let scheme = match headers.get(":scheme") {
Some("https") => "https",
Some(scheme) => Err(ConnectError::WrongScheme(Some(scheme.to_string())))?,
None => return Err(ConnectError::WrongScheme(None)),
};
let authority = headers
.get(":authority")
.ok_or(ConnectError::WrongAuthority)?;
let path_and_query = headers.get(":path").ok_or(ConnectError::WrongPath)?;
let method = headers.get(":method");
match method
.map(|method| method.try_into().map_err(|_| ConnectError::InvalidMethod))
.transpose()?
{
Some(http::Method::CONNECT) => (),
o => return Err(ConnectError::WrongMethod(o)),
};
let protocol = headers.get(":protocol");
if protocol != Some("webtransport") {
return Err(ConnectError::WrongProtocol(protocol.map(|s| s.to_string())));
}
let protocols = headers
.get(protocol_negotiation::AVAILABLE_NAME_STD)
.or_else(|| headers.get(protocol_negotiation::AVAILABLE_NAME))
.map(protocol_negotiation::decode_list)
.transpose()
.map_err(|_| ConnectError::InvalidProtocol)?
.unwrap_or_default();
let url = Url::parse(&format!("{scheme}://{authority}{path_and_query}"))?;
Ok(Self { url, protocols })
}
pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, ConnectError> {
let mut buf = Vec::new();
loop {
if stream.read_buf(&mut buf).await? == 0 {
return Err(ConnectError::UnexpectedEnd);
}
let mut limit = std::io::Cursor::new(&buf);
match Self::decode(&mut limit) {
Ok(request) => return Ok(request),
Err(ConnectError::UnexpectedEnd) => continue,
Err(e) => return Err(e),
}
}
}
pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<(), ConnectError> {
let mut headers = qpack::Headers::default();
headers.set(":method", "CONNECT");
headers.set(":scheme", self.url.scheme());
headers.set(":authority", self.url.authority());
let path_and_query = match self.url.query() {
Some(query) => format!("{}?{}", self.url.path(), query),
None => self.url.path().to_string(),
};
headers.set(":path", &path_and_query);
headers.set(":protocol", "webtransport");
if !self.protocols.is_empty() {
let encoded = protocol_negotiation::encode_list(&self.protocols)?;
// Send both the standard and legacy header names to maximize interop.
headers.set(protocol_negotiation::AVAILABLE_NAME_STD, &encoded);
headers.set(protocol_negotiation::AVAILABLE_NAME, &encoded);
}
// Use a temporary buffer so we can compute the size.
let mut tmp = Vec::new();
headers.encode(&mut tmp);
let size = VarInt::from_u32(tmp.len() as u32);
Frame::HEADERS.encode(buf);
size.encode(buf);
buf.put_slice(&tmp);
Ok(())
}
pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), ConnectError> {
let mut buf = BytesMut::new();
self.encode(&mut buf)?;
stream.write_all_buf(&mut buf).await?;
Ok(())
}
}
impl From<Url> for ConnectRequest {
fn from(url: Url) -> Self {
Self {
url,
protocols: Vec::new(),
}
}
}
/// A CONNECT response to accept or reject a WebTransport session.
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct ConnectResponse {
/// The status code of the response.
pub status: http::status::StatusCode,
/// The subprotocol selected by the server, if any
pub protocol: Option<String>,
}
impl ConnectResponse {
pub const OK: Self = Self {
status: http::StatusCode::OK,
protocol: None,
};
pub fn new(status: http::StatusCode) -> Self {
Self {
status,
protocol: None,
}
}
pub fn with_protocol(mut self, protocol: impl Into<String>) -> Self {
self.protocol = Some(protocol.into());
self
}
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
let (typ, mut data) = Frame::read(buf).map_err(|_| ConnectError::UnexpectedEnd)?;
if typ != Frame::HEADERS {
return Err(ConnectError::UnexpectedFrame(typ));
}
let headers = qpack::Headers::decode(&mut data)?;
let status = match headers
.get(":status")
.map(|status| {
http::StatusCode::from_str(status).map_err(|_| ConnectError::InvalidStatus)
})
.transpose()?
{
Some(status) if status.is_success() => status,
o => return Err(ConnectError::WrongStatus(o)),
};
let protocol = headers
.get(protocol_negotiation::SELECTED_NAME_STD)
.or_else(|| headers.get(protocol_negotiation::SELECTED_NAME))
.map(protocol_negotiation::decode_item)
.transpose()
.map_err(|_| ConnectError::InvalidProtocol)?;
Ok(Self { status, protocol })
}
pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, ConnectError> {
let mut buf = Vec::new();
loop {
if stream.read_buf(&mut buf).await? == 0 {
return Err(ConnectError::UnexpectedEnd);
}
let mut limit = std::io::Cursor::new(&buf);
match Self::decode(&mut limit) {
Ok(response) => return Ok(response),
Err(ConnectError::UnexpectedEnd) => continue,
Err(e) => return Err(e),
}
}
}
pub fn encode<B: BufMut>(&self, buf: &mut B) -> Result<(), ConnectError> {
let mut headers = qpack::Headers::default();
headers.set(":status", self.status.as_str());
headers.set("sec-webtransport-http3-draft", "draft02");
if let Some(protocol) = self.protocol.as_ref() {
let encoded = protocol_negotiation::encode_item(protocol)?;
// Send both the standard and legacy header names to maximize interop.
headers.set(protocol_negotiation::SELECTED_NAME_STD, &encoded);
headers.set(protocol_negotiation::SELECTED_NAME, &encoded);
}
// Use a temporary buffer so we can compute the size.
let mut tmp = Vec::new();
headers.encode(&mut tmp);
let size = VarInt::from_u32(tmp.len() as u32);
Frame::HEADERS.encode(buf);
size.encode(buf);
buf.put_slice(&tmp);
Ok(())
}
pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), ConnectError> {
let mut buf = BytesMut::new();
self.encode(&mut buf)?;
stream.write_all_buf(&mut buf).await?;
Ok(())
}
}
impl Default for ConnectResponse {
fn default() -> Self {
Self::OK
}
}
impl From<http::StatusCode> for ConnectResponse {
fn from(status: http::StatusCode) -> Self {
Self {
status,
protocol: None,
}
}
}
mod protocol_negotiation {
//! WebTransport sub-protocol negotiation using RFC 8941 Structured Fields,
//!
//! according to [draft 14](https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-14.html#section-3.3)
use sfv::{Item, ItemSerializer, List, ListEntry, ListSerializer, Parser, StringRef};
use crate::ConnectError;
/// The header name for the available protocols, sent within the WebTransport Connect request.
/// Legacy name kept for compatibility with older servers/clients.
pub const AVAILABLE_NAME: &str = "wt-available-protocols";
/// Standard WebTransport sub-protocol negotiation header name.
pub const AVAILABLE_NAME_STD: &str = "sec-webtransport-protocol";
/// The header name for the selected protocol, sent within the WebTransport Connect response.
/// Legacy name kept for compatibility with older servers/clients.
pub const SELECTED_NAME: &str = "wt-protocol";
/// Standard WebTransport sub-protocol negotiation header name.
pub const SELECTED_NAME_STD: &str = "sec-webtransport-protocol";
/// Encode a list of protocol strings as an RFC 8941 Structured Field List.
pub fn encode_list(protocols: &[String]) -> Result<String, ConnectError> {
let mut serializer = ListSerializer::new();
for protocol in protocols {
let s = StringRef::from_str(protocol)?;
let _ = serializer.bare_item(s);
}
serializer.finish().ok_or(ConnectError::InvalidProtocol)
}
/// Decode an RFC 8941 Structured Field List of strings.
pub fn decode_list(value: &str) -> Result<Vec<String>, ConnectError> {
let list = Parser::new(value).parse::<List>()?;
list.iter()
.map(|entry| match entry {
ListEntry::Item(item) => Ok(item
.bare_item
.as_string()
.ok_or(ConnectError::InvalidProtocol)?
.as_str()
.to_string()),
_ => Err(ConnectError::InvalidProtocol),
})
.collect()
}
/// Encode a single string as an RFC 8941 Structured Field Item.
pub fn encode_item(protocol: &str) -> Result<String, ConnectError> {
let s = StringRef::from_str(protocol)?;
Ok(ItemSerializer::new().bare_item(s).finish())
}
/// Decode an RFC 8941 Structured Field Item (single string).
pub fn decode_item(value: &str) -> Result<String, ConnectError> {
let item = Parser::new(value).parse::<Item>()?;
Ok(item
.bare_item
.as_string()
.ok_or(ConnectError::InvalidProtocol)?
.as_str()
.to_string())
}
}

View file

@ -0,0 +1,18 @@
// WebTransport shares with HTTP/3, so we can't start at 0 or use the full VarInt.
const ERROR_FIRST: u64 = 0x52e4a40fa8db;
const ERROR_LAST: u64 = 0x52e5ac983162;
pub const fn error_from_http3(code: u64) -> Option<u32> {
if code < ERROR_FIRST || code > ERROR_LAST {
return None;
}
let code = code - ERROR_FIRST;
let code = code - code / 0x1f;
Some(code as u32)
}
pub const fn error_to_http3(code: u32) -> u64 {
ERROR_FIRST + code as u64 + code as u64 / 0x1e
}

View file

@ -0,0 +1,65 @@
use bytes::{Buf, BufMut};
use crate::{VarInt, VarIntUnexpectedEnd};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Frame(pub VarInt);
impl Frame {
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, VarIntUnexpectedEnd> {
let typ = VarInt::decode(buf)?;
Ok(Frame(typ))
}
pub fn encode<B: BufMut>(&self, buf: &mut B) {
self.0.encode(buf)
}
pub fn is_grease(&self) -> bool {
let val = self.0.into_inner();
if val < 0x21 {
return false;
}
#[allow(unknown_lints, clippy::manual_is_multiple_of)]
{
(val - 0x21) % 0x1f == 0
}
}
pub fn read<B: Buf>(
buf: &mut B,
) -> Result<(Frame, bytes::buf::Take<&mut B>), VarIntUnexpectedEnd> {
let typ = Frame::decode(buf)?;
let size = VarInt::decode(buf)?;
let mut limit = Buf::take(buf, size.into_inner() as usize);
if limit.remaining() < limit.limit() {
return Err(VarIntUnexpectedEnd);
}
// Try again if this is a GREASE frame we need to ignore
if typ.is_grease() {
limit.advance(limit.limit());
return Self::read(limit.into_inner());
}
Ok((typ, limit))
}
}
macro_rules! frames {
{$($name:ident = $val:expr,)*} => {
impl Frame {
$(pub const $name: Frame = Frame(VarInt::from_u32($val));)*
}
}
}
// Sent at the start of a bidirectional stream.
frames! {
DATA = 0x00,
HEADERS = 0x01,
SETTINGS = 0x04,
WEBTRANSPORT = 0x41,
}

View file

@ -0,0 +1,364 @@
// Huffman encoding is a compression technique that replaces common strings with shorter codes.
// Ugh I wish we didn't have to implement this, but the other endpoint is allowed to use it.
// Taken from https://github.com/hyperium/h3/blob/master/h3/src/qpack/prefix_string/decode.rs
// License: MIT
#[derive(Debug, Default, PartialEq, Clone)]
pub struct BitWindow {
pub byte: u32,
pub bit: u32,
pub count: u32,
}
impl BitWindow {
pub fn new() -> Self {
Self::default()
}
pub fn forwards(&mut self, step: u32) {
self.bit += self.count;
self.byte += self.bit / 8;
self.bit %= 8;
self.count = step;
}
pub fn opposite_bit_window(&self) -> BitWindow {
BitWindow {
byte: self.byte,
bit: self.bit,
count: 8 - (self.bit % 8),
}
}
}
use thiserror::Error;
#[derive(Error, Debug, PartialEq, Clone)]
pub enum Error {
#[error("missing bits: {0:?}")]
MissingBits(BitWindow),
#[error("unhandled: {0:?} {1:?}")]
Unhandled(BitWindow, usize),
}
#[derive(Clone, Debug)]
enum DecodeValue {
Partial(&'static HuffmanDecoder),
Sym(u8),
}
#[derive(Clone, Debug)]
struct HuffmanDecoder {
lookup: u32,
table: &'static [DecodeValue],
}
impl HuffmanDecoder {
fn check_eof(&self, bit_pos: &mut BitWindow, input: &[u8]) -> Result<Option<u32>, Error> {
use std::cmp::Ordering;
match ((bit_pos.byte + 1) as usize).cmp(&input.len()) {
// Position is out-of-range
Ordering::Greater => {
return Ok(None);
}
// Position is on the last byte
Ordering::Equal => {
let side = bit_pos.opposite_bit_window();
let rest = match read_bits(input, side.byte, side.bit, side.count) {
Ok(x) => x,
Err(()) => {
return Err(Error::MissingBits(side));
}
};
let eof_filler = ((2u16 << (side.count - 1)) - 1) as u8;
if rest & eof_filler == eof_filler {
return Ok(None);
}
}
Ordering::Less => {}
}
Err(Error::MissingBits(bit_pos.clone()))
}
fn fetch_value(&self, bit_pos: &mut BitWindow, input: &[u8]) -> Result<Option<u32>, Error> {
match read_bits(input, bit_pos.byte, bit_pos.bit, bit_pos.count) {
Ok(value) => Ok(Some(value as u32)),
Err(()) => self.check_eof(bit_pos, input),
}
}
fn decode_next(&self, bit_pos: &mut BitWindow, input: &[u8]) -> Result<Option<u8>, Error> {
bit_pos.forwards(self.lookup);
let value = match self.fetch_value(bit_pos, input) {
Ok(Some(value)) => value as usize,
Ok(None) => return Ok(None),
Err(err) => return Err(err),
};
let at_value = match (self.table).get(value) {
Some(x) => x,
None => return Err(Error::Unhandled(bit_pos.clone(), value)),
};
match at_value {
DecodeValue::Sym(x) => Ok(Some(*x)),
DecodeValue::Partial(d) => d.decode_next(bit_pos, input),
}
}
}
/// Read `len` bits from the `src` slice at the specified position
///
/// Never read more than 8 bits at a time. `bit_offset` may be larger than 8.
fn read_bits(src: &[u8], mut byte_offset: u32, mut bit_offset: u32, len: u32) -> Result<u8, ()> {
if len == 0 || len > 8 || src.len() as u32 * 8 < (byte_offset * 8) + bit_offset + len {
return Err(());
}
// Deal with `bit_offset` > 8
byte_offset += bit_offset / 8;
bit_offset -= (bit_offset / 8) * 8;
Ok(if bit_offset + len <= 8 {
// Read all the bits from a single byte
(src[byte_offset as usize] << bit_offset) >> (8 - len)
} else {
// The range of bits spans over 2 bytes
let mut result = (src[byte_offset as usize] as u16) << 8;
result |= src[byte_offset as usize + 1] as u16;
((result << bit_offset) >> (16 - len)) as u8
})
}
macro_rules! bits_decode {
// general way
(
lookup: $count:expr, [
$($sym:expr,)*
$(=> $sub:ident,)* ]
) => {
HuffmanDecoder {
lookup: $count,
table: &[
$( DecodeValue::Sym($sym as u8), )*
$( DecodeValue::Partial(&$sub), )*
]
}
};
// 2-final
( $first:expr, $second:expr ) => {
HuffmanDecoder {
lookup: 1,
table: &[
DecodeValue::Sym($first as u8),
DecodeValue::Sym($second as u8),
]
}
};
// 4-final
( $first:expr, $second:expr, $third:expr, $fourth:expr ) => {
HuffmanDecoder {
lookup: 2,
table: &[
DecodeValue::Sym($first as u8),
DecodeValue::Sym($second as u8),
DecodeValue::Sym($third as u8),
DecodeValue::Sym($fourth as u8),
]
}
};
// 2-final-partial
( $first:expr, => $second:ident ) => {
HuffmanDecoder {
lookup: 1,
table: &[
DecodeValue::Sym($first as u8),
DecodeValue::Partial(&$second),
]
}
};
// 2-partial
( => $first:ident, => $second:ident ) => {
HuffmanDecoder {
lookup: 1,
table: &[
DecodeValue::Partial(&$first),
DecodeValue::Partial(&$second),
]
}
};
// 4-partial
( => $first:ident, => $second:ident,
=> $third:ident, => $fourth:ident ) => {
HuffmanDecoder {
lookup: 2,
table: &[
DecodeValue::Partial(&$first),
DecodeValue::Partial(&$second),
DecodeValue::Partial(&$third),
DecodeValue::Partial(&$fourth),
]
}
};
[ $( $name:ident => ( $($value:tt)* ), )* ] => {
$( const $name: HuffmanDecoder = bits_decode!( $( $value )* ); )*
};
}
#[rustfmt::skip]
bits_decode![
HPACK_STRING => (
lookup: 5, [ b'0', b'1', b'2', b'a', b'c', b'e', b'i', b'o', b's', b't',
=> END0_01010, => END0_01011, => END0_01100, => END0_01101,
=> END0_01110, => END0_01111, => END0_10000, => END0_10001,
=> END0_10010, => END0_10011, => END0_10100, => END0_10101,
=> END0_10110, => END0_10111, => END0_11000, => END0_11001,
=> END0_11010, => END0_11011, => END0_11100, => END0_11101,
=> END0_11110, => END0_11111,
]),
END0_01010 => ( 32, b'%'),
END0_01011 => (b'-', b'.'),
END0_01100 => (b'/', b'3'),
END0_01101 => (b'4', b'5'),
END0_01110 => (b'6', b'7'),
END0_01111 => (b'8', b'9'),
END0_10000 => (b'=', b'A'),
END0_10001 => (b'_', b'b'),
END0_10010 => (b'd', b'f'),
END0_10011 => (b'g', b'h'),
END0_10100 => (b'l', b'm'),
END0_10101 => (b'n', b'p'),
END0_10110 => (b'r', b'u'),
END0_10111 => (b':', b'B', b'C', b'D'),
END0_11000 => (b'E', b'F', b'G', b'H'),
END0_11001 => (b'I', b'J', b'K', b'L'),
END0_11010 => (b'M', b'N', b'O', b'P'),
END0_11011 => (b'Q', b'R', b'S', b'T'),
END0_11100 => (b'U', b'V', b'W', b'Y'),
END0_11101 => (b'j', b'k', b'q', b'v'),
END0_11110 => (b'w', b'x', b'y', b'z'),
END0_11111 => (=> END5_00, => END5_01, => END5_10, => END5_11),
END5_00 => (b'&', b'*'),
END5_01 => (b',', 59),
END5_10 => (b'X', b'Z'),
END5_11 => (=> END7_0, => END7_1),
END7_0 => (b'!', b'"', b'(', b')'),
END7_1 => (=> END8_0, => END8_1),
END8_0 => (b'?', => END9A_1),
END9A_1 => (b'\'', b'+'),
END8_1 => (lookup: 2, [b'|', => END9B_01, => END9B_10, => END9B_11,]),
END9B_01 => (b'#', b'>'),
END9B_10 => (0, b'$', b'@', b'['),
END9B_11 => (lookup: 2, [b']', b'~', => END13_10, => END13_11,]),
END13_10 => (b'^', b'}'),
END13_11 => (=> END14_0, => END14_1),
END14_0 => (b'<', b'`'),
END14_1 => (b'{', => END15_1),
END15_1 =>
(lookup: 4, [ b'\\', 195, 208, => END19_0011,
=> END19_0100, => END19_0101, => END19_0110, => END19_0111,
=> END19_1000, => END19_1001, => END19_1010, => END19_1011,
=> END19_1100, => END19_1101, => END19_1110, => END19_1111,
]),
END19_0011 => (128, 130),
END19_0100 => (131, 162),
END19_0101 => (184, 194),
END19_0110 => (224, 226),
END19_0111 => (153, 161, 167, 172),
END19_1000 => (176, 177, 179, 209),
END19_1001 => (216, 217, 227, 229),
END19_1010 => (lookup: 2, [230, => END19_1010_01, => END19_1010_10,
=> END19_1010_11,]),
END19_1010_01 => (129, 132),
END19_1010_10 => (133, 134),
END19_1010_11 => (136, 146),
END19_1011 => (lookup: 3, [154, 156, 160, 163, 164, 169, 170, 173,]),
END19_1100 => (lookup: 3, [178, 181, 185, 186, 187, 189, 190, 196,]),
END19_1101 => (lookup: 3, [198, 228, 232, 233,
=> END23A_100, => END23A_101,
=> END23A_110, => END23A_111,]),
END23A_100 => ( 1, 135),
END23A_101 => (137, 138),
END23A_110 => (139, 140),
END23A_111 => (141, 143),
END19_1110 => (lookup: 4, [147, 149, 150, 151, 152, 155, 157, 158,
165, 166, 168, 174, 175, 180, 182, 183,]),
END19_1111 => (lookup: 4, [188, 191, 197, 231, 239,
=> END23B_0101, => END23B_0110, => END23B_0111,
=> END23B_1000, => END23B_1001, => END23B_1010,
=> END23B_1011, => END23B_1100, => END23B_1101,
=> END23B_1110, => END23B_1111,]),
END23B_0101 => ( 9, 142),
END23B_0110 => (144, 145),
END23B_0111 => (148, 159),
END23B_1000 => (171, 206),
END23B_1001 => (215, 225),
END23B_1010 => (236, 237),
END23B_1011 => (199, 207, 234, 235),
END23B_1100 => (lookup: 3, [192, 193, 200, 201, 202, 205, 210, 213,]),
END23B_1101 => (lookup: 3, [218, 219, 238, 240, 242, 243, 255,
=> END27A_111,]),
END27A_111 => (203, 204),
END23B_1110 => (lookup: 4, [211, 212, 214, 221, 222, 223, 241, 244,
245, 246, 247, 248, 250, 251, 252, 253,]),
END23B_1111 => (lookup: 4, [ 254, => END27B_0001, => END27B_0010,
=> END27B_0011, => END27B_0100, => END27B_0101,
=> END27B_0110, => END27B_0111, => END27B_1000,
=> END27B_1001, => END27B_1010, => END27B_1011,
=> END27B_1100, => END27B_1101, => END27B_1110,
=> END27B_1111,]),
END27B_0001 => (2, 3),
END27B_0010 => (4, 5),
END27B_0011 => (6, 7),
END27B_0100 => (8, 11),
END27B_0101 => (12, 14),
END27B_0110 => (15, 16),
END27B_0111 => (17, 18),
END27B_1000 => (19, 20),
END27B_1001 => (21, 23),
END27B_1010 => (24, 25),
END27B_1011 => (26, 27),
END27B_1100 => (28, 29),
END27B_1101 => (30, 31),
END27B_1110 => (127, 220),
END27B_1111 => (lookup: 1, [249, => END31_1,]),
END31_1 => (lookup: 2, [10, 13, 22, => EOF,]),
EOF => (lookup: 8, []),
];
pub struct DecodeIter<'a> {
bit_pos: BitWindow,
content: &'a Vec<u8>,
}
impl Iterator for DecodeIter<'_> {
type Item = Result<u8, Error>;
fn next(&mut self) -> Option<Self::Item> {
match HPACK_STRING.decode_next(&mut self.bit_pos, self.content) {
Ok(Some(x)) => Some(Ok(x)),
Err(err) => Some(Err(err)),
Ok(None) => None,
}
}
}
pub trait HpackStringDecode {
fn hpack_decode(&self) -> DecodeIter<'_>;
}
impl HpackStringDecode for Vec<u8> {
fn hpack_decode(&self) -> DecodeIter<'_> {
DecodeIter {
bit_pos: BitWindow::new(),
content: self,
}
}
}

View file

@ -0,0 +1,18 @@
mod capsule;
mod connect;
mod error;
mod frame;
mod settings;
mod stream;
mod varint;
pub use capsule::*;
pub use connect::*;
pub use error::*;
pub use frame::*;
pub use settings::*;
pub use stream::*;
pub use varint::*;
mod huffman;
mod qpack;

View file

@ -0,0 +1,627 @@
// This is a minimal QPACK implementation that only supports the static table and literals.
// By refusing to acknowledge the QPACK encoder, we can avoid implementing the dynamic table altogether.
// This is not recommended for a full HTTP/3 implementation but it's literally more efficient for handling a single WebTransport CONNECT request.
use std::collections::HashMap;
use bytes::{Buf, BufMut};
use super::huffman::{self, HpackStringDecode};
use thiserror::Error;
#[derive(Error, Debug, Clone)]
pub enum DecodeError {
#[error("unexpected end of input")]
UnexpectedEnd,
#[error("varint bounds exceeded")]
BoundsExceeded,
#[error("dynamic references not supported")]
DynamicEntry,
#[error("unknown entry")]
UnknownEntry,
#[error("huffman decoding error")]
HuffmanError(#[from] huffman::Error),
#[error("invalid utf8 header")] // technically not required by the HTTP spec
Utf8Error(#[from] std::str::Utf8Error),
}
#[cfg(target_pointer_width = "64")]
const MAX_POWER: usize = 10 * 7;
#[cfg(target_pointer_width = "32")]
const MAX_POWER: usize = 5 * 7;
// Simple QPACK implementation that ONLY supports the static table and literals.
#[derive(Debug, Default)]
pub struct Headers {
fields: HashMap<String, String>,
}
impl Headers {
pub fn get(&self, name: &str) -> Option<&str> {
self.fields.get(name).map(|v| v.as_str())
}
pub fn set(&mut self, name: &str, value: &str) {
self.fields.insert(name.to_string(), value.to_string());
}
pub fn decode<B: Buf>(mut buf: &mut B) -> Result<Self, DecodeError> {
// We don't support dynamic entries so we can skip these.
let (_, _insert_count) = decode_prefix(buf, 8)?;
let (_sign, _delta_base) = decode_prefix(buf, 7)?;
let mut fields = HashMap::new();
while buf.has_remaining() {
// Read the first byte;
let peek = buf.get_u8();
// Read the byte again by chaining Bufs.
let first = [peek];
let mut chain = first.chain(buf);
// See: https://www.rfc-editor.org/rfc/rfc9204.html#section-4.5.2
// This is over-engineered, LUL
let (name, value) = match peek & 0b1100_0000 {
// Indexed line field from static table
0b1100_0000 => Self::decode_index(&mut chain)?,
// Indexed line field from dynamic table
0b1000_0000 => return Err(DecodeError::DynamicEntry),
_ => match peek & 0b1101_0000 {
// Indexed with literal name ref from static table
0b0101_0000 => Self::decode_literal_value(&mut chain)?,
// Indexed with literal name ref from dynamic table
0b0100_0000 => return Err(DecodeError::DynamicEntry),
// Literal
_ if peek & 0b1110_0000 == 0b0010_0000 => Self::decode_literal(&mut chain)?,
_ => match peek & 0b1111_0000 {
// Indexed with post base
0b0001_0000 => return Err(DecodeError::DynamicEntry),
// Indexed with post base name ref
0b0000_0000 => return Err(DecodeError::DynamicEntry),
// ugh
_ => return Err(DecodeError::UnknownEntry),
},
},
};
fields.insert(name, value);
// Get the buffer back.
(_, buf) = chain.into_inner();
}
Ok(Self { fields })
}
fn decode_index<B: Buf>(buf: &mut B) -> Result<(String, String), DecodeError> {
/*
0 1 2 3 4 5 6 7
+---+---+---+---+---+---+---+---+
| 1 | 1 | Index (6+) |
+---+---+-----------------------+
*/
let (_, index) = decode_prefix(buf, 6)?;
let (name, value) = StaticTable::get(index)?;
Ok((name.to_string(), value.to_string()))
}
fn decode_literal_value<B: Buf>(buf: &mut B) -> Result<(String, String), DecodeError> {
/*
0 1 2 3 4 5 6 7
+---+---+---+---+---+---+---+---+
| 0 | 1 | N | 1 |Name Index (4+)|
+---+---+---+---+---------------+
| H | Value Length (7+) |
+---+---------------------------+
| Value String (Length bytes) |
+-------------------------------+
*/
let (_, name) = decode_prefix(buf, 4)?;
let (name, _) = StaticTable::get(name)?;
let value = decode_string(buf, 8)?;
let value = std::str::from_utf8(&value)?;
Ok((name.to_string(), value.to_string()))
}
fn decode_literal<B: Buf>(buf: &mut B) -> Result<(String, String), DecodeError> {
/*
0 1 2 3 4 5 6 7
+---+---+---+---+---+---+---+---+
| 0 | 0 | 1 | N | H |NameLen(3+)|
+---+---+---+---+---+-----------+
| Name String (Length bytes) |
+---+---------------------------+
| H | Value Length (7+) |
+---+---------------------------+
| Value String (Length bytes) |
+-------------------------------+
*/
let name = decode_string(buf, 4)?;
let name = std::str::from_utf8(&name)?;
let value = decode_string(buf, 8)?;
let value = std::str::from_utf8(&value)?;
Ok((name.to_string(), value.to_string()))
}
pub fn encode<B: BufMut>(&self, buf: &mut B) {
// We don't support dynamic entries so we can skip these.
encode_prefix(buf, 8, 0, 0);
encode_prefix(buf, 7, 0, 0);
// We must encode pseudo-headers first.
// https://datatracker.ietf.org/doc/html/rfc9114#section-4.1.2
let mut headers: Vec<_> = self.fields.iter().collect();
headers.sort_by_key(|&(key, _)| !key.starts_with(':'));
for (name, value) in headers.iter() {
if let Some(index) = StaticTable::find(name, value) {
Self::encode_index(buf, index)
} else if let Some(index) = StaticTable::find_name(name) {
Self::encode_literal_value(buf, index, value)
} else {
Self::encode_literal(buf, name, value)
}
}
}
fn encode_index<B: BufMut>(buf: &mut B, index: usize) {
/*
0 1 2 3 4 5 6 7
+---+---+---+---+---+---+---+---+
| 1 | 1 | Index (6+) |
+---+---+-----------------------+
*/
encode_prefix(buf, 6, 0b11, index);
}
fn encode_literal_value<B: BufMut>(buf: &mut B, name: usize, value: &str) {
/*
0 1 2 3 4 5 6 7
+---+---+---+---+---+---+---+---+
| 0 | 1 | N | 1 |Name Index (4+)|
+---+---+---+---+---------------+
| H | Value Length (7+) |
+---+---------------------------+
| Value String (Length bytes) |
+-------------------------------+
*/
encode_prefix(buf, 4, 0b0101, name);
encode_prefix(buf, 7, 0b0, value.len());
buf.put_slice(value.as_bytes());
}
fn encode_literal<B: BufMut>(buf: &mut B, name: &str, value: &str) {
/*
0 1 2 3 4 5 6 7
+---+---+---+---+---+---+---+---+
| 0 | 0 | 1 | N | H |NameLen(3+)|
+---+---+---+---+---+-----------+
| Name String (Length bytes) |
+---+---------------------------+
| H | Value Length (7+) |
+---+---------------------------+
| Value String (Length bytes) |
+-------------------------------+
*/
encode_prefix(buf, 3, 0b00100, name.len());
buf.put_slice(name.as_bytes());
encode_prefix(buf, 7, 0b0, value.len());
buf.put_slice(value.as_bytes());
}
}
// An integer that uses a fixed number of bits, otherwise a variable number of bytes if it's too large.
// https://www.rfc-editor.org/rfc/rfc7541#section-5.1
// Based on : https://github.com/hyperium/h3/blob/master/h3/src/qpack/prefix_int.rs
// License: MIT
pub fn decode_prefix<B: Buf>(buf: &mut B, size: u8) -> Result<(u8, usize), DecodeError> {
assert!(size <= 8);
if !buf.has_remaining() {
return Err(DecodeError::UnexpectedEnd);
}
let mut first = buf.get_u8();
// NOTE: following casts to u8 intend to trim the most significant bits, they are used as a
// workaround for shiftoverflow errors when size == 8.
let flags = ((first as usize) >> size) as u8;
let mask = 0xFF >> (8 - size);
first &= mask;
// if first < 2usize.pow(size) - 1
if first < mask {
return Ok((flags, first as usize));
}
let mut value = mask as usize;
let mut power = 0usize;
loop {
if !buf.has_remaining() {
return Err(DecodeError::UnexpectedEnd);
}
let byte = buf.get_u8() as usize;
value += (byte & 127) << power;
power += 7;
if byte & 128 == 0 {
break;
}
if power >= MAX_POWER {
return Err(DecodeError::BoundsExceeded);
}
}
Ok((flags, value))
}
pub fn encode_prefix<B: BufMut>(buf: &mut B, size: u8, flags: u8, value: usize) {
assert!(size > 0 && size <= 8);
// NOTE: following casts to u8 intend to trim the most significant bits, they are used as a
// workaround for shiftoverflow errors when size == 8.
let mask = !(0xFF << size) as u8;
let flags = ((flags as usize) << size) as u8;
// if value < 2usize.pow(size) - 1
if value < (mask as usize) {
buf.put_u8(flags | value as u8);
return;
}
buf.put_u8(mask | flags);
let mut remaining = value - mask as usize;
while remaining >= 128 {
let rest = (remaining % 128) as u8;
buf.put_u8(rest + 128);
remaining /= 128;
}
buf.put_u8(remaining as u8);
}
pub fn decode_string<B: Buf>(buf: &mut B, size: u8) -> Result<Vec<u8>, DecodeError> {
if !buf.has_remaining() {
return Err(DecodeError::UnexpectedEnd);
}
let (flags, len) = decode_prefix(buf, size - 1)?;
if buf.remaining() < len {
return Err(DecodeError::UnexpectedEnd);
}
let payload = buf.copy_to_bytes(len);
let value: Vec<u8> = if flags & 1 == 0 {
payload.into_iter().collect()
} else {
let mut decoded = Vec::new();
for byte in payload.into_iter().collect::<Vec<u8>>().hpack_decode() {
decoded.push(byte?);
}
decoded
};
Ok(value)
}
// Based on https://github.com/hyperium/h3/blob/master/h3/src/qpack/static_.rs
// I switched over to str because it's nicer in Rust... even though HTTP doesn't use utf8.
struct StaticTable {}
impl StaticTable {
pub fn get(index: usize) -> Result<(&'static str, &'static str), DecodeError> {
match PREDEFINED_HEADERS.get(index) {
Some(v) => Ok(*v),
None => Err(DecodeError::UnknownEntry),
}
}
// TODO combine find and find_name to do a single lookup
pub fn find(name: &str, value: &str) -> Option<usize> {
match (name, value) {
(":authority", "") => Some(0),
(":path", "/") => Some(1),
("age", "0") => Some(2),
("content-disposition", "") => Some(3),
("content-length", "0") => Some(4),
("cookie", "") => Some(5),
("date", "") => Some(6),
("etag", "") => Some(7),
("if-modified-since", "") => Some(8),
("if-none-match", "") => Some(9),
("last-modified", "") => Some(10),
("link", "") => Some(11),
("location", "") => Some(12),
("referer", "") => Some(13),
("set-cookie", "") => Some(14),
(":method", "CONNECT") => Some(15),
(":method", "DELETE") => Some(16),
(":method", "GET") => Some(17),
(":method", "HEAD") => Some(18),
(":method", "OPTIONS") => Some(19),
(":method", "POST") => Some(20),
(":method", "PUT") => Some(21),
(":scheme", "http") => Some(22),
(":scheme", "https") => Some(23),
(":status", "103") => Some(24),
(":status", "200") => Some(25),
(":status", "304") => Some(26),
(":status", "404") => Some(27),
(":status", "503") => Some(28),
("accept", "*/*") => Some(29),
("accept", "application/dns-message") => Some(30),
("accept-encoding", "gzip, deflate, br") => Some(31),
("accept-ranges", "bytes") => Some(32),
("access-control-allow-headers", "cache-control") => Some(33),
("access-control-allow-headers", "content-type") => Some(34),
("access-control-allow-origin", "*") => Some(35),
("cache-control", "max-age=0") => Some(36),
("cache-control", "max-age=2592000") => Some(37),
("cache-control", "max-age=604800") => Some(38),
("cache-control", "no-cache") => Some(39),
("cache-control", "no-store") => Some(40),
("cache-control", "public, max-age=31536000") => Some(41),
("content-encoding", "br") => Some(42),
("content-encoding", "gzip") => Some(43),
("content-type", "application/dns-message") => Some(44),
("content-type", "application/javascript") => Some(45),
("content-type", "application/json") => Some(46),
("content-type", "application/x-www-form-urlencoded") => Some(47),
("content-type", "image/gif") => Some(48),
("content-type", "image/jpeg") => Some(49),
("content-type", "image/png") => Some(50),
("content-type", "text/css") => Some(51),
("content-type", "text/html; charset=utf-8") => Some(52),
("content-type", "text/plain") => Some(53),
("content-type", "text/plain;charset=utf-8") => Some(54),
("range", "bytes=0-") => Some(55),
("strict-transport-security", "max-age=31536000") => Some(56),
("strict-transport-security", "max-age=31536000; includesubdomains") => Some(57),
("strict-transport-security", "max-age=31536000; includesubdomains; preload") => {
Some(58)
}
("vary", "accept-encoding") => Some(59),
("vary", "origin") => Some(60),
("x-content-type-options", "nosniff") => Some(61),
("x-xss-protection", "1; mode=block") => Some(62),
(":status", "100") => Some(63),
(":status", "204") => Some(64),
(":status", "206") => Some(65),
(":status", "302") => Some(66),
(":status", "400") => Some(67),
(":status", "403") => Some(68),
(":status", "421") => Some(69),
(":status", "425") => Some(70),
(":status", "500") => Some(71),
("accept-language", "") => Some(72),
("access-control-allow-credentials", "FALSE") => Some(73),
("access-control-allow-credentials", "TRUE") => Some(74),
("access-control-allow-headers", "*") => Some(75),
("access-control-allow-methods", "get") => Some(76),
("access-control-allow-methods", "get, post, options") => Some(77),
("access-control-allow-methods", "options") => Some(78),
("access-control-expose-headers", "content-length") => Some(79),
("access-control-request-headers", "content-type") => Some(80),
("access-control-request-method", "get") => Some(81),
("access-control-request-method", "post") => Some(82),
("alt-svc", "clear") => Some(83),
("authorization", "") => Some(84),
(
"content-security-policy",
"script-src 'none'; object-src 'none'; base-uri 'none'",
) => Some(85),
("early-data", "1") => Some(86),
("expect-ct", "") => Some(87),
("forwarded", "") => Some(88),
("if-range", "") => Some(89),
("origin", "") => Some(90),
("purpose", "prefetch") => Some(91),
("server", "") => Some(92),
("timing-allow-origin", "*") => Some(93),
("upgrade-insecure-requests", "1") => Some(94),
("user-agent", "") => Some(95),
("x-forwarded-for", "") => Some(96),
("x-frame-options", "deny") => Some(97),
("x-frame-options", "sameorigin") => Some(98),
_ => None,
}
}
pub fn find_name(name: &str) -> Option<usize> {
match name {
":authority" => Some(0),
":path" => Some(1),
"age" => Some(2),
"content-disposition" => Some(3),
"content-length" => Some(4),
"cookie" => Some(5),
"date" => Some(6),
"etag" => Some(7),
"if-modified-since" => Some(8),
"if-none-match" => Some(9),
"last-modified" => Some(10),
"link" => Some(11),
"location" => Some(12),
"referer" => Some(13),
"set-cookie" => Some(14),
":method" => Some(15),
":scheme" => Some(22),
":status" => Some(24),
"accept" => Some(29),
"accept-encoding" => Some(31),
"accept-ranges" => Some(32),
"access-control-allow-headers" => Some(33),
"access-control-allow-origin" => Some(35),
"cache-control" => Some(36),
"content-encoding" => Some(42),
"content-type" => Some(44),
"range" => Some(55),
"strict-transport-security" => Some(56),
"vary" => Some(59),
"x-content-type-options" => Some(61),
"x-xss-protection" => Some(62),
"accept-language" => Some(72),
"access-control-allow-credentials" => Some(73),
"access-control-allow-methods" => Some(76),
"access-control-expose-headers" => Some(79),
"access-control-request-headers" => Some(80),
"access-control-request-method" => Some(81),
"alt-svc" => Some(83),
"authorization" => Some(84),
"content-security-policy" => Some(85),
"early-data" => Some(86),
"expect-ct" => Some(87),
"forwarded" => Some(88),
"if-range" => Some(89),
"origin" => Some(90),
"purpose" => Some(91),
"server" => Some(92),
"timing-allow-origin" => Some(93),
"upgrade-insecure-requests" => Some(94),
"user-agent" => Some(95),
"x-forwarded-for" => Some(96),
"x-frame-options" => Some(97),
_ => None,
}
}
}
const PREDEFINED_HEADERS: [(&str, &str); 99] = [
(":authority", ""),
(":path", "/"),
("age", "0"),
("content-disposition", ""),
("content-length", "0"),
("cookie", ""),
("date", ""),
("etag", ""),
("if-modified-since", ""),
("if-none-match", ""),
("last-modified", ""),
("link", ""),
("location", ""),
("referer", ""),
("set-cookie", ""),
(":method", "CONNECT"),
(":method", "DELETE"),
(":method", "GET"),
(":method", "HEAD"),
(":method", "OPTIONS"),
(":method", "POST"),
(":method", "PUT"),
(":scheme", "http"),
(":scheme", "https"),
(":status", "103"),
(":status", "200"),
(":status", "304"),
(":status", "404"),
(":status", "503"),
("accept", "*/*"),
("accept", "application/dns-message"),
("accept-encoding", "gzip, deflate, br"),
("accept-ranges", "bytes"),
("access-control-allow-headers", "cache-control"),
("access-control-allow-headers", "content-type"),
("access-control-allow-origin", "*"),
("cache-control", "max-age=0"),
("cache-control", "max-age=2592000"),
("cache-control", "max-age=604800"),
("cache-control", "no-cache"),
("cache-control", "no-store"),
("cache-control", "public, max-age=31536000"),
("content-encoding", "br"),
("content-encoding", "gzip"),
("content-type", "application/dns-message"),
("content-type", "application/javascript"),
("content-type", "application/json"),
("content-type", "application/x-www-form-urlencoded"),
("content-type", "image/gif"),
("content-type", "image/jpeg"),
("content-type", "image/png"),
("content-type", "text/css"),
("content-type", "text/html; charset=utf-8"),
("content-type", "text/plain"),
("content-type", "text/plain;charset=utf-8"),
("range", "bytes=0-"),
("strict-transport-security", "max-age=31536000"),
(
"strict-transport-security",
"max-age=31536000; includesubdomains",
),
(
"strict-transport-security",
"max-age=31536000; includesubdomains; preload",
),
("vary", "accept-encoding"),
("vary", "origin"),
("x-content-type-options", "nosniff"),
("x-xss-protection", "1; mode=block"),
(":status", "100"),
(":status", "204"),
(":status", "206"),
(":status", "302"),
(":status", "400"),
(":status", "403"),
(":status", "421"),
(":status", "425"),
(":status", "500"),
("accept-language", ""),
("access-control-allow-credentials", "FALSE"),
("access-control-allow-credentials", "TRUE"),
("access-control-allow-headers", "*"),
("access-control-allow-methods", "get"),
("access-control-allow-methods", "get, post, options"),
("access-control-allow-methods", "options"),
("access-control-expose-headers", "content-length"),
("access-control-request-headers", "content-type"),
("access-control-request-method", "get"),
("access-control-request-method", "post"),
("alt-svc", "clear"),
("authorization", ""),
(
"content-security-policy",
"script-src 'none'; object-src 'none'; base-uri 'none'",
),
("early-data", "1"),
("expect-ct", ""),
("forwarded", ""),
("if-range", ""),
("origin", ""),
("purpose", "prefetch"),
("server", ""),
("timing-allow-origin", "*"),
("upgrade-insecure-requests", "1"),
("user-agent", ""),
("x-forwarded-for", ""),
("x-frame-options", "deny"),
("x-frame-options", "sameorigin"),
];

View file

@ -0,0 +1,254 @@
use std::{
collections::HashMap,
fmt::Debug,
ops::{Deref, DerefMut},
sync::Arc,
};
use bytes::{Buf, BufMut, BytesMut};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use super::{Frame, StreamUni, VarInt, VarIntUnexpectedEnd};
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct Setting(pub VarInt);
impl Setting {
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, VarIntUnexpectedEnd> {
Ok(Setting(VarInt::decode(buf)?))
}
pub fn encode<B: BufMut>(&self, buf: &mut B) {
self.0.encode(buf)
}
// Reference : https://datatracker.ietf.org/doc/html/rfc9114#section-7.2.4.1
pub fn is_grease(&self) -> bool {
let val = self.0.into_inner();
if val < 0x21 {
return false;
}
#[allow(unknown_lints, clippy::manual_is_multiple_of)]
{
(val - 0x21) % 0x1f == 0
}
}
}
impl Debug for Setting {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
Setting::QPACK_MAX_TABLE_CAPACITY => write!(f, "QPACK_MAX_TABLE_CAPACITY"),
Setting::MAX_FIELD_SECTION_SIZE => write!(f, "MAX_FIELD_SECTION_SIZE"),
Setting::QPACK_BLOCKED_STREAMS => write!(f, "QPACK_BLOCKED_STREAMS"),
Setting::ENABLE_CONNECT_PROTOCOL => write!(f, "ENABLE_CONNECT_PROTOCOL"),
Setting::ENABLE_DATAGRAM => write!(f, "ENABLE_DATAGRAM"),
Setting::ENABLE_DATAGRAM_DEPRECATED => write!(f, "ENABLE_DATAGRAM_DEPRECATED"),
Setting::WEBTRANSPORT_ENABLE_DEPRECATED => write!(f, "WEBTRANSPORT_ENABLE_DEPRECATED"),
Setting::WEBTRANSPORT_MAX_SESSIONS_DEPRECATED => {
write!(f, "WEBTRANSPORT_MAX_SESSIONS_DEPRECATED")
}
Setting::WEBTRANSPORT_MAX_SESSIONS => write!(f, "WEBTRANSPORT_MAX_SESSIONS"),
x if x.is_grease() => write!(f, "GREASE SETTING [{:x?}]", x.0.into_inner()),
x => write!(f, "UNKNOWN_SETTING [{:x?}]", x.0.into_inner()),
}
}
}
macro_rules! settings {
{$($name:ident = $val:expr,)*} => {
impl Setting {
$(pub const $name: Setting = Setting(VarInt::from_u32($val));)*
}
}
}
settings! {
// These are for HTTP/3 and we can ignore them
QPACK_MAX_TABLE_CAPACITY = 0x1, // default is 0, which disables QPACK dynamic table
MAX_FIELD_SECTION_SIZE = 0x6,
QPACK_BLOCKED_STREAMS = 0x7,
// Both of these are required for WebTransport
ENABLE_CONNECT_PROTOCOL = 0x8,
ENABLE_DATAGRAM = 0x33,
ENABLE_DATAGRAM_DEPRECATED = 0xFFD277, // still used by Chrome
// Removed in draft 06
WEBTRANSPORT_ENABLE_DEPRECATED = 0x2b603742,
WEBTRANSPORT_MAX_SESSIONS_DEPRECATED = 0x2b603743,
// New way to enable WebTransport
WEBTRANSPORT_MAX_SESSIONS = 0xc671706a,
}
#[derive(Error, Debug, Clone)]
pub enum SettingsError {
#[error("unexpected end of input")]
UnexpectedEnd,
#[error("unexpected stream type {0:?}")]
UnexpectedStreamType(StreamUni),
#[error("unexpected frame {0:?}")]
UnexpectedFrame(Frame),
#[error("invalid size")]
InvalidSize,
#[error("io error: {0}")]
Io(Arc<std::io::Error>),
}
impl From<std::io::Error> for SettingsError {
fn from(err: std::io::Error) -> Self {
SettingsError::Io(Arc::new(err))
}
}
// A map of settings to values.
#[derive(Default, Debug)]
pub struct Settings(HashMap<Setting, VarInt>);
impl Settings {
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, SettingsError> {
let typ = StreamUni::decode(buf).map_err(|_| SettingsError::UnexpectedEnd)?;
if typ != StreamUni::CONTROL {
return Err(SettingsError::UnexpectedStreamType(typ));
}
let (typ, mut data) = Frame::read(buf).map_err(|_| SettingsError::UnexpectedEnd)?;
if typ != Frame::SETTINGS {
return Err(SettingsError::UnexpectedFrame(typ));
}
let mut settings = Settings::default();
while data.has_remaining() {
// These return a different error because retrying won't help.
let id = Setting::decode(&mut data).map_err(|_| SettingsError::InvalidSize)?;
let value = VarInt::decode(&mut data).map_err(|_| SettingsError::InvalidSize)?;
// Only add if it is not grease
if !id.is_grease() {
settings.0.insert(id, value);
}
}
Ok(settings)
}
pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, SettingsError> {
let mut buf = Vec::new();
loop {
if stream.read_buf(&mut buf).await? == 0 {
return Err(SettingsError::UnexpectedEnd);
}
// Look at the buffer we've already read.
let mut limit = std::io::Cursor::new(&buf);
match Settings::decode(&mut limit) {
Ok(settings) => return Ok(settings),
Err(SettingsError::UnexpectedEnd) => continue, // More data needed.
Err(e) => return Err(e),
};
}
}
pub fn encode<B: BufMut>(&self, buf: &mut B) {
StreamUni::CONTROL.encode(buf);
Frame::SETTINGS.encode(buf);
// Encode to a temporary buffer so we can learn the length.
// TODO avoid doing this, just use a fixed size varint.
let mut tmp = Vec::new();
for (id, value) in &self.0 {
id.encode(&mut tmp);
value.encode(&mut tmp);
}
VarInt::from_u32(tmp.len() as u32).encode(buf);
buf.put_slice(&tmp);
}
pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), SettingsError> {
// TODO avoid allocating to the heap
let mut buf = BytesMut::new();
self.encode(&mut buf);
stream.write_all_buf(&mut buf).await?;
Ok(())
}
pub fn enable_webtransport(&mut self, max_sessions: u32) {
let max = VarInt::from_u32(max_sessions);
self.insert(Setting::ENABLE_CONNECT_PROTOCOL, VarInt::from_u32(1));
self.insert(Setting::ENABLE_DATAGRAM, VarInt::from_u32(1));
self.insert(Setting::ENABLE_DATAGRAM_DEPRECATED, VarInt::from_u32(1));
self.insert(Setting::WEBTRANSPORT_MAX_SESSIONS, max);
// TODO remove when 07 is in the wild
self.insert(Setting::WEBTRANSPORT_MAX_SESSIONS_DEPRECATED, max);
self.insert(Setting::WEBTRANSPORT_ENABLE_DEPRECATED, VarInt::from_u32(1));
}
// Returns the maximum number of sessions supported.
pub fn supports_webtransport(&self) -> u64 {
// Sent by Chrome 114.0.5735.198 (July 19, 2023)
// Setting(1): 65536, // qpack_max_table_capacity
// Setting(6): 16384, // max_field_section_size
// Setting(7): 100, // qpack_blocked_streams
// Setting(51): 1, // enable_datagram
// Setting(16765559): 1 // enable_datagram_deprecated
// Setting(727725890): 1, // webtransport_max_sessions_deprecated
// Setting(4445614305): 454654587, // grease
// NOTE: The presence of ENABLE_WEBTRANSPORT implies ENABLE_CONNECT is supported.
let datagram = self
.get(&Setting::ENABLE_DATAGRAM)
.or(self.get(&Setting::ENABLE_DATAGRAM_DEPRECATED))
.map(|v| v.into_inner());
if datagram != Some(1) {
return 0;
}
// The deprecated (before draft-07) way of enabling WebTransport was to send two parameters.
// Both would send ENABLE=1 and the server would send MAX_SESSIONS=N to limit the sessions.
// Now both just send MAX_SESSIONS, and a non-zero value means WebTransport is enabled.
if let Some(max) = self.get(&Setting::WEBTRANSPORT_MAX_SESSIONS) {
return max.into_inner();
}
let enabled = self
.get(&Setting::WEBTRANSPORT_ENABLE_DEPRECATED)
.map(|v| v.into_inner());
if enabled != Some(1) {
return 0;
}
// Only the server is allowed to set this one, so if it's None we assume it's 1.
self.get(&Setting::WEBTRANSPORT_MAX_SESSIONS_DEPRECATED)
.map(|v| v.into_inner())
.unwrap_or(1)
}
}
impl Deref for Settings {
type Target = HashMap<Setting, VarInt>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Settings {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

View file

@ -0,0 +1,45 @@
use bytes::{Buf, BufMut};
use super::{VarInt, VarIntUnexpectedEnd};
// Sent as the first bytes of a unidirectional stream to identify the type.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct StreamUni(pub VarInt);
impl StreamUni {
pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, VarIntUnexpectedEnd> {
Ok(StreamUni(VarInt::decode(buf)?))
}
pub fn encode<B: BufMut>(&self, buf: &mut B) {
self.0.encode(buf)
}
pub fn is_grease(&self) -> bool {
let val = self.0.into_inner();
if val < 0x21 {
return false;
}
#[allow(unknown_lints, clippy::manual_is_multiple_of)]
{
(val - 0x21) % 0x1f == 0
}
}
}
macro_rules! streams_uni {
{$($name:ident = $val:expr,)*} => {
impl StreamUni {
$(pub const $name: StreamUni = StreamUni(VarInt::from_u32($val));)*
}
}
}
streams_uni! {
CONTROL = 0x00,
PUSH = 0x01,
QPACK_ENCODER = 0x02,
QPACK_DECODER = 0x03,
WEBTRANSPORT = 0x54,
}

View file

@ -0,0 +1,233 @@
// Based on Quinn: https://github.com/quinn-rs/quinn/tree/main/quinn-proto/src
// Licensed under Apache-2.0 OR MIT
use std::{convert::TryInto, fmt, io::Cursor};
use bytes::{Buf, BufMut};
use thiserror::Error;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
/// An integer less than 2^62
///
/// Values of this type are suitable for encoding as QUIC variable-length integer.
// It would be neat if we could express to Rust that the top two bits are available for use as enum
// discriminants
#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct VarInt(pub(crate) u64);
impl VarInt {
/// The largest representable value
pub const MAX: Self = Self((1 << 62) - 1);
/// The largest encoded value length
pub const MAX_SIZE: usize = 8;
/// Construct a `VarInt` infallibly
pub const fn from_u32(x: u32) -> Self {
Self(x as u64)
}
/// Succeeds if `x` < 2^62
pub fn from_u64(x: u64) -> Result<Self, VarIntBoundsExceeded> {
if x < 2u64.pow(62) {
Ok(Self(x))
} else {
Err(VarIntBoundsExceeded)
}
}
/// Create a VarInt without ensuring it's in range
///
/// # Safety
///
/// `x` must be less than 2^62.
pub const unsafe fn from_u64_unchecked(x: u64) -> Self {
Self(x)
}
/// Extract the integer value
pub const fn into_inner(self) -> u64 {
self.0
}
/// Compute the number of bytes needed to encode this value
pub fn size(self) -> usize {
let x = self.0;
if x < 2u64.pow(6) {
1
} else if x < 2u64.pow(14) {
2
} else if x < 2u64.pow(30) {
4
} else if x < 2u64.pow(62) {
8
} else {
unreachable!("malformed VarInt");
}
}
}
impl From<VarInt> for u64 {
fn from(x: VarInt) -> Self {
x.0
}
}
impl From<u8> for VarInt {
fn from(x: u8) -> Self {
Self(x.into())
}
}
impl From<u16> for VarInt {
fn from(x: u16) -> Self {
Self(x.into())
}
}
impl From<u32> for VarInt {
fn from(x: u32) -> Self {
Self(x.into())
}
}
impl std::convert::TryFrom<u64> for VarInt {
type Error = VarIntBoundsExceeded;
/// Succeeds iff `x` < 2^62
fn try_from(x: u64) -> Result<Self, VarIntBoundsExceeded> {
Self::from_u64(x)
}
}
impl std::convert::TryFrom<u128> for VarInt {
type Error = VarIntBoundsExceeded;
/// Succeeds iff `x` < 2^62
fn try_from(x: u128) -> Result<Self, VarIntBoundsExceeded> {
Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?)
}
}
impl std::convert::TryFrom<usize> for VarInt {
type Error = VarIntBoundsExceeded;
/// Succeeds iff `x` < 2^62
fn try_from(x: usize) -> Result<Self, VarIntBoundsExceeded> {
Self::try_from(x as u64)
}
}
impl fmt::Debug for VarInt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl fmt::Display for VarInt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl VarInt {
pub fn decode<B: Buf>(r: &mut B) -> Result<Self, VarIntUnexpectedEnd> {
if !r.has_remaining() {
return Err(VarIntUnexpectedEnd);
}
let mut buf = [0; 8];
buf[0] = r.get_u8();
let tag = buf[0] >> 6;
buf[0] &= 0b0011_1111;
let x = match tag {
0b00 => u64::from(buf[0]),
0b01 => {
if r.remaining() < 1 {
return Err(VarIntUnexpectedEnd);
}
r.copy_to_slice(&mut buf[1..2]);
u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
}
0b10 => {
if r.remaining() < 3 {
return Err(VarIntUnexpectedEnd);
}
r.copy_to_slice(&mut buf[1..4]);
u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
}
0b11 => {
if r.remaining() < 7 {
return Err(VarIntUnexpectedEnd);
}
r.copy_to_slice(&mut buf[1..8]);
u64::from_be_bytes(buf)
}
_ => unreachable!(),
};
Ok(Self(x))
}
// Read a varint from the stream.
pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, VarIntUnexpectedEnd> {
// 8 bytes is the max size of a varint
let mut buf = [0; 8];
// Read the first byte because it includes the length.
stream
.read_exact(&mut buf[0..1])
.await
.map_err(|_| VarIntUnexpectedEnd)?;
// 0b00 = 1, 0b01 = 2, 0b10 = 4, 0b11 = 8
let size = 1 << (buf[0] >> 6);
stream
.read_exact(&mut buf[1..size])
.await
.map_err(|_| VarIntUnexpectedEnd)?;
// Use a cursor to read the varint on the stack.
let mut cursor = Cursor::new(&buf[..size]);
let v = VarInt::decode(&mut cursor).unwrap();
Ok(v)
}
pub fn encode<B: BufMut>(&self, w: &mut B) {
let x = self.0;
if x < 2u64.pow(6) {
w.put_u8(x as u8);
} else if x < 2u64.pow(14) {
w.put_u16((0b01 << 14) | x as u16);
} else if x < 2u64.pow(30) {
w.put_u32((0b10 << 30) | x as u32);
} else if x < 2u64.pow(62) {
w.put_u64((0b11 << 62) | x);
} else {
unreachable!("malformed VarInt")
}
}
pub async fn write<S: AsyncWrite + Unpin>(
&self,
stream: &mut S,
) -> Result<(), VarIntUnexpectedEnd> {
// Super jaink but keeps everything on the stack.
let mut buf = [0u8; 8];
let mut cursor: &mut [u8] = &mut buf;
self.encode(&mut cursor);
let size = 8 - cursor.len();
let mut cursor = &buf[..size];
stream
.write_all_buf(&mut cursor)
.await
.map_err(|_| VarIntUnexpectedEnd)?;
Ok(())
}
}
/// Error returned when constructing a `VarInt` from a value >= 2^62
#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
#[error("value too large for varint encoding")]
pub struct VarIntBoundsExceeded;
#[derive(Error, Debug, Copy, Clone, Eq, PartialEq)]
#[error("unexpected end of buffer")]
pub struct VarIntUnexpectedEnd;