From 523c601dc323cda80ddd600fe27fb7b9478ff67a Mon Sep 17 00:00:00 2001 From: "every.channel" Date: Wed, 18 Feb 2026 01:28:57 -0800 Subject: [PATCH] wt: patch web-transport-proto header interop for Cloudflare relay --- .gitignore | 2 + Cargo.lock | 2 - Cargo.toml | 7 + .../ECP-0063-cloudflare-moq-webtransport.md | 1 + third_party/web-transport-proto/CHANGELOG.md | 90 +++ third_party/web-transport-proto/Cargo.toml | 59 ++ third_party/web-transport-proto/README.md | 6 + .../web-transport-proto/src/capsule.rs | 376 +++++++++++ .../web-transport-proto/src/connect.rs | 402 +++++++++++ third_party/web-transport-proto/src/error.rs | 18 + third_party/web-transport-proto/src/frame.rs | 65 ++ .../web-transport-proto/src/huffman.rs | 364 ++++++++++ third_party/web-transport-proto/src/lib.rs | 18 + third_party/web-transport-proto/src/qpack.rs | 627 ++++++++++++++++++ .../web-transport-proto/src/settings.rs | 254 +++++++ third_party/web-transport-proto/src/stream.rs | 45 ++ third_party/web-transport-proto/src/varint.rs | 233 +++++++ 17 files changed, 2567 insertions(+), 2 deletions(-) create mode 100644 third_party/web-transport-proto/CHANGELOG.md create mode 100644 third_party/web-transport-proto/Cargo.toml create mode 100644 third_party/web-transport-proto/README.md create mode 100644 third_party/web-transport-proto/src/capsule.rs create mode 100644 third_party/web-transport-proto/src/connect.rs create mode 100644 third_party/web-transport-proto/src/error.rs create mode 100644 third_party/web-transport-proto/src/frame.rs create mode 100644 third_party/web-transport-proto/src/huffman.rs create mode 100644 third_party/web-transport-proto/src/lib.rs create mode 100644 third_party/web-transport-proto/src/qpack.rs create mode 100644 third_party/web-transport-proto/src/settings.rs create mode 100644 third_party/web-transport-proto/src/stream.rs create mode 100644 third_party/web-transport-proto/src/varint.rs diff --git a/.gitignore b/.gitignore index 45c7b5e..f347bb5 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,8 @@ result third_party/* !third_party/iroh-live !third_party/iroh-org +!third_party/web-transport-proto +!third_party/web-transport-proto/** third_party/iroh-org/* !third_party/iroh-org/iroh-gossip diff --git a/Cargo.lock b/Cargo.lock index 7eed7bb..b577eac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8312,8 +8312,6 @@ dependencies = [ [[package]] name = "web-transport-proto" version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17633ea7058419f87cbb7f341ab75ac5c1d6d187c154b0bd4c87539e66f4c4e4" dependencies = [ "bytes", "http", diff --git a/Cargo.toml b/Cargo.toml index a1d7c77..eb761a2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,3 +34,10 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" tracing = "0.1" tracing-subscriber = "0.3" + +[patch.crates-io] +# Cloudflare's relay uses standard WebTransport subprotocol negotiation. The upstream +# `web-transport-proto` crate (used by `web-transport-quinn`) currently uses legacy +# header names (`wt-available-protocols` / `wt-protocol`), which prevents negotiating +# `moqt-*` and causes the relay to close after MoQ SETUP. +web-transport-proto = { path = "third_party/web-transport-proto" } diff --git a/evolution/proposals/ECP-0063-cloudflare-moq-webtransport.md b/evolution/proposals/ECP-0063-cloudflare-moq-webtransport.md index 443db33..fe314bc 100644 --- a/evolution/proposals/ECP-0063-cloudflare-moq-webtransport.md +++ b/evolution/proposals/ECP-0063-cloudflare-moq-webtransport.md @@ -60,6 +60,7 @@ newer draft implementations. Implementation choice: - Cloudflare's relay preview currently does **not** support `ANNOUNCE` (namespace-style publishing). `ec-node wt-publish` uses the `moq-lite` publish model via `moq-native` and `moq-mux` (fMP4 ingestion) for Cloudflare relay compatibility. - On NixOS deployments, we disable `moq-native`'s WebSocket fallback (`MOQ_CLIENT_WEBSOCKET_ENABLED=false`) to ensure WebTransport (QUIC) is used. This avoids the WebSocket path occasionally "winning" the race and then failing MoQ negotiation against the Cloudflare relay, causing rapid reconnect loops. +- For Cloudflare relay interop, we patch `web-transport-proto` to send and accept the standard WebTransport subprotocol negotiation header (`sec-webtransport-protocol`) in addition to the legacy `wt-available-protocols`/`wt-protocol` headers. Without subprotocol negotiation, the relay may not select a `moqt-*` protocol and can close the session immediately after MoQ `SETUP`. ### Share link diff --git a/third_party/web-transport-proto/CHANGELOG.md b/third_party/web-transport-proto/CHANGELOG.md new file mode 100644 index 0000000..38045a6 --- /dev/null +++ b/third_party/web-transport-proto/CHANGELOG.md @@ -0,0 +1,90 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.2.7](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.6...web-transport-proto-v0.2.7) - 2025-09-03 + +### Other + +- Rename the repo. ([#94](https://github.com/moq-dev/web-transport/pull/94)) +- Fix clippy warnings ([#91](https://github.com/moq-dev/web-transport/pull/91)) +- Add support for session closed capsule ([#86](https://github.com/moq-dev/web-transport/pull/86)) +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.5.2](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.5.1...web-transport-proto-v0.5.2) - 2026-02-13 + +### Other + +- Fix some API mistakes. ([#163](https://github.com/moq-dev/web-transport/pull/163)) + +## [0.5.1](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.5.0...web-transport-proto-v0.5.1) - 2026-02-11 + +### Other + +- Async accept ([#159](https://github.com/moq-dev/web-transport/pull/159)) + +## [0.5.0](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.4.0...web-transport-proto-v0.5.0) - 2026-02-10 + +### Other + +- Fix capsule protocol handling ([#152](https://github.com/moq-dev/web-transport/pull/152)) + +## [0.4.0](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.3.1...web-transport-proto-v0.4.0) - 2026-01-23 + +### Other + +- Sub-protocol negotiation + breaking API changes ([#143](https://github.com/moq-dev/web-transport/pull/143)) + +## [0.3.1](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.3.0...web-transport-proto-v0.3.1) - 2026-01-07 + +### Other + +- Rename the repo into a new org. ([#132](https://github.com/moq-dev/web-transport/pull/132)) + +## [0.2.8](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.7...web-transport-proto-v0.2.8) - 2025-10-17 + +### Other + +- Make traits compatible with WASM ([#107](https://github.com/moq-dev/web-transport/pull/107)) +- Check all feature combinations ([#102](https://github.com/moq-dev/web-transport/pull/102)) + +## [0.2.6](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.5...web-transport-proto-v0.2.6) - 2025-05-15 + +### Other + +- Add (generic) support for learning when a stream is closed. ([#73](https://github.com/moq-dev/web-transport/pull/73)) +- Add url query to CONNECT :path request ([#70](https://github.com/moq-dev/web-transport/pull/70)) + +## [0.2.5](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.4...web-transport-proto-v0.2.5) - 2025-03-26 + +### Other + +- Added Ring feature flag ([#68](https://github.com/moq-dev/web-transport/pull/68)) + +## [0.2.4](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.3...web-transport-proto-v0.2.4) - 2025-01-15 + +### Other + +- Bump some deps. ([#55](https://github.com/moq-dev/web-transport/pull/55)) +- Clippy fixes. ([#53](https://github.com/moq-dev/web-transport/pull/53)) + +## [0.2.3](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.2...web-transport-proto-v0.2.3) - 2024-09-02 + +### Other +- Don't set the N bit for literals. ([#41](https://github.com/moq-dev/web-transport/pull/41)) + +## [0.2.2](https://github.com/moq-dev/web-transport/compare/web-transport-proto-v0.2.1...web-transport-proto-v0.2.2) - 2024-08-15 + +### Other +- Some more documentation. ([#34](https://github.com/moq-dev/web-transport/pull/34)) diff --git a/third_party/web-transport-proto/Cargo.toml b/third_party/web-transport-proto/Cargo.toml new file mode 100644 index 0000000..e723106 --- /dev/null +++ b/third_party/web-transport-proto/Cargo.toml @@ -0,0 +1,59 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2021" +name = "web-transport-proto" +version = "0.5.2" +authors = ["Luke Curley"] +build = false +autolib = false +autobins = false +autoexamples = false +autotests = false +autobenches = false +description = "WebTransport core protocol" +readme = "README.md" +keywords = [ + "quic", + "http3", + "webtransport", +] +categories = [ + "network-programming", + "web-programming", +] +license = "MIT OR Apache-2.0" +repository = "https://github.com/moq-dev/web-transport" + +[lib] +name = "web_transport_proto" +path = "src/lib.rs" + +[dependencies.bytes] +version = "1" + +[dependencies.http] +version = "1" + +[dependencies.sfv] +version = "0.14" + +[dependencies.thiserror] +version = "2" + +[dependencies.tokio] +version = "1" +features = ["io-util"] +default-features = false + +[dependencies.url] +version = "2" diff --git a/third_party/web-transport-proto/README.md b/third_party/web-transport-proto/README.md new file mode 100644 index 0000000..fc9bcc3 --- /dev/null +++ b/third_party/web-transport-proto/README.md @@ -0,0 +1,6 @@ +[![crates.io](https://img.shields.io/crates/v/web-transport-proto)](https://crates.io/crates/web-transport-proto) +[![discord](https://img.shields.io/discord/1124083992740761730)](https://discord.gg/FCYF3p99mr) + +# web-transport-proto +The gritty WebTransport protocol implementation. +Not meant to be used directly, but as a dependency for [web-transport-quinn](../web-transport-quinn) and [web-transport-wasm](../web-transport-wasm). diff --git a/third_party/web-transport-proto/src/capsule.rs b/third_party/web-transport-proto/src/capsule.rs new file mode 100644 index 0000000..cd1b1ab --- /dev/null +++ b/third_party/web-transport-proto/src/capsule.rs @@ -0,0 +1,376 @@ +use std::sync::Arc; + +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::{VarInt, VarIntUnexpectedEnd}; + +// The spec (draft-ietf-webtrans-http3-06) says the type is 0x2843, which would +// varint-encode to 0x68 0x43. However, actual wire data shows 0x43 0x28 which +// decodes to 808. There may be a discrepancy in implementations or specs. +// Using 0x2843 as specified in the standard. +const CLOSE_WEBTRANSPORT_SESSION_TYPE: u64 = 0x2843; +const MAX_MESSAGE_SIZE: usize = 1024; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Capsule { + CloseWebTransportSession { code: u32, reason: String }, + Grease { num: u64 }, + Unknown { typ: VarInt, payload: Bytes }, +} + +impl Capsule { + pub fn decode(buf: &mut B) -> Result { + let typ = VarInt::decode(buf)?; + let length = VarInt::decode(buf)?; + + let mut payload = buf.take(length.into_inner() as usize); + + // Check declared length first - reject immediately if too large + if payload.limit() > MAX_MESSAGE_SIZE { + return Err(CapsuleError::MessageTooLong); + } + + // Then check if all declared bytes are buffered + if payload.remaining() < payload.limit() { + return Err(CapsuleError::UnexpectedEnd); + } + + let typ_val = typ.into_inner(); + + if let Some(num) = is_grease(typ_val) { + payload.advance(payload.remaining()); + return Ok(Self::Grease { num }); + } + + match typ_val { + CLOSE_WEBTRANSPORT_SESSION_TYPE => { + if payload.remaining() < 4 { + return Err(CapsuleError::UnexpectedEnd); + } + + let error_code = payload.get_u32(); + + let message_len = payload.remaining(); + if message_len > MAX_MESSAGE_SIZE { + return Err(CapsuleError::MessageTooLong); + } + + let mut message_bytes = vec![0u8; message_len]; + payload.copy_to_slice(&mut message_bytes); + + let error_message = + String::from_utf8(message_bytes).map_err(|_| CapsuleError::InvalidUtf8)?; + + Ok(Self::CloseWebTransportSession { + code: error_code, + reason: error_message, + }) + } + _ => { + let mut payload_bytes = vec![0u8; payload.remaining()]; + payload.copy_to_slice(&mut payload_bytes); + Ok(Self::Unknown { + typ, + payload: Bytes::from(payload_bytes), + }) + } + } + } + + pub async fn read(stream: &mut S) -> Result, CapsuleError> { + let mut buf = Vec::new(); + loop { + if stream.read_buf(&mut buf).await? == 0 { + if buf.is_empty() { + return Ok(None); + } + return Err(CapsuleError::UnexpectedEnd); + } + + let mut limit = std::io::Cursor::new(&buf); + match Self::decode(&mut limit) { + Ok(capsule) => return Ok(Some(capsule)), + Err(CapsuleError::UnexpectedEnd) => continue, + Err(e) => return Err(e), + } + } + } + + pub fn encode(&self, buf: &mut B) { + match self { + Self::CloseWebTransportSession { + code: error_code, + reason: error_message, + } => { + // Encode the capsule type + VarInt::from_u64(CLOSE_WEBTRANSPORT_SESSION_TYPE) + .unwrap() + .encode(buf); + + // Calculate and encode the length + let length = 4 + error_message.len(); + VarInt::from_u32(length as u32).encode(buf); + + // Encode the error code (32-bit) + buf.put_u32(*error_code); + + // Encode the error message + buf.put_slice(error_message.as_bytes()); + } + Self::Grease { num } => { + // Generate grease type: 0x29 * N + 0x17 + // Check for overflow + let grease_type = num + .checked_mul(0x29) + .and_then(|v| v.checked_add(0x17)) + .expect("grease num value would overflow u64"); + + VarInt::from_u64(grease_type).unwrap().encode(buf); + + // Grease capsules have zero-length payload + VarInt::from_u32(0).encode(buf); + } + Self::Unknown { typ, payload } => { + // Encode the capsule type + typ.encode(buf); + + // Encode the length + VarInt::try_from(payload.len()).unwrap().encode(buf); + + // Encode the payload + buf.put_slice(payload); + } + } + } + + pub async fn write(&self, stream: &mut S) -> Result<(), CapsuleError> { + let mut buf = BytesMut::new(); + self.encode(&mut buf); + stream.write_all_buf(&mut buf).await?; + Ok(()) + } +} + +// RFC 9297 Section 5.4: Capsule types of the form 0x29 * N + 0x17 +// Returns Some(N) if the value is a grease type, None otherwise +fn is_grease(val: u64) -> Option { + if val < 0x17 { + return None; + } + let num = (val - 0x17) / 0x29; + if val == 0x29 * num + 0x17 { + Some(num) + } else { + None + } +} + +#[derive(Debug, Clone, thiserror::Error)] +pub enum CapsuleError { + #[error("unexpected end of buffer")] + UnexpectedEnd, + + #[error("invalid UTF-8")] + InvalidUtf8, + + #[error("message too long")] + MessageTooLong, + + #[error("unknown capsule type: {0:?}")] + UnknownType(VarInt), + + #[error("varint decode error: {0:?}")] + VarInt(#[from] VarIntUnexpectedEnd), + + #[error("io error: {0}")] + Io(Arc), +} + +impl From for CapsuleError { + fn from(err: std::io::Error) -> Self { + CapsuleError::Io(Arc::new(err)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::Bytes; + + #[test] + fn test_close_webtransport_session_decode() { + // Test with spec-compliant type 0x2843 (encodes as 0x68 0x43) + let mut data = Vec::new(); + VarInt::from_u64(0x2843).unwrap().encode(&mut data); + VarInt::from_u32(8).encode(&mut data); + data.extend_from_slice(b"\x00\x00\x01\xa4test"); + + let mut buf = data.as_slice(); + let capsule = Capsule::decode(&mut buf).unwrap(); + + match capsule { + Capsule::CloseWebTransportSession { + code: error_code, + reason: error_message, + } => { + assert_eq!(error_code, 420); + assert_eq!(error_message, "test"); + } + _ => panic!("Expected CloseWebTransportSession"), + } + + assert_eq!(buf.len(), 0); // All bytes consumed + } + + #[test] + fn test_close_webtransport_session_encode() { + let capsule = Capsule::CloseWebTransportSession { + code: 420, + reason: "test".to_string(), + }; + + let mut buf = Vec::new(); + capsule.encode(&mut buf); + + // Expected format: type(0x2843 as varint = 0x68 0x43) + length(8 as varint) + error_code(420 as u32 BE) + "test" + assert_eq!(buf, b"\x68\x43\x08\x00\x00\x01\xa4test"); + } + + #[test] + fn test_close_webtransport_session_roundtrip() { + let original = Capsule::CloseWebTransportSession { + code: 12345, + reason: "Connection closed by application".to_string(), + }; + + let mut buf = Vec::new(); + original.encode(&mut buf); + + let mut read_buf = buf.as_slice(); + let decoded = Capsule::decode(&mut read_buf).unwrap(); + + assert_eq!(original, decoded); + assert_eq!(read_buf.len(), 0); // All bytes consumed + } + + #[test] + fn test_empty_error_message() { + let capsule = Capsule::CloseWebTransportSession { + code: 0, + reason: String::new(), + }; + + let mut buf = Vec::new(); + capsule.encode(&mut buf); + + // Type(0x2843 as varint = 0x68 0x43) + Length(4) + error_code(0) + assert_eq!(buf, b"\x68\x43\x04\x00\x00\x00\x00"); + + let mut read_buf = buf.as_slice(); + let decoded = Capsule::decode(&mut read_buf).unwrap(); + assert_eq!(capsule, decoded); + } + + #[test] + fn test_invalid_utf8() { + // Create a capsule with invalid UTF-8 in the message + let mut data = Vec::new(); + VarInt::from_u64(0x2843).unwrap().encode(&mut data); // type + VarInt::from_u32(5).encode(&mut data); // length(5) + data.extend_from_slice(b"\x00\x00\x00\x00"); // error_code(0) + data.push(0xFF); // Invalid UTF-8 byte + + let mut buf = data.as_slice(); + let result = Capsule::decode(&mut buf); + assert!(matches!(result, Err(CapsuleError::InvalidUtf8))); + } + + #[test] + fn test_truncated_error_code() { + // Capsule with length indicating 3 bytes but error code needs 4 + let mut data = Vec::new(); + VarInt::from_u64(0x2843).unwrap().encode(&mut data); // type + VarInt::from_u32(3).encode(&mut data); // length(3) + data.extend_from_slice(b"\x00\x00\x00"); // incomplete error code + + let mut buf = data.as_slice(); + let result = Capsule::decode(&mut buf); + assert!(matches!(result, Err(CapsuleError::UnexpectedEnd))); + } + + #[test] + fn test_unknown_capsule() { + // Test handling of unknown capsule types + let unknown_type = 0x1234u64; + let payload_data = b"unknown payload"; + + let mut data = Vec::new(); + VarInt::from_u64(unknown_type).unwrap().encode(&mut data); + VarInt::from_u32(payload_data.len() as u32).encode(&mut data); + data.extend_from_slice(payload_data); + + let mut buf = data.as_slice(); + let capsule = Capsule::decode(&mut buf).unwrap(); + + match capsule { + Capsule::Unknown { typ, payload } => { + assert_eq!(typ.into_inner(), unknown_type); + assert_eq!(payload.as_ref(), payload_data); + } + _ => panic!("Expected Unknown capsule"), + } + } + + #[test] + fn test_unknown_capsule_roundtrip() { + let capsule = Capsule::Unknown { + typ: VarInt::from_u64(0x9999).unwrap(), + payload: Bytes::from("test payload"), + }; + + let mut buf = Vec::new(); + capsule.encode(&mut buf); + + let mut read_buf = buf.as_slice(); + let decoded = Capsule::decode(&mut read_buf).unwrap(); + + assert_eq!(capsule, decoded); + assert_eq!(read_buf.len(), 0); + } + + #[test] + fn test_grease_capsule() { + // Test grease formula: 0x29 * N + 0x17 + for num in [0, 1, 5, 100, 1000] { + let capsule = Capsule::Grease { num }; + + let mut buf = Vec::new(); + capsule.encode(&mut buf); + + let mut read_buf = buf.as_slice(); + let decoded = Capsule::decode(&mut read_buf).unwrap(); + + assert_eq!(capsule, decoded); + assert_eq!(read_buf.len(), 0); + } + } + + #[test] + fn test_grease_values() { + // Verify specific grease type values + assert_eq!(is_grease(0x17), Some(0)); // N=0 + assert_eq!(is_grease(0x40), Some(1)); // N=1: 0x29 + 0x17 = 0x40 + assert_eq!(is_grease(0x69), Some(2)); // N=2: 0x29*2 + 0x17 = 0x69 + assert_eq!(is_grease(0x18), None); // Not a grease value + assert_eq!(is_grease(0x41), None); // Not a grease value + } + + #[test] + #[should_panic(expected = "grease num value would overflow u64")] + fn test_grease_overflow() { + let capsule = Capsule::Grease { num: u64::MAX }; + let mut buf = Vec::new(); + capsule.encode(&mut buf); + } +} diff --git a/third_party/web-transport-proto/src/connect.rs b/third_party/web-transport-proto/src/connect.rs new file mode 100644 index 0000000..8320c8b --- /dev/null +++ b/third_party/web-transport-proto/src/connect.rs @@ -0,0 +1,402 @@ +use std::{str::FromStr, sync::Arc}; + +use bytes::{Buf, BufMut, BytesMut}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use url::Url; + +use super::{qpack, Frame, VarInt}; + +use thiserror::Error; + +// Errors that can occur during the connect request. +#[derive(Error, Debug, Clone)] +pub enum ConnectError { + #[error("unexpected end of input")] + UnexpectedEnd, + + #[error("qpack error")] + QpackError(#[from] qpack::DecodeError), + + #[error("unexpected frame {0:?}")] + UnexpectedFrame(Frame), + + #[error("invalid method")] + InvalidMethod, + + #[error("invalid url")] + InvalidUrl(#[from] url::ParseError), + + #[error("invalid status")] + InvalidStatus, + + #[error("expected 200, got: {0:?}")] + WrongStatus(Option), + + #[error("expected connect, got: {0:?}")] + WrongMethod(Option), + + #[error("expected https, got: {0:?}")] + WrongScheme(Option), + + #[error("expected authority header")] + WrongAuthority, + + #[error("expected webtransport, got: {0:?}")] + WrongProtocol(Option), + + #[error("expected path header")] + WrongPath, + + #[error("invalid protocol header")] + InvalidProtocol, + + #[error("structured field error: {0}")] + StructuredFieldError(Arc), + + #[error("non-200 status: {0:?}")] + ErrorStatus(http::StatusCode), + + #[error("io error: {0}")] + Io(Arc), +} + +impl From for ConnectError { + fn from(err: std::io::Error) -> Self { + ConnectError::Io(Arc::new(err)) + } +} + +impl From for ConnectError { + fn from(err: sfv::Error) -> Self { + ConnectError::StructuredFieldError(Arc::new(err)) + } +} + +/// A CONNECT request to initiate a WebTransport session. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub struct ConnectRequest { + /// The URL to connect to. + pub url: Url, + + /// The subprotocols requested (if any). + pub protocols: Vec, +} + +impl ConnectRequest { + pub fn new(url: impl Into) -> Self { + Self { + url: url.into(), + protocols: Vec::new(), + } + } + + pub fn with_protocol(mut self, protocol: impl Into) -> Self { + self.protocols.push(protocol.into()); + self + } + + pub fn with_protocols(mut self, protocols: impl IntoIterator) -> Self { + self.protocols.extend(protocols); + self + } + + pub fn decode(buf: &mut B) -> Result { + let (typ, mut data) = Frame::read(buf).map_err(|_| ConnectError::UnexpectedEnd)?; + if typ != Frame::HEADERS { + return Err(ConnectError::UnexpectedFrame(typ)); + } + + // We no longer return UnexpectedEnd because we know the buffer should be large enough. + + let headers = qpack::Headers::decode(&mut data)?; + + let scheme = match headers.get(":scheme") { + Some("https") => "https", + Some(scheme) => Err(ConnectError::WrongScheme(Some(scheme.to_string())))?, + None => return Err(ConnectError::WrongScheme(None)), + }; + + let authority = headers + .get(":authority") + .ok_or(ConnectError::WrongAuthority)?; + + let path_and_query = headers.get(":path").ok_or(ConnectError::WrongPath)?; + + let method = headers.get(":method"); + match method + .map(|method| method.try_into().map_err(|_| ConnectError::InvalidMethod)) + .transpose()? + { + Some(http::Method::CONNECT) => (), + o => return Err(ConnectError::WrongMethod(o)), + }; + + let protocol = headers.get(":protocol"); + if protocol != Some("webtransport") { + return Err(ConnectError::WrongProtocol(protocol.map(|s| s.to_string()))); + } + + let protocols = headers + .get(protocol_negotiation::AVAILABLE_NAME_STD) + .or_else(|| headers.get(protocol_negotiation::AVAILABLE_NAME)) + .map(protocol_negotiation::decode_list) + .transpose() + .map_err(|_| ConnectError::InvalidProtocol)? + .unwrap_or_default(); + + let url = Url::parse(&format!("{scheme}://{authority}{path_and_query}"))?; + + Ok(Self { url, protocols }) + } + + pub async fn read(stream: &mut S) -> Result { + let mut buf = Vec::new(); + loop { + if stream.read_buf(&mut buf).await? == 0 { + return Err(ConnectError::UnexpectedEnd); + } + + let mut limit = std::io::Cursor::new(&buf); + match Self::decode(&mut limit) { + Ok(request) => return Ok(request), + Err(ConnectError::UnexpectedEnd) => continue, + Err(e) => return Err(e), + } + } + } + + pub fn encode(&self, buf: &mut B) -> Result<(), ConnectError> { + let mut headers = qpack::Headers::default(); + headers.set(":method", "CONNECT"); + headers.set(":scheme", self.url.scheme()); + headers.set(":authority", self.url.authority()); + let path_and_query = match self.url.query() { + Some(query) => format!("{}?{}", self.url.path(), query), + None => self.url.path().to_string(), + }; + headers.set(":path", &path_and_query); + headers.set(":protocol", "webtransport"); + + if !self.protocols.is_empty() { + let encoded = protocol_negotiation::encode_list(&self.protocols)?; + // Send both the standard and legacy header names to maximize interop. + headers.set(protocol_negotiation::AVAILABLE_NAME_STD, &encoded); + headers.set(protocol_negotiation::AVAILABLE_NAME, &encoded); + } + + // Use a temporary buffer so we can compute the size. + let mut tmp = Vec::new(); + headers.encode(&mut tmp); + let size = VarInt::from_u32(tmp.len() as u32); + + Frame::HEADERS.encode(buf); + size.encode(buf); + buf.put_slice(&tmp); + + Ok(()) + } + + pub async fn write(&self, stream: &mut S) -> Result<(), ConnectError> { + let mut buf = BytesMut::new(); + self.encode(&mut buf)?; + stream.write_all_buf(&mut buf).await?; + Ok(()) + } +} + +impl From for ConnectRequest { + fn from(url: Url) -> Self { + Self { + url, + protocols: Vec::new(), + } + } +} + +/// A CONNECT response to accept or reject a WebTransport session. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub struct ConnectResponse { + /// The status code of the response. + pub status: http::status::StatusCode, + + /// The subprotocol selected by the server, if any + pub protocol: Option, +} + +impl ConnectResponse { + pub const OK: Self = Self { + status: http::StatusCode::OK, + protocol: None, + }; + + pub fn new(status: http::StatusCode) -> Self { + Self { + status, + protocol: None, + } + } + + pub fn with_protocol(mut self, protocol: impl Into) -> Self { + self.protocol = Some(protocol.into()); + self + } + + pub fn decode(buf: &mut B) -> Result { + let (typ, mut data) = Frame::read(buf).map_err(|_| ConnectError::UnexpectedEnd)?; + if typ != Frame::HEADERS { + return Err(ConnectError::UnexpectedFrame(typ)); + } + + let headers = qpack::Headers::decode(&mut data)?; + + let status = match headers + .get(":status") + .map(|status| { + http::StatusCode::from_str(status).map_err(|_| ConnectError::InvalidStatus) + }) + .transpose()? + { + Some(status) if status.is_success() => status, + o => return Err(ConnectError::WrongStatus(o)), + }; + + let protocol = headers + .get(protocol_negotiation::SELECTED_NAME_STD) + .or_else(|| headers.get(protocol_negotiation::SELECTED_NAME)) + .map(protocol_negotiation::decode_item) + .transpose() + .map_err(|_| ConnectError::InvalidProtocol)?; + + Ok(Self { status, protocol }) + } + + pub async fn read(stream: &mut S) -> Result { + let mut buf = Vec::new(); + loop { + if stream.read_buf(&mut buf).await? == 0 { + return Err(ConnectError::UnexpectedEnd); + } + + let mut limit = std::io::Cursor::new(&buf); + match Self::decode(&mut limit) { + Ok(response) => return Ok(response), + Err(ConnectError::UnexpectedEnd) => continue, + Err(e) => return Err(e), + } + } + } + + pub fn encode(&self, buf: &mut B) -> Result<(), ConnectError> { + let mut headers = qpack::Headers::default(); + headers.set(":status", self.status.as_str()); + headers.set("sec-webtransport-http3-draft", "draft02"); + + if let Some(protocol) = self.protocol.as_ref() { + let encoded = protocol_negotiation::encode_item(protocol)?; + // Send both the standard and legacy header names to maximize interop. + headers.set(protocol_negotiation::SELECTED_NAME_STD, &encoded); + headers.set(protocol_negotiation::SELECTED_NAME, &encoded); + } + + // Use a temporary buffer so we can compute the size. + let mut tmp = Vec::new(); + headers.encode(&mut tmp); + let size = VarInt::from_u32(tmp.len() as u32); + + Frame::HEADERS.encode(buf); + size.encode(buf); + buf.put_slice(&tmp); + + Ok(()) + } + + pub async fn write(&self, stream: &mut S) -> Result<(), ConnectError> { + let mut buf = BytesMut::new(); + self.encode(&mut buf)?; + stream.write_all_buf(&mut buf).await?; + Ok(()) + } +} + +impl Default for ConnectResponse { + fn default() -> Self { + Self::OK + } +} + +impl From for ConnectResponse { + fn from(status: http::StatusCode) -> Self { + Self { + status, + protocol: None, + } + } +} + +mod protocol_negotiation { + //! WebTransport sub-protocol negotiation using RFC 8941 Structured Fields, + //! + //! according to [draft 14](https://www.ietf.org/archive/id/draft-ietf-webtrans-http3-14.html#section-3.3) + + use sfv::{Item, ItemSerializer, List, ListEntry, ListSerializer, Parser, StringRef}; + + use crate::ConnectError; + + /// The header name for the available protocols, sent within the WebTransport Connect request. + /// Legacy name kept for compatibility with older servers/clients. + pub const AVAILABLE_NAME: &str = "wt-available-protocols"; + /// Standard WebTransport sub-protocol negotiation header name. + pub const AVAILABLE_NAME_STD: &str = "sec-webtransport-protocol"; + + /// The header name for the selected protocol, sent within the WebTransport Connect response. + /// Legacy name kept for compatibility with older servers/clients. + pub const SELECTED_NAME: &str = "wt-protocol"; + /// Standard WebTransport sub-protocol negotiation header name. + pub const SELECTED_NAME_STD: &str = "sec-webtransport-protocol"; + + /// Encode a list of protocol strings as an RFC 8941 Structured Field List. + pub fn encode_list(protocols: &[String]) -> Result { + let mut serializer = ListSerializer::new(); + for protocol in protocols { + let s = StringRef::from_str(protocol)?; + let _ = serializer.bare_item(s); + } + serializer.finish().ok_or(ConnectError::InvalidProtocol) + } + + /// Decode an RFC 8941 Structured Field List of strings. + pub fn decode_list(value: &str) -> Result, ConnectError> { + let list = Parser::new(value).parse::()?; + + list.iter() + .map(|entry| match entry { + ListEntry::Item(item) => Ok(item + .bare_item + .as_string() + .ok_or(ConnectError::InvalidProtocol)? + .as_str() + .to_string()), + _ => Err(ConnectError::InvalidProtocol), + }) + .collect() + } + + /// Encode a single string as an RFC 8941 Structured Field Item. + pub fn encode_item(protocol: &str) -> Result { + let s = StringRef::from_str(protocol)?; + Ok(ItemSerializer::new().bare_item(s).finish()) + } + + /// Decode an RFC 8941 Structured Field Item (single string). + pub fn decode_item(value: &str) -> Result { + let item = Parser::new(value).parse::()?; + Ok(item + .bare_item + .as_string() + .ok_or(ConnectError::InvalidProtocol)? + .as_str() + .to_string()) + } +} diff --git a/third_party/web-transport-proto/src/error.rs b/third_party/web-transport-proto/src/error.rs new file mode 100644 index 0000000..30ff7c0 --- /dev/null +++ b/third_party/web-transport-proto/src/error.rs @@ -0,0 +1,18 @@ +// WebTransport shares with HTTP/3, so we can't start at 0 or use the full VarInt. +const ERROR_FIRST: u64 = 0x52e4a40fa8db; +const ERROR_LAST: u64 = 0x52e5ac983162; + +pub const fn error_from_http3(code: u64) -> Option { + if code < ERROR_FIRST || code > ERROR_LAST { + return None; + } + + let code = code - ERROR_FIRST; + let code = code - code / 0x1f; + + Some(code as u32) +} + +pub const fn error_to_http3(code: u32) -> u64 { + ERROR_FIRST + code as u64 + code as u64 / 0x1e +} diff --git a/third_party/web-transport-proto/src/frame.rs b/third_party/web-transport-proto/src/frame.rs new file mode 100644 index 0000000..a9fe454 --- /dev/null +++ b/third_party/web-transport-proto/src/frame.rs @@ -0,0 +1,65 @@ +use bytes::{Buf, BufMut}; + +use crate::{VarInt, VarIntUnexpectedEnd}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct Frame(pub VarInt); + +impl Frame { + pub fn decode(buf: &mut B) -> Result { + let typ = VarInt::decode(buf)?; + Ok(Frame(typ)) + } + + pub fn encode(&self, buf: &mut B) { + self.0.encode(buf) + } + + pub fn is_grease(&self) -> bool { + let val = self.0.into_inner(); + if val < 0x21 { + return false; + } + + #[allow(unknown_lints, clippy::manual_is_multiple_of)] + { + (val - 0x21) % 0x1f == 0 + } + } + + pub fn read( + buf: &mut B, + ) -> Result<(Frame, bytes::buf::Take<&mut B>), VarIntUnexpectedEnd> { + let typ = Frame::decode(buf)?; + let size = VarInt::decode(buf)?; + + let mut limit = Buf::take(buf, size.into_inner() as usize); + if limit.remaining() < limit.limit() { + return Err(VarIntUnexpectedEnd); + } + + // Try again if this is a GREASE frame we need to ignore + if typ.is_grease() { + limit.advance(limit.limit()); + return Self::read(limit.into_inner()); + } + + Ok((typ, limit)) + } +} + +macro_rules! frames { + {$($name:ident = $val:expr,)*} => { + impl Frame { + $(pub const $name: Frame = Frame(VarInt::from_u32($val));)* + } + } +} + +// Sent at the start of a bidirectional stream. +frames! { + DATA = 0x00, + HEADERS = 0x01, + SETTINGS = 0x04, + WEBTRANSPORT = 0x41, +} diff --git a/third_party/web-transport-proto/src/huffman.rs b/third_party/web-transport-proto/src/huffman.rs new file mode 100644 index 0000000..333ed3a --- /dev/null +++ b/third_party/web-transport-proto/src/huffman.rs @@ -0,0 +1,364 @@ +// Huffman encoding is a compression technique that replaces common strings with shorter codes. +// Ugh I wish we didn't have to implement this, but the other endpoint is allowed to use it. + +// Taken from https://github.com/hyperium/h3/blob/master/h3/src/qpack/prefix_string/decode.rs +// License: MIT + +#[derive(Debug, Default, PartialEq, Clone)] +pub struct BitWindow { + pub byte: u32, + pub bit: u32, + pub count: u32, +} + +impl BitWindow { + pub fn new() -> Self { + Self::default() + } + + pub fn forwards(&mut self, step: u32) { + self.bit += self.count; + + self.byte += self.bit / 8; + self.bit %= 8; + + self.count = step; + } + + pub fn opposite_bit_window(&self) -> BitWindow { + BitWindow { + byte: self.byte, + bit: self.bit, + count: 8 - (self.bit % 8), + } + } +} + +use thiserror::Error; + +#[derive(Error, Debug, PartialEq, Clone)] +pub enum Error { + #[error("missing bits: {0:?}")] + MissingBits(BitWindow), + + #[error("unhandled: {0:?} {1:?}")] + Unhandled(BitWindow, usize), +} + +#[derive(Clone, Debug)] +enum DecodeValue { + Partial(&'static HuffmanDecoder), + Sym(u8), +} + +#[derive(Clone, Debug)] +struct HuffmanDecoder { + lookup: u32, + table: &'static [DecodeValue], +} + +impl HuffmanDecoder { + fn check_eof(&self, bit_pos: &mut BitWindow, input: &[u8]) -> Result, Error> { + use std::cmp::Ordering; + match ((bit_pos.byte + 1) as usize).cmp(&input.len()) { + // Position is out-of-range + Ordering::Greater => { + return Ok(None); + } + // Position is on the last byte + Ordering::Equal => { + let side = bit_pos.opposite_bit_window(); + + let rest = match read_bits(input, side.byte, side.bit, side.count) { + Ok(x) => x, + Err(()) => { + return Err(Error::MissingBits(side)); + } + }; + + let eof_filler = ((2u16 << (side.count - 1)) - 1) as u8; + if rest & eof_filler == eof_filler { + return Ok(None); + } + } + Ordering::Less => {} + } + Err(Error::MissingBits(bit_pos.clone())) + } + + fn fetch_value(&self, bit_pos: &mut BitWindow, input: &[u8]) -> Result, Error> { + match read_bits(input, bit_pos.byte, bit_pos.bit, bit_pos.count) { + Ok(value) => Ok(Some(value as u32)), + Err(()) => self.check_eof(bit_pos, input), + } + } + + fn decode_next(&self, bit_pos: &mut BitWindow, input: &[u8]) -> Result, Error> { + bit_pos.forwards(self.lookup); + + let value = match self.fetch_value(bit_pos, input) { + Ok(Some(value)) => value as usize, + Ok(None) => return Ok(None), + Err(err) => return Err(err), + }; + + let at_value = match (self.table).get(value) { + Some(x) => x, + None => return Err(Error::Unhandled(bit_pos.clone(), value)), + }; + + match at_value { + DecodeValue::Sym(x) => Ok(Some(*x)), + DecodeValue::Partial(d) => d.decode_next(bit_pos, input), + } + } +} + +/// Read `len` bits from the `src` slice at the specified position +/// +/// Never read more than 8 bits at a time. `bit_offset` may be larger than 8. +fn read_bits(src: &[u8], mut byte_offset: u32, mut bit_offset: u32, len: u32) -> Result { + if len == 0 || len > 8 || src.len() as u32 * 8 < (byte_offset * 8) + bit_offset + len { + return Err(()); + } + + // Deal with `bit_offset` > 8 + byte_offset += bit_offset / 8; + bit_offset -= (bit_offset / 8) * 8; + + Ok(if bit_offset + len <= 8 { + // Read all the bits from a single byte + (src[byte_offset as usize] << bit_offset) >> (8 - len) + } else { + // The range of bits spans over 2 bytes + let mut result = (src[byte_offset as usize] as u16) << 8; + result |= src[byte_offset as usize + 1] as u16; + ((result << bit_offset) >> (16 - len)) as u8 + }) +} + +macro_rules! bits_decode { + // general way + ( + lookup: $count:expr, [ + $($sym:expr,)* + $(=> $sub:ident,)* ] + ) => { + HuffmanDecoder { + lookup: $count, + table: &[ + $( DecodeValue::Sym($sym as u8), )* + $( DecodeValue::Partial(&$sub), )* + ] + } + }; + // 2-final + ( $first:expr, $second:expr ) => { + HuffmanDecoder { + lookup: 1, + table: &[ + DecodeValue::Sym($first as u8), + DecodeValue::Sym($second as u8), + ] + } + }; + // 4-final + ( $first:expr, $second:expr, $third:expr, $fourth:expr ) => { + HuffmanDecoder { + lookup: 2, + table: &[ + DecodeValue::Sym($first as u8), + DecodeValue::Sym($second as u8), + DecodeValue::Sym($third as u8), + DecodeValue::Sym($fourth as u8), + ] + } + }; + // 2-final-partial + ( $first:expr, => $second:ident ) => { + HuffmanDecoder { + lookup: 1, + table: &[ + DecodeValue::Sym($first as u8), + DecodeValue::Partial(&$second), + ] + } + }; + // 2-partial + ( => $first:ident, => $second:ident ) => { + HuffmanDecoder { + lookup: 1, + table: &[ + DecodeValue::Partial(&$first), + DecodeValue::Partial(&$second), + ] + } + }; + // 4-partial + ( => $first:ident, => $second:ident, + => $third:ident, => $fourth:ident ) => { + HuffmanDecoder { + lookup: 2, + table: &[ + DecodeValue::Partial(&$first), + DecodeValue::Partial(&$second), + DecodeValue::Partial(&$third), + DecodeValue::Partial(&$fourth), + ] + } + }; + [ $( $name:ident => ( $($value:tt)* ), )* ] => { + $( const $name: HuffmanDecoder = bits_decode!( $( $value )* ); )* + }; +} + +#[rustfmt::skip] +bits_decode![ + HPACK_STRING => ( + lookup: 5, [ b'0', b'1', b'2', b'a', b'c', b'e', b'i', b'o', b's', b't', + => END0_01010, => END0_01011, => END0_01100, => END0_01101, + => END0_01110, => END0_01111, => END0_10000, => END0_10001, + => END0_10010, => END0_10011, => END0_10100, => END0_10101, + => END0_10110, => END0_10111, => END0_11000, => END0_11001, + => END0_11010, => END0_11011, => END0_11100, => END0_11101, + => END0_11110, => END0_11111, + ]), + END0_01010 => ( 32, b'%'), + END0_01011 => (b'-', b'.'), + END0_01100 => (b'/', b'3'), + END0_01101 => (b'4', b'5'), + END0_01110 => (b'6', b'7'), + END0_01111 => (b'8', b'9'), + END0_10000 => (b'=', b'A'), + END0_10001 => (b'_', b'b'), + END0_10010 => (b'd', b'f'), + END0_10011 => (b'g', b'h'), + END0_10100 => (b'l', b'm'), + END0_10101 => (b'n', b'p'), + END0_10110 => (b'r', b'u'), + END0_10111 => (b':', b'B', b'C', b'D'), + END0_11000 => (b'E', b'F', b'G', b'H'), + END0_11001 => (b'I', b'J', b'K', b'L'), + END0_11010 => (b'M', b'N', b'O', b'P'), + END0_11011 => (b'Q', b'R', b'S', b'T'), + END0_11100 => (b'U', b'V', b'W', b'Y'), + END0_11101 => (b'j', b'k', b'q', b'v'), + END0_11110 => (b'w', b'x', b'y', b'z'), + END0_11111 => (=> END5_00, => END5_01, => END5_10, => END5_11), + END5_00 => (b'&', b'*'), + END5_01 => (b',', 59), + END5_10 => (b'X', b'Z'), + END5_11 => (=> END7_0, => END7_1), + END7_0 => (b'!', b'"', b'(', b')'), + END7_1 => (=> END8_0, => END8_1), + END8_0 => (b'?', => END9A_1), + END9A_1 => (b'\'', b'+'), + END8_1 => (lookup: 2, [b'|', => END9B_01, => END9B_10, => END9B_11,]), + END9B_01 => (b'#', b'>'), + END9B_10 => (0, b'$', b'@', b'['), + END9B_11 => (lookup: 2, [b']', b'~', => END13_10, => END13_11,]), + END13_10 => (b'^', b'}'), + END13_11 => (=> END14_0, => END14_1), + END14_0 => (b'<', b'`'), + END14_1 => (b'{', => END15_1), + END15_1 => + (lookup: 4, [ b'\\', 195, 208, => END19_0011, + => END19_0100, => END19_0101, => END19_0110, => END19_0111, + => END19_1000, => END19_1001, => END19_1010, => END19_1011, + => END19_1100, => END19_1101, => END19_1110, => END19_1111, + ]), + END19_0011 => (128, 130), + END19_0100 => (131, 162), + END19_0101 => (184, 194), + END19_0110 => (224, 226), + END19_0111 => (153, 161, 167, 172), + END19_1000 => (176, 177, 179, 209), + END19_1001 => (216, 217, 227, 229), + END19_1010 => (lookup: 2, [230, => END19_1010_01, => END19_1010_10, + => END19_1010_11,]), + END19_1010_01 => (129, 132), + END19_1010_10 => (133, 134), + END19_1010_11 => (136, 146), + END19_1011 => (lookup: 3, [154, 156, 160, 163, 164, 169, 170, 173,]), + END19_1100 => (lookup: 3, [178, 181, 185, 186, 187, 189, 190, 196,]), + END19_1101 => (lookup: 3, [198, 228, 232, 233, + => END23A_100, => END23A_101, + => END23A_110, => END23A_111,]), + END23A_100 => ( 1, 135), + END23A_101 => (137, 138), + END23A_110 => (139, 140), + END23A_111 => (141, 143), + END19_1110 => (lookup: 4, [147, 149, 150, 151, 152, 155, 157, 158, + 165, 166, 168, 174, 175, 180, 182, 183,]), + END19_1111 => (lookup: 4, [188, 191, 197, 231, 239, + => END23B_0101, => END23B_0110, => END23B_0111, + => END23B_1000, => END23B_1001, => END23B_1010, + => END23B_1011, => END23B_1100, => END23B_1101, + => END23B_1110, => END23B_1111,]), + END23B_0101 => ( 9, 142), + END23B_0110 => (144, 145), + END23B_0111 => (148, 159), + END23B_1000 => (171, 206), + END23B_1001 => (215, 225), + END23B_1010 => (236, 237), + END23B_1011 => (199, 207, 234, 235), + END23B_1100 => (lookup: 3, [192, 193, 200, 201, 202, 205, 210, 213,]), + END23B_1101 => (lookup: 3, [218, 219, 238, 240, 242, 243, 255, + => END27A_111,]), + END27A_111 => (203, 204), + END23B_1110 => (lookup: 4, [211, 212, 214, 221, 222, 223, 241, 244, + 245, 246, 247, 248, 250, 251, 252, 253,]), + END23B_1111 => (lookup: 4, [ 254, => END27B_0001, => END27B_0010, + => END27B_0011, => END27B_0100, => END27B_0101, + => END27B_0110, => END27B_0111, => END27B_1000, + => END27B_1001, => END27B_1010, => END27B_1011, + => END27B_1100, => END27B_1101, => END27B_1110, + => END27B_1111,]), + END27B_0001 => (2, 3), + END27B_0010 => (4, 5), + END27B_0011 => (6, 7), + END27B_0100 => (8, 11), + END27B_0101 => (12, 14), + END27B_0110 => (15, 16), + END27B_0111 => (17, 18), + END27B_1000 => (19, 20), + END27B_1001 => (21, 23), + END27B_1010 => (24, 25), + END27B_1011 => (26, 27), + END27B_1100 => (28, 29), + END27B_1101 => (30, 31), + END27B_1110 => (127, 220), + END27B_1111 => (lookup: 1, [249, => END31_1,]), + END31_1 => (lookup: 2, [10, 13, 22, => EOF,]), + EOF => (lookup: 8, []), + ]; + +pub struct DecodeIter<'a> { + bit_pos: BitWindow, + content: &'a Vec, +} + +impl Iterator for DecodeIter<'_> { + type Item = Result; + + fn next(&mut self) -> Option { + match HPACK_STRING.decode_next(&mut self.bit_pos, self.content) { + Ok(Some(x)) => Some(Ok(x)), + Err(err) => Some(Err(err)), + Ok(None) => None, + } + } +} + +pub trait HpackStringDecode { + fn hpack_decode(&self) -> DecodeIter<'_>; +} + +impl HpackStringDecode for Vec { + fn hpack_decode(&self) -> DecodeIter<'_> { + DecodeIter { + bit_pos: BitWindow::new(), + content: self, + } + } +} diff --git a/third_party/web-transport-proto/src/lib.rs b/third_party/web-transport-proto/src/lib.rs new file mode 100644 index 0000000..ff213bb --- /dev/null +++ b/third_party/web-transport-proto/src/lib.rs @@ -0,0 +1,18 @@ +mod capsule; +mod connect; +mod error; +mod frame; +mod settings; +mod stream; +mod varint; + +pub use capsule::*; +pub use connect::*; +pub use error::*; +pub use frame::*; +pub use settings::*; +pub use stream::*; +pub use varint::*; + +mod huffman; +mod qpack; diff --git a/third_party/web-transport-proto/src/qpack.rs b/third_party/web-transport-proto/src/qpack.rs new file mode 100644 index 0000000..0f3c7cd --- /dev/null +++ b/third_party/web-transport-proto/src/qpack.rs @@ -0,0 +1,627 @@ +// This is a minimal QPACK implementation that only supports the static table and literals. +// By refusing to acknowledge the QPACK encoder, we can avoid implementing the dynamic table altogether. +// This is not recommended for a full HTTP/3 implementation but it's literally more efficient for handling a single WebTransport CONNECT request. + +use std::collections::HashMap; + +use bytes::{Buf, BufMut}; + +use super::huffman::{self, HpackStringDecode}; +use thiserror::Error; + +#[derive(Error, Debug, Clone)] +pub enum DecodeError { + #[error("unexpected end of input")] + UnexpectedEnd, + + #[error("varint bounds exceeded")] + BoundsExceeded, + + #[error("dynamic references not supported")] + DynamicEntry, + + #[error("unknown entry")] + UnknownEntry, + + #[error("huffman decoding error")] + HuffmanError(#[from] huffman::Error), + + #[error("invalid utf8 header")] // technically not required by the HTTP spec + Utf8Error(#[from] std::str::Utf8Error), +} + +#[cfg(target_pointer_width = "64")] +const MAX_POWER: usize = 10 * 7; + +#[cfg(target_pointer_width = "32")] +const MAX_POWER: usize = 5 * 7; + +// Simple QPACK implementation that ONLY supports the static table and literals. +#[derive(Debug, Default)] +pub struct Headers { + fields: HashMap, +} + +impl Headers { + pub fn get(&self, name: &str) -> Option<&str> { + self.fields.get(name).map(|v| v.as_str()) + } + + pub fn set(&mut self, name: &str, value: &str) { + self.fields.insert(name.to_string(), value.to_string()); + } + + pub fn decode(mut buf: &mut B) -> Result { + // We don't support dynamic entries so we can skip these. + let (_, _insert_count) = decode_prefix(buf, 8)?; + let (_sign, _delta_base) = decode_prefix(buf, 7)?; + + let mut fields = HashMap::new(); + while buf.has_remaining() { + // Read the first byte; + let peek = buf.get_u8(); + + // Read the byte again by chaining Bufs. + let first = [peek]; + let mut chain = first.chain(buf); + + // See: https://www.rfc-editor.org/rfc/rfc9204.html#section-4.5.2 + // This is over-engineered, LUL + let (name, value) = match peek & 0b1100_0000 { + // Indexed line field from static table + 0b1100_0000 => Self::decode_index(&mut chain)?, + + // Indexed line field from dynamic table + 0b1000_0000 => return Err(DecodeError::DynamicEntry), + + _ => match peek & 0b1101_0000 { + // Indexed with literal name ref from static table + 0b0101_0000 => Self::decode_literal_value(&mut chain)?, + + // Indexed with literal name ref from dynamic table + 0b0100_0000 => return Err(DecodeError::DynamicEntry), + + // Literal + _ if peek & 0b1110_0000 == 0b0010_0000 => Self::decode_literal(&mut chain)?, + + _ => match peek & 0b1111_0000 { + // Indexed with post base + 0b0001_0000 => return Err(DecodeError::DynamicEntry), + + // Indexed with post base name ref + 0b0000_0000 => return Err(DecodeError::DynamicEntry), + + // ugh + _ => return Err(DecodeError::UnknownEntry), + }, + }, + }; + + fields.insert(name, value); + + // Get the buffer back. + (_, buf) = chain.into_inner(); + } + + Ok(Self { fields }) + } + + fn decode_index(buf: &mut B) -> Result<(String, String), DecodeError> { + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 1 | 1 | Index (6+) | + +---+---+-----------------------+ + */ + + let (_, index) = decode_prefix(buf, 6)?; + let (name, value) = StaticTable::get(index)?; + Ok((name.to_string(), value.to_string())) + } + + fn decode_literal_value(buf: &mut B) -> Result<(String, String), DecodeError> { + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 0 | 1 | N | 1 |Name Index (4+)| + +---+---+---+---+---------------+ + | H | Value Length (7+) | + +---+---------------------------+ + | Value String (Length bytes) | + +-------------------------------+ + */ + + let (_, name) = decode_prefix(buf, 4)?; + let (name, _) = StaticTable::get(name)?; + + let value = decode_string(buf, 8)?; + let value = std::str::from_utf8(&value)?; + + Ok((name.to_string(), value.to_string())) + } + + fn decode_literal(buf: &mut B) -> Result<(String, String), DecodeError> { + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 0 | 0 | 1 | N | H |NameLen(3+)| + +---+---+---+---+---+-----------+ + | Name String (Length bytes) | + +---+---------------------------+ + | H | Value Length (7+) | + +---+---------------------------+ + | Value String (Length bytes) | + +-------------------------------+ + */ + + let name = decode_string(buf, 4)?; + let name = std::str::from_utf8(&name)?; + + let value = decode_string(buf, 8)?; + let value = std::str::from_utf8(&value)?; + + Ok((name.to_string(), value.to_string())) + } + + pub fn encode(&self, buf: &mut B) { + // We don't support dynamic entries so we can skip these. + encode_prefix(buf, 8, 0, 0); + encode_prefix(buf, 7, 0, 0); + + // We must encode pseudo-headers first. + // https://datatracker.ietf.org/doc/html/rfc9114#section-4.1.2 + let mut headers: Vec<_> = self.fields.iter().collect(); + headers.sort_by_key(|&(key, _)| !key.starts_with(':')); + + for (name, value) in headers.iter() { + if let Some(index) = StaticTable::find(name, value) { + Self::encode_index(buf, index) + } else if let Some(index) = StaticTable::find_name(name) { + Self::encode_literal_value(buf, index, value) + } else { + Self::encode_literal(buf, name, value) + } + } + } + + fn encode_index(buf: &mut B, index: usize) { + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 1 | 1 | Index (6+) | + +---+---+-----------------------+ + */ + + encode_prefix(buf, 6, 0b11, index); + } + + fn encode_literal_value(buf: &mut B, name: usize, value: &str) { + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 0 | 1 | N | 1 |Name Index (4+)| + +---+---+---+---+---------------+ + | H | Value Length (7+) | + +---+---------------------------+ + | Value String (Length bytes) | + +-------------------------------+ + */ + + encode_prefix(buf, 4, 0b0101, name); + encode_prefix(buf, 7, 0b0, value.len()); + + buf.put_slice(value.as_bytes()); + } + + fn encode_literal(buf: &mut B, name: &str, value: &str) { + /* + 0 1 2 3 4 5 6 7 + +---+---+---+---+---+---+---+---+ + | 0 | 0 | 1 | N | H |NameLen(3+)| + +---+---+---+---+---+-----------+ + | Name String (Length bytes) | + +---+---------------------------+ + | H | Value Length (7+) | + +---+---------------------------+ + | Value String (Length bytes) | + +-------------------------------+ + */ + + encode_prefix(buf, 3, 0b00100, name.len()); + buf.put_slice(name.as_bytes()); + + encode_prefix(buf, 7, 0b0, value.len()); + buf.put_slice(value.as_bytes()); + } +} + +// An integer that uses a fixed number of bits, otherwise a variable number of bytes if it's too large. +// https://www.rfc-editor.org/rfc/rfc7541#section-5.1 + +// Based on : https://github.com/hyperium/h3/blob/master/h3/src/qpack/prefix_int.rs +// License: MIT + +pub fn decode_prefix(buf: &mut B, size: u8) -> Result<(u8, usize), DecodeError> { + assert!(size <= 8); + + if !buf.has_remaining() { + return Err(DecodeError::UnexpectedEnd); + } + + let mut first = buf.get_u8(); + + // NOTE: following casts to u8 intend to trim the most significant bits, they are used as a + // workaround for shiftoverflow errors when size == 8. + let flags = ((first as usize) >> size) as u8; + let mask = 0xFF >> (8 - size); + first &= mask; + + // if first < 2usize.pow(size) - 1 + if first < mask { + return Ok((flags, first as usize)); + } + + let mut value = mask as usize; + let mut power = 0usize; + loop { + if !buf.has_remaining() { + return Err(DecodeError::UnexpectedEnd); + } + + let byte = buf.get_u8() as usize; + value += (byte & 127) << power; + power += 7; + + if byte & 128 == 0 { + break; + } + + if power >= MAX_POWER { + return Err(DecodeError::BoundsExceeded); + } + } + + Ok((flags, value)) +} + +pub fn encode_prefix(buf: &mut B, size: u8, flags: u8, value: usize) { + assert!(size > 0 && size <= 8); + + // NOTE: following casts to u8 intend to trim the most significant bits, they are used as a + // workaround for shiftoverflow errors when size == 8. + let mask = !(0xFF << size) as u8; + let flags = ((flags as usize) << size) as u8; + + // if value < 2usize.pow(size) - 1 + if value < (mask as usize) { + buf.put_u8(flags | value as u8); + return; + } + + buf.put_u8(mask | flags); + let mut remaining = value - mask as usize; + + while remaining >= 128 { + let rest = (remaining % 128) as u8; + buf.put_u8(rest + 128); + remaining /= 128; + } + + buf.put_u8(remaining as u8); +} + +pub fn decode_string(buf: &mut B, size: u8) -> Result, DecodeError> { + if !buf.has_remaining() { + return Err(DecodeError::UnexpectedEnd); + } + + let (flags, len) = decode_prefix(buf, size - 1)?; + if buf.remaining() < len { + return Err(DecodeError::UnexpectedEnd); + } + + let payload = buf.copy_to_bytes(len); + let value: Vec = if flags & 1 == 0 { + payload.into_iter().collect() + } else { + let mut decoded = Vec::new(); + for byte in payload.into_iter().collect::>().hpack_decode() { + decoded.push(byte?); + } + decoded + }; + Ok(value) +} + +// Based on https://github.com/hyperium/h3/blob/master/h3/src/qpack/static_.rs +// I switched over to str because it's nicer in Rust... even though HTTP doesn't use utf8. +struct StaticTable {} + +impl StaticTable { + pub fn get(index: usize) -> Result<(&'static str, &'static str), DecodeError> { + match PREDEFINED_HEADERS.get(index) { + Some(v) => Ok(*v), + None => Err(DecodeError::UnknownEntry), + } + } + + // TODO combine find and find_name to do a single lookup + pub fn find(name: &str, value: &str) -> Option { + match (name, value) { + (":authority", "") => Some(0), + (":path", "/") => Some(1), + ("age", "0") => Some(2), + ("content-disposition", "") => Some(3), + ("content-length", "0") => Some(4), + ("cookie", "") => Some(5), + ("date", "") => Some(6), + ("etag", "") => Some(7), + ("if-modified-since", "") => Some(8), + ("if-none-match", "") => Some(9), + ("last-modified", "") => Some(10), + ("link", "") => Some(11), + ("location", "") => Some(12), + ("referer", "") => Some(13), + ("set-cookie", "") => Some(14), + (":method", "CONNECT") => Some(15), + (":method", "DELETE") => Some(16), + (":method", "GET") => Some(17), + (":method", "HEAD") => Some(18), + (":method", "OPTIONS") => Some(19), + (":method", "POST") => Some(20), + (":method", "PUT") => Some(21), + (":scheme", "http") => Some(22), + (":scheme", "https") => Some(23), + (":status", "103") => Some(24), + (":status", "200") => Some(25), + (":status", "304") => Some(26), + (":status", "404") => Some(27), + (":status", "503") => Some(28), + ("accept", "*/*") => Some(29), + ("accept", "application/dns-message") => Some(30), + ("accept-encoding", "gzip, deflate, br") => Some(31), + ("accept-ranges", "bytes") => Some(32), + ("access-control-allow-headers", "cache-control") => Some(33), + ("access-control-allow-headers", "content-type") => Some(34), + ("access-control-allow-origin", "*") => Some(35), + ("cache-control", "max-age=0") => Some(36), + ("cache-control", "max-age=2592000") => Some(37), + ("cache-control", "max-age=604800") => Some(38), + ("cache-control", "no-cache") => Some(39), + ("cache-control", "no-store") => Some(40), + ("cache-control", "public, max-age=31536000") => Some(41), + ("content-encoding", "br") => Some(42), + ("content-encoding", "gzip") => Some(43), + ("content-type", "application/dns-message") => Some(44), + ("content-type", "application/javascript") => Some(45), + ("content-type", "application/json") => Some(46), + ("content-type", "application/x-www-form-urlencoded") => Some(47), + ("content-type", "image/gif") => Some(48), + ("content-type", "image/jpeg") => Some(49), + ("content-type", "image/png") => Some(50), + ("content-type", "text/css") => Some(51), + ("content-type", "text/html; charset=utf-8") => Some(52), + ("content-type", "text/plain") => Some(53), + ("content-type", "text/plain;charset=utf-8") => Some(54), + ("range", "bytes=0-") => Some(55), + ("strict-transport-security", "max-age=31536000") => Some(56), + ("strict-transport-security", "max-age=31536000; includesubdomains") => Some(57), + ("strict-transport-security", "max-age=31536000; includesubdomains; preload") => { + Some(58) + } + ("vary", "accept-encoding") => Some(59), + ("vary", "origin") => Some(60), + ("x-content-type-options", "nosniff") => Some(61), + ("x-xss-protection", "1; mode=block") => Some(62), + (":status", "100") => Some(63), + (":status", "204") => Some(64), + (":status", "206") => Some(65), + (":status", "302") => Some(66), + (":status", "400") => Some(67), + (":status", "403") => Some(68), + (":status", "421") => Some(69), + (":status", "425") => Some(70), + (":status", "500") => Some(71), + ("accept-language", "") => Some(72), + ("access-control-allow-credentials", "FALSE") => Some(73), + ("access-control-allow-credentials", "TRUE") => Some(74), + ("access-control-allow-headers", "*") => Some(75), + ("access-control-allow-methods", "get") => Some(76), + ("access-control-allow-methods", "get, post, options") => Some(77), + ("access-control-allow-methods", "options") => Some(78), + ("access-control-expose-headers", "content-length") => Some(79), + ("access-control-request-headers", "content-type") => Some(80), + ("access-control-request-method", "get") => Some(81), + ("access-control-request-method", "post") => Some(82), + ("alt-svc", "clear") => Some(83), + ("authorization", "") => Some(84), + ( + "content-security-policy", + "script-src 'none'; object-src 'none'; base-uri 'none'", + ) => Some(85), + ("early-data", "1") => Some(86), + ("expect-ct", "") => Some(87), + ("forwarded", "") => Some(88), + ("if-range", "") => Some(89), + ("origin", "") => Some(90), + ("purpose", "prefetch") => Some(91), + ("server", "") => Some(92), + ("timing-allow-origin", "*") => Some(93), + ("upgrade-insecure-requests", "1") => Some(94), + ("user-agent", "") => Some(95), + ("x-forwarded-for", "") => Some(96), + ("x-frame-options", "deny") => Some(97), + ("x-frame-options", "sameorigin") => Some(98), + _ => None, + } + } + + pub fn find_name(name: &str) -> Option { + match name { + ":authority" => Some(0), + ":path" => Some(1), + "age" => Some(2), + "content-disposition" => Some(3), + "content-length" => Some(4), + "cookie" => Some(5), + "date" => Some(6), + "etag" => Some(7), + "if-modified-since" => Some(8), + "if-none-match" => Some(9), + "last-modified" => Some(10), + "link" => Some(11), + "location" => Some(12), + "referer" => Some(13), + "set-cookie" => Some(14), + ":method" => Some(15), + ":scheme" => Some(22), + ":status" => Some(24), + "accept" => Some(29), + "accept-encoding" => Some(31), + "accept-ranges" => Some(32), + "access-control-allow-headers" => Some(33), + "access-control-allow-origin" => Some(35), + "cache-control" => Some(36), + "content-encoding" => Some(42), + "content-type" => Some(44), + "range" => Some(55), + "strict-transport-security" => Some(56), + "vary" => Some(59), + "x-content-type-options" => Some(61), + "x-xss-protection" => Some(62), + "accept-language" => Some(72), + "access-control-allow-credentials" => Some(73), + "access-control-allow-methods" => Some(76), + "access-control-expose-headers" => Some(79), + "access-control-request-headers" => Some(80), + "access-control-request-method" => Some(81), + "alt-svc" => Some(83), + "authorization" => Some(84), + "content-security-policy" => Some(85), + "early-data" => Some(86), + "expect-ct" => Some(87), + "forwarded" => Some(88), + "if-range" => Some(89), + "origin" => Some(90), + "purpose" => Some(91), + "server" => Some(92), + "timing-allow-origin" => Some(93), + "upgrade-insecure-requests" => Some(94), + "user-agent" => Some(95), + "x-forwarded-for" => Some(96), + "x-frame-options" => Some(97), + _ => None, + } + } +} + +const PREDEFINED_HEADERS: [(&str, &str); 99] = [ + (":authority", ""), + (":path", "/"), + ("age", "0"), + ("content-disposition", ""), + ("content-length", "0"), + ("cookie", ""), + ("date", ""), + ("etag", ""), + ("if-modified-since", ""), + ("if-none-match", ""), + ("last-modified", ""), + ("link", ""), + ("location", ""), + ("referer", ""), + ("set-cookie", ""), + (":method", "CONNECT"), + (":method", "DELETE"), + (":method", "GET"), + (":method", "HEAD"), + (":method", "OPTIONS"), + (":method", "POST"), + (":method", "PUT"), + (":scheme", "http"), + (":scheme", "https"), + (":status", "103"), + (":status", "200"), + (":status", "304"), + (":status", "404"), + (":status", "503"), + ("accept", "*/*"), + ("accept", "application/dns-message"), + ("accept-encoding", "gzip, deflate, br"), + ("accept-ranges", "bytes"), + ("access-control-allow-headers", "cache-control"), + ("access-control-allow-headers", "content-type"), + ("access-control-allow-origin", "*"), + ("cache-control", "max-age=0"), + ("cache-control", "max-age=2592000"), + ("cache-control", "max-age=604800"), + ("cache-control", "no-cache"), + ("cache-control", "no-store"), + ("cache-control", "public, max-age=31536000"), + ("content-encoding", "br"), + ("content-encoding", "gzip"), + ("content-type", "application/dns-message"), + ("content-type", "application/javascript"), + ("content-type", "application/json"), + ("content-type", "application/x-www-form-urlencoded"), + ("content-type", "image/gif"), + ("content-type", "image/jpeg"), + ("content-type", "image/png"), + ("content-type", "text/css"), + ("content-type", "text/html; charset=utf-8"), + ("content-type", "text/plain"), + ("content-type", "text/plain;charset=utf-8"), + ("range", "bytes=0-"), + ("strict-transport-security", "max-age=31536000"), + ( + "strict-transport-security", + "max-age=31536000; includesubdomains", + ), + ( + "strict-transport-security", + "max-age=31536000; includesubdomains; preload", + ), + ("vary", "accept-encoding"), + ("vary", "origin"), + ("x-content-type-options", "nosniff"), + ("x-xss-protection", "1; mode=block"), + (":status", "100"), + (":status", "204"), + (":status", "206"), + (":status", "302"), + (":status", "400"), + (":status", "403"), + (":status", "421"), + (":status", "425"), + (":status", "500"), + ("accept-language", ""), + ("access-control-allow-credentials", "FALSE"), + ("access-control-allow-credentials", "TRUE"), + ("access-control-allow-headers", "*"), + ("access-control-allow-methods", "get"), + ("access-control-allow-methods", "get, post, options"), + ("access-control-allow-methods", "options"), + ("access-control-expose-headers", "content-length"), + ("access-control-request-headers", "content-type"), + ("access-control-request-method", "get"), + ("access-control-request-method", "post"), + ("alt-svc", "clear"), + ("authorization", ""), + ( + "content-security-policy", + "script-src 'none'; object-src 'none'; base-uri 'none'", + ), + ("early-data", "1"), + ("expect-ct", ""), + ("forwarded", ""), + ("if-range", ""), + ("origin", ""), + ("purpose", "prefetch"), + ("server", ""), + ("timing-allow-origin", "*"), + ("upgrade-insecure-requests", "1"), + ("user-agent", ""), + ("x-forwarded-for", ""), + ("x-frame-options", "deny"), + ("x-frame-options", "sameorigin"), +]; diff --git a/third_party/web-transport-proto/src/settings.rs b/third_party/web-transport-proto/src/settings.rs new file mode 100644 index 0000000..110c538 --- /dev/null +++ b/third_party/web-transport-proto/src/settings.rs @@ -0,0 +1,254 @@ +use std::{ + collections::HashMap, + fmt::Debug, + ops::{Deref, DerefMut}, + sync::Arc, +}; + +use bytes::{Buf, BufMut, BytesMut}; + +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use super::{Frame, StreamUni, VarInt, VarIntUnexpectedEnd}; + +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct Setting(pub VarInt); + +impl Setting { + pub fn decode(buf: &mut B) -> Result { + Ok(Setting(VarInt::decode(buf)?)) + } + + pub fn encode(&self, buf: &mut B) { + self.0.encode(buf) + } + + // Reference : https://datatracker.ietf.org/doc/html/rfc9114#section-7.2.4.1 + pub fn is_grease(&self) -> bool { + let val = self.0.into_inner(); + if val < 0x21 { + return false; + } + + #[allow(unknown_lints, clippy::manual_is_multiple_of)] + { + (val - 0x21) % 0x1f == 0 + } + } +} + +impl Debug for Setting { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match *self { + Setting::QPACK_MAX_TABLE_CAPACITY => write!(f, "QPACK_MAX_TABLE_CAPACITY"), + Setting::MAX_FIELD_SECTION_SIZE => write!(f, "MAX_FIELD_SECTION_SIZE"), + Setting::QPACK_BLOCKED_STREAMS => write!(f, "QPACK_BLOCKED_STREAMS"), + Setting::ENABLE_CONNECT_PROTOCOL => write!(f, "ENABLE_CONNECT_PROTOCOL"), + Setting::ENABLE_DATAGRAM => write!(f, "ENABLE_DATAGRAM"), + Setting::ENABLE_DATAGRAM_DEPRECATED => write!(f, "ENABLE_DATAGRAM_DEPRECATED"), + Setting::WEBTRANSPORT_ENABLE_DEPRECATED => write!(f, "WEBTRANSPORT_ENABLE_DEPRECATED"), + Setting::WEBTRANSPORT_MAX_SESSIONS_DEPRECATED => { + write!(f, "WEBTRANSPORT_MAX_SESSIONS_DEPRECATED") + } + Setting::WEBTRANSPORT_MAX_SESSIONS => write!(f, "WEBTRANSPORT_MAX_SESSIONS"), + x if x.is_grease() => write!(f, "GREASE SETTING [{:x?}]", x.0.into_inner()), + x => write!(f, "UNKNOWN_SETTING [{:x?}]", x.0.into_inner()), + } + } +} + +macro_rules! settings { + {$($name:ident = $val:expr,)*} => { + impl Setting { + $(pub const $name: Setting = Setting(VarInt::from_u32($val));)* + } + } +} + +settings! { + // These are for HTTP/3 and we can ignore them + QPACK_MAX_TABLE_CAPACITY = 0x1, // default is 0, which disables QPACK dynamic table + MAX_FIELD_SECTION_SIZE = 0x6, + QPACK_BLOCKED_STREAMS = 0x7, + + // Both of these are required for WebTransport + ENABLE_CONNECT_PROTOCOL = 0x8, + ENABLE_DATAGRAM = 0x33, + ENABLE_DATAGRAM_DEPRECATED = 0xFFD277, // still used by Chrome + + // Removed in draft 06 + WEBTRANSPORT_ENABLE_DEPRECATED = 0x2b603742, + WEBTRANSPORT_MAX_SESSIONS_DEPRECATED = 0x2b603743, + + // New way to enable WebTransport + WEBTRANSPORT_MAX_SESSIONS = 0xc671706a, +} + +#[derive(Error, Debug, Clone)] +pub enum SettingsError { + #[error("unexpected end of input")] + UnexpectedEnd, + + #[error("unexpected stream type {0:?}")] + UnexpectedStreamType(StreamUni), + + #[error("unexpected frame {0:?}")] + UnexpectedFrame(Frame), + + #[error("invalid size")] + InvalidSize, + + #[error("io error: {0}")] + Io(Arc), +} + +impl From for SettingsError { + fn from(err: std::io::Error) -> Self { + SettingsError::Io(Arc::new(err)) + } +} + +// A map of settings to values. +#[derive(Default, Debug)] +pub struct Settings(HashMap); + +impl Settings { + pub fn decode(buf: &mut B) -> Result { + let typ = StreamUni::decode(buf).map_err(|_| SettingsError::UnexpectedEnd)?; + if typ != StreamUni::CONTROL { + return Err(SettingsError::UnexpectedStreamType(typ)); + } + + let (typ, mut data) = Frame::read(buf).map_err(|_| SettingsError::UnexpectedEnd)?; + if typ != Frame::SETTINGS { + return Err(SettingsError::UnexpectedFrame(typ)); + } + + let mut settings = Settings::default(); + while data.has_remaining() { + // These return a different error because retrying won't help. + let id = Setting::decode(&mut data).map_err(|_| SettingsError::InvalidSize)?; + let value = VarInt::decode(&mut data).map_err(|_| SettingsError::InvalidSize)?; + // Only add if it is not grease + if !id.is_grease() { + settings.0.insert(id, value); + } + } + + Ok(settings) + } + + pub async fn read(stream: &mut S) -> Result { + let mut buf = Vec::new(); + + loop { + if stream.read_buf(&mut buf).await? == 0 { + return Err(SettingsError::UnexpectedEnd); + } + + // Look at the buffer we've already read. + let mut limit = std::io::Cursor::new(&buf); + + match Settings::decode(&mut limit) { + Ok(settings) => return Ok(settings), + Err(SettingsError::UnexpectedEnd) => continue, // More data needed. + Err(e) => return Err(e), + }; + } + } + + pub fn encode(&self, buf: &mut B) { + StreamUni::CONTROL.encode(buf); + Frame::SETTINGS.encode(buf); + + // Encode to a temporary buffer so we can learn the length. + // TODO avoid doing this, just use a fixed size varint. + let mut tmp = Vec::new(); + for (id, value) in &self.0 { + id.encode(&mut tmp); + value.encode(&mut tmp); + } + + VarInt::from_u32(tmp.len() as u32).encode(buf); + buf.put_slice(&tmp); + } + + pub async fn write(&self, stream: &mut S) -> Result<(), SettingsError> { + // TODO avoid allocating to the heap + let mut buf = BytesMut::new(); + self.encode(&mut buf); + stream.write_all_buf(&mut buf).await?; + Ok(()) + } + + pub fn enable_webtransport(&mut self, max_sessions: u32) { + let max = VarInt::from_u32(max_sessions); + + self.insert(Setting::ENABLE_CONNECT_PROTOCOL, VarInt::from_u32(1)); + self.insert(Setting::ENABLE_DATAGRAM, VarInt::from_u32(1)); + self.insert(Setting::ENABLE_DATAGRAM_DEPRECATED, VarInt::from_u32(1)); + self.insert(Setting::WEBTRANSPORT_MAX_SESSIONS, max); + + // TODO remove when 07 is in the wild + self.insert(Setting::WEBTRANSPORT_MAX_SESSIONS_DEPRECATED, max); + self.insert(Setting::WEBTRANSPORT_ENABLE_DEPRECATED, VarInt::from_u32(1)); + } + + // Returns the maximum number of sessions supported. + pub fn supports_webtransport(&self) -> u64 { + // Sent by Chrome 114.0.5735.198 (July 19, 2023) + // Setting(1): 65536, // qpack_max_table_capacity + // Setting(6): 16384, // max_field_section_size + // Setting(7): 100, // qpack_blocked_streams + // Setting(51): 1, // enable_datagram + // Setting(16765559): 1 // enable_datagram_deprecated + // Setting(727725890): 1, // webtransport_max_sessions_deprecated + // Setting(4445614305): 454654587, // grease + + // NOTE: The presence of ENABLE_WEBTRANSPORT implies ENABLE_CONNECT is supported. + + let datagram = self + .get(&Setting::ENABLE_DATAGRAM) + .or(self.get(&Setting::ENABLE_DATAGRAM_DEPRECATED)) + .map(|v| v.into_inner()); + + if datagram != Some(1) { + return 0; + } + + // The deprecated (before draft-07) way of enabling WebTransport was to send two parameters. + // Both would send ENABLE=1 and the server would send MAX_SESSIONS=N to limit the sessions. + // Now both just send MAX_SESSIONS, and a non-zero value means WebTransport is enabled. + + if let Some(max) = self.get(&Setting::WEBTRANSPORT_MAX_SESSIONS) { + return max.into_inner(); + } + + let enabled = self + .get(&Setting::WEBTRANSPORT_ENABLE_DEPRECATED) + .map(|v| v.into_inner()); + if enabled != Some(1) { + return 0; + } + + // Only the server is allowed to set this one, so if it's None we assume it's 1. + self.get(&Setting::WEBTRANSPORT_MAX_SESSIONS_DEPRECATED) + .map(|v| v.into_inner()) + .unwrap_or(1) + } +} + +impl Deref for Settings { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for Settings { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} diff --git a/third_party/web-transport-proto/src/stream.rs b/third_party/web-transport-proto/src/stream.rs new file mode 100644 index 0000000..d32ae1c --- /dev/null +++ b/third_party/web-transport-proto/src/stream.rs @@ -0,0 +1,45 @@ +use bytes::{Buf, BufMut}; + +use super::{VarInt, VarIntUnexpectedEnd}; + +// Sent as the first bytes of a unidirectional stream to identify the type. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct StreamUni(pub VarInt); + +impl StreamUni { + pub fn decode(buf: &mut B) -> Result { + Ok(StreamUni(VarInt::decode(buf)?)) + } + + pub fn encode(&self, buf: &mut B) { + self.0.encode(buf) + } + + pub fn is_grease(&self) -> bool { + let val = self.0.into_inner(); + if val < 0x21 { + return false; + } + + #[allow(unknown_lints, clippy::manual_is_multiple_of)] + { + (val - 0x21) % 0x1f == 0 + } + } +} + +macro_rules! streams_uni { + {$($name:ident = $val:expr,)*} => { + impl StreamUni { + $(pub const $name: StreamUni = StreamUni(VarInt::from_u32($val));)* + } + } +} + +streams_uni! { + CONTROL = 0x00, + PUSH = 0x01, + QPACK_ENCODER = 0x02, + QPACK_DECODER = 0x03, + WEBTRANSPORT = 0x54, +} diff --git a/third_party/web-transport-proto/src/varint.rs b/third_party/web-transport-proto/src/varint.rs new file mode 100644 index 0000000..a0c2086 --- /dev/null +++ b/third_party/web-transport-proto/src/varint.rs @@ -0,0 +1,233 @@ +// Based on Quinn: https://github.com/quinn-rs/quinn/tree/main/quinn-proto/src +// Licensed under Apache-2.0 OR MIT + +use std::{convert::TryInto, fmt, io::Cursor}; + +use bytes::{Buf, BufMut}; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +/// An integer less than 2^62 +/// +/// Values of this type are suitable for encoding as QUIC variable-length integer. +// It would be neat if we could express to Rust that the top two bits are available for use as enum +// discriminants +#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct VarInt(pub(crate) u64); + +impl VarInt { + /// The largest representable value + pub const MAX: Self = Self((1 << 62) - 1); + /// The largest encoded value length + pub const MAX_SIZE: usize = 8; + + /// Construct a `VarInt` infallibly + pub const fn from_u32(x: u32) -> Self { + Self(x as u64) + } + + /// Succeeds if `x` < 2^62 + pub fn from_u64(x: u64) -> Result { + if x < 2u64.pow(62) { + Ok(Self(x)) + } else { + Err(VarIntBoundsExceeded) + } + } + + /// Create a VarInt without ensuring it's in range + /// + /// # Safety + /// + /// `x` must be less than 2^62. + pub const unsafe fn from_u64_unchecked(x: u64) -> Self { + Self(x) + } + + /// Extract the integer value + pub const fn into_inner(self) -> u64 { + self.0 + } + + /// Compute the number of bytes needed to encode this value + pub fn size(self) -> usize { + let x = self.0; + if x < 2u64.pow(6) { + 1 + } else if x < 2u64.pow(14) { + 2 + } else if x < 2u64.pow(30) { + 4 + } else if x < 2u64.pow(62) { + 8 + } else { + unreachable!("malformed VarInt"); + } + } +} + +impl From for u64 { + fn from(x: VarInt) -> Self { + x.0 + } +} + +impl From for VarInt { + fn from(x: u8) -> Self { + Self(x.into()) + } +} + +impl From for VarInt { + fn from(x: u16) -> Self { + Self(x.into()) + } +} + +impl From for VarInt { + fn from(x: u32) -> Self { + Self(x.into()) + } +} + +impl std::convert::TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + /// Succeeds iff `x` < 2^62 + fn try_from(x: u64) -> Result { + Self::from_u64(x) + } +} + +impl std::convert::TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + /// Succeeds iff `x` < 2^62 + fn try_from(x: u128) -> Result { + Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?) + } +} + +impl std::convert::TryFrom for VarInt { + type Error = VarIntBoundsExceeded; + /// Succeeds iff `x` < 2^62 + fn try_from(x: usize) -> Result { + Self::try_from(x as u64) + } +} + +impl fmt::Debug for VarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl fmt::Display for VarInt { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl VarInt { + pub fn decode(r: &mut B) -> Result { + if !r.has_remaining() { + return Err(VarIntUnexpectedEnd); + } + let mut buf = [0; 8]; + buf[0] = r.get_u8(); + let tag = buf[0] >> 6; + buf[0] &= 0b0011_1111; + let x = match tag { + 0b00 => u64::from(buf[0]), + 0b01 => { + if r.remaining() < 1 { + return Err(VarIntUnexpectedEnd); + } + r.copy_to_slice(&mut buf[1..2]); + u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap())) + } + 0b10 => { + if r.remaining() < 3 { + return Err(VarIntUnexpectedEnd); + } + r.copy_to_slice(&mut buf[1..4]); + u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap())) + } + 0b11 => { + if r.remaining() < 7 { + return Err(VarIntUnexpectedEnd); + } + r.copy_to_slice(&mut buf[1..8]); + u64::from_be_bytes(buf) + } + _ => unreachable!(), + }; + Ok(Self(x)) + } + + // Read a varint from the stream. + pub async fn read(stream: &mut S) -> Result { + // 8 bytes is the max size of a varint + let mut buf = [0; 8]; + + // Read the first byte because it includes the length. + stream + .read_exact(&mut buf[0..1]) + .await + .map_err(|_| VarIntUnexpectedEnd)?; + + // 0b00 = 1, 0b01 = 2, 0b10 = 4, 0b11 = 8 + let size = 1 << (buf[0] >> 6); + stream + .read_exact(&mut buf[1..size]) + .await + .map_err(|_| VarIntUnexpectedEnd)?; + + // Use a cursor to read the varint on the stack. + let mut cursor = Cursor::new(&buf[..size]); + let v = VarInt::decode(&mut cursor).unwrap(); + + Ok(v) + } + + pub fn encode(&self, w: &mut B) { + let x = self.0; + if x < 2u64.pow(6) { + w.put_u8(x as u8); + } else if x < 2u64.pow(14) { + w.put_u16((0b01 << 14) | x as u16); + } else if x < 2u64.pow(30) { + w.put_u32((0b10 << 30) | x as u32); + } else if x < 2u64.pow(62) { + w.put_u64((0b11 << 62) | x); + } else { + unreachable!("malformed VarInt") + } + } + + pub async fn write( + &self, + stream: &mut S, + ) -> Result<(), VarIntUnexpectedEnd> { + // Super jaink but keeps everything on the stack. + let mut buf = [0u8; 8]; + let mut cursor: &mut [u8] = &mut buf; + self.encode(&mut cursor); + let size = 8 - cursor.len(); + + let mut cursor = &buf[..size]; + stream + .write_all_buf(&mut cursor) + .await + .map_err(|_| VarIntUnexpectedEnd)?; + + Ok(()) + } +} + +/// Error returned when constructing a `VarInt` from a value >= 2^62 +#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)] +#[error("value too large for varint encoding")] +pub struct VarIntBoundsExceeded; + +#[derive(Error, Debug, Copy, Clone, Eq, PartialEq)] +#[error("unexpected end of buffer")] +pub struct VarIntUnexpectedEnd;