Bit String Distributions

Definitions ▸
`FixedBitString n` is a length-$n$ bit string (`List.Vector Bool n`). `bitStringCardinality` counts $|\{0,1\}^n| = 2^n$. `uniformBitStringOfLength` wraps the uniform `PMF` on `FixedBitString n` into a `PMF BitString` by forgetting the length index.
/-- A bit string of fixed length `n`. -/
abbrev FixedBitString (n : ℕ) := List.Vector Bool n

/-- There are exactly `2^n` bit strings of length `n`. -/
lemma bitStringCardinality (n : ℕ) :
    Fintype.card (FixedBitString n) = 2 ^ n := by
  simp [card_vector, Fintype.card_bool]

/-- Sample a uniform bit string of length `n`. -/
noncomputable def uniformBitStringOfLength (n : ℕ) : PMF BitString :=
  (PMF.uniformOfFintype (FixedBitString n)).map (·.toList)
# Theorem 1: The joint distribution of two independent uniforms is uniform on their product **given:** - $k, n : \mathbb{N}$ (lengths) - $A$ the uniform distribution on $\{0,1\}^k$ - $B$ the uniform distribution on $\{0,1\}^n$ **then:** - sampling $a \sim A$ and $b \sim B$ independently gives the uniform distribution on $\{0,1\}^k \times \{0,1\}^n$ - i.e. $\Pr[(a, b)] = \frac{1}{2^k} \cdot \frac{1}{2^n} = \frac{1}{2^{k+n}}$
/-- Sampling two independent uniforms and pairing gives the uniform
    distribution on the product. -/
lemma bitstring_prod_eq (k n : ℕ) :
    (do let a ← PMF.uniformOfFintype (FixedBitString k)
        let b ← PMF.uniformOfFintype (FixedBitString n)
        return (a, b)) =
    PMF.uniformOfFintype (FixedBitString k × FixedBitString n) := by
reduce the do-notation to explicit bind/pure form
  show (PMF.uniformOfFintype (FixedBitString k)).bind (fun a =>
    (PMF.uniformOfFintype (FixedBitString n)).bind (fun b =>
    PMF.pure (a, b))) =
    PMF.uniformOfFintype (FixedBitString k × FixedBitString n)
fix a concrete pair $(a, b)$; it suffices to show the joint distribution and the product uniform distribution assign $(a, b)$ the same probability
  ext ⟨a, b⟩
expand the PMF bind/pure operations, replace $|\{0,1\}^n|$ with $2^n$, and split the product cardinality
  simp only [PMF.bind_apply, PMF.pure_apply, PMF.uniformOfFintype_apply,
             tsum_fintype, Fintype.card_prod, Prod.mk.injEq,
             bitStringCardinality, Nat.cast_pow, Nat.cast_ofNat]
the goal is now: $\sum_{a_1} \frac{1}{2^k} \cdot \sum_{b_1} \frac{1}{2^n} \cdot \mathbf{1}[a=a_1 \wedge b=b_1] = \frac{1}{2^k \cdot 2^n}$ key lemma: for each fixed $a_1$, the inner sum over $b_1$ collapses: $\sum_{b_1} \frac{1}{2^n} \cdot \mathbf{1}[a=a_1 \wedge b=b_1] = \frac{1}{2^n} \cdot \mathbf{1}[a=a_1]$ because $b$ is fixed, so exactly one $b_1$ (namely $b_1 = b$) contributes
  have h_inner : ∀ (a_1 : FixedBitString k),
      ∑ b_1 : FixedBitString n,
        (2^n : ENNReal)⁻¹ * ite (a = a_1 ∧ b = b_1) (1 : ENNReal) 0
      = (2^n : ENNReal)⁻¹ * ite (a = a_1) (1 : ENNReal) 0 := fun a_1 => by
    by_cases hx : a = a_1
**Case $a = a_1$:** the conjunction reduces to $\mathbf{1}[b = b_1]$, so $\sum_{b_1} \frac{1}{2^n} \cdot \mathbf{1}[b=b_1] = \frac{1}{2^n}$
    · subst hx; simp [Finset.mem_univ]
**Case $a \neq a_1$:** the conjunction is always false, so the sum is $0$
    · simp [hx]
substitute h_inner, factor out $\frac{1}{2^k}$, and evaluate the outer sum
  simp_rw [h_inner, ← mul_assoc, ← Finset.mul_sum]
$\sum_{a_1} \mathbf{1}[a=a_1] = 1$ (exactly one term is nonzero)
  simp only [Finset.sum_ite_eq, Finset.mem_univ, if_true, mul_one]
$\frac{1}{2^k} \cdot \frac{1}{2^n} = \frac{1}{2^k \cdot 2^n}$
  simp [ENNReal.mul_inv]
# Bijection 1: Pairing bijects with concatenation ++ indicates concatenation. The map $(va, vb) \mapsto va \mathbin{++} vb$ is a bijection $\{0,1\}^k \times \{0,1\}^n \simeq \{0,1\}^{k+n}$, with inverse $v \mapsto (v[{:}k],\, v[k{:}])$.
/-- The bijection sending `(va, vb) : FixedBitString k × FixedBitString n`
    to the concatenated vector `va ++ vb : FixedBitString (k + n)`. -/
private def vectorAppendEquiv (k n : ℕ) :
    FixedBitString k × FixedBitString n ≃ FixedBitString (k + n) where
  toFun := fun ⟨a, b⟩ => ⟨a.toList ++ b.toList, by simp⟩
  invFun := fun v =>
    ⟨⟨v.toList.take k, by simp⟩, ⟨v.toList.drop k, by simp⟩⟩
  left_inv := fun ⟨⟨la, hla⟩, ⟨lb, hlb⟩⟩ => by
    simp only [List.Vector.toList]
    ext1
    · simp [← hla]
    · simp [← hla]
  right_inv := fun ⟨l, hl⟩ => by simp [List.take_append_drop]
# Theorem 2: Concatenating independent uniform bit strings # yields a uniform bit string **given:** - $k, n : \mathbb{N}$ - $a \sim \{0,1\}^k$ and $b \sim \{0,1\}^n$ sampled independently **then:** (++ indicates concatenation) - $a \mathbin{++} b \sim \{0,1\}^{k+n}$ **proof sketch:** - the joint distribution of $(a, b)$ is uniform on $\{0,1\}^k \times \{0,1\}^n$ by [`bitstring_prod_eq`](Bitstrings.html#theorem-1 "Bit String Distributions, Theorem 1") - concatenation is a bijection $\{0,1\}^k \times \{0,1\}^n \simeq \{0,1\}^{k+n}$, with inverse $v \mapsto (v_{:k},\, v_{k:})$ (see [Bijection 1](Bitstrings.html#bijection-1)) - push the uniform product distribution along this bijection using [`PMF.uniformOfFintype_map_equiv`](UniformProofs.html#theorem-1 "Uniform Distribution Proofs, Theorem 1") and the result is uniform on $\{0,1\}^{k+n}$
/-- Independently sampling a k-bit string and an n-bit string and
    concatenating is the same distribution as sampling a (k+n)-bit
    string uniformly at random. -/
theorem uniform_concat_eq (k n : ℕ) :
    (do let a ← uniformBitStringOfLength k
        let b ← uniformBitStringOfLength n
        return a ++ b) =
    uniformBitStringOfLength (k + n) := by
  simp only [uniformBitStringOfLength]
restate the goal in fully explicit form so subsequent rewrites have a precise target to match against
  show (PMF.map List.Vector.toList
          (PMF.uniformOfFintype (FixedBitString k))).bind
      (fun a => (PMF.map List.Vector.toList
          (PMF.uniformOfFintype (FixedBitString n))).bind
        (fun b => PMF.pure (a ++ b))) =
      PMF.map List.Vector.toList
        (PMF.uniformOfFintype (FixedBitString (k + n)))
`PMF.bind_map`: `(f.map g).bind h = f.bind (h ∘ g)` push the `toList` maps inside the binds, composing them into the bound function so we work directly on `FixedBitString` values
  simp only [PMF.bind_map]
rewrite using side goal `?_`: the nested binds equal a single map, i.e. sampling $a \sim \{0,1\}^k$ then $b \sim \{0,1\}^n$ and concatenating is the same as sampling $(a,b) \sim \{0,1\}^k \times \{0,1\}^n$ and concatenating. after the rewrite the goal reduces to: `(uniform (k×n)).map concat = (uniform (k+n)).map toList`
  rw [show
      (PMF.uniformOfFintype (FixedBitString k)).bind
        ((fun a => (PMF.uniformOfFintype (FixedBitString n)).bind
          ((fun b => PMF.pure (a ++ b)) ∘
            List.Vector.toList)) ∘ List.Vector.toList) =
      (PMF.uniformOfFintype (FixedBitString k × FixedBitString n)).map
        (fun p => p.1.toList ++ p.2.toList) from ?_]
after this rewrite, the goal becomes: `(uniform (k×n)).map (fun p => p.1.toList ++ p.2.toList)` ` = (uniform (k+n)).map toList` <hr> ### Main goal i.e. if $(a,b) \sim U(\{0,1\}^k \times \{0,1\}^n)$ then $a \mathbin{++} b \sim U(\{0,1\}^{k+n})$ show `(uniform (k×n)).map (fun p => p.1.toList ++ p.2.toList)` $=$ `(uniform (k+n)).map toList` **Step 1.** `fun p => p.1.toList ++ p.2.toList` is definitionally equal to `toList ∘ vectorAppendEquiv`; rewrite by `funext`/`rfl` to make this explicit.
  · rw [show
        (fun p : FixedBitString k × FixedBitString n =>
          p.1.toList ++ p.2.toList) =
        List.Vector.toList ∘ vectorAppendEquiv k n from
        funext fun ⟨va, vb⟩ => rfl]
**Step 2.** `← PMF.map_comp` splits `map (toList ∘ vectorAppendEquiv)` into `(map vectorAppendEquiv).map toList`. Then [`PMF.uniformOfFintype_map_equiv`](UniformProofs.html#theorem-1 "Uniform Distribution Proofs, Theorem 1") converts `uniform (k×n)` to `uniform (k+n)`, closing the goal.
    rw [← PMF.map_comp, PMF.uniformOfFintype_map_equiv]
<hr> ### Side goal prove that the two nested binds (after `bind_map`) equal `(uniform (k×n)).map (fun p => p.1.toList ++ p.2.toList)` **Step 1.** `← PMF.bind_pure_comp` rewrites `f.map g` as `f.bind (pure ∘ g)`, converting the RHS to a `bind` so that `bind` lemmas apply. **Step 2.** [`bitstring_prod_eq`](Bitstrings.html#theorem-1 "Bit String Distributions, Theorem 1") (backwards) replaces `uniform (k×n)` with sequential sampling so both sides have the same bind structure. **Step 3.** `PMF.bind_bind` (associativity) flattens the three nested binds. `congr 1` + `ext` reduces to checking the bound functions agree pointwise. `PMF.pure_bind` eliminates the `pure (va, vb)` wrapper, closing the goal.
  · rw [← PMF.bind_pure_comp
          (f := fun p : FixedBitString k × FixedBitString n =>
            p.1.toList ++ p.2.toList)]
    have prod_eq :
        PMF.uniformOfFintype (FixedBitString k × FixedBitString n) =
        (PMF.uniformOfFintype (FixedBitString k)).bind fun va =>
          (PMF.uniformOfFintype (FixedBitString n)).bind fun vb =>
            PMF.pure (va, vb) :=
      (bitstring_prod_eq k n).symm
    rw [prod_eq, PMF.bind_bind]
    congr 1; ext va; congr 1; ext vb
    simp [PMF.pure_bind]
Raw source
import Mathlib.Probability.Distributions.Uniform
import Mathlib.Data.Fintype.BigOperators
import Mathlib.Data.Fintype.Card
import Onewayf.PolyTime
import Onewayf.UniformProofs

-- leandown
-- [meta]
-- title = "Bit string distributions"
-- group = "Cryptography"
-- [content]

-- [hidden: Definitions]
-- `FixedBitString n` is a length-$n$ bit string (`List.Vector Bool n`).
--
-- `bitStringCardinality` counts $|\{0,1\}^n| = 2^n$.
--
-- `uniformBitStringOfLength` wraps the uniform `PMF` on `FixedBitString n`
-- into a `PMF BitString` by forgetting the length index.
/-- A bit string of fixed length `n`. -/
abbrev FixedBitString (n : ℕ) := List.Vector Bool n

/-- There are exactly `2^n` bit strings of length `n`. -/
lemma bitStringCardinality (n : ℕ) :
    Fintype.card (FixedBitString n) = 2 ^ n := by
  simp [card_vector, Fintype.card_bool]

/-- Sample a uniform bit string of length `n`. -/
noncomputable def uniformBitStringOfLength (n : ℕ) : PMF BitString :=
  (PMF.uniformOfFintype (FixedBitString n)).map (·.toList)
-- [/hidden]

-- # {{theorem}}: The joint distribution of two independent uniforms is uniform on their product
--
-- **given:**
-- - $k, n : \mathbb{N}$ (lengths)
-- - $A$ the uniform distribution on $\{0,1\}^k$
-- - $B$ the uniform distribution on $\{0,1\}^n$
--
-- **then:**
-- - sampling $a \sim A$ and $b \sim B$ independently gives the uniform
--   distribution on $\{0,1\}^k \times \{0,1\}^n$
-- - i.e. $\Pr[(a, b)] = \frac{1}{2^k} \cdot \frac{1}{2^n} = \frac{1}{2^{k+n}}$
--
/-- Sampling two independent uniforms and pairing gives the uniform
    distribution on the product. -/
lemma bitstring_prod_eq (k n : ℕ) :
    (do let a ← PMF.uniformOfFintype (FixedBitString k)
        let b ← PMF.uniformOfFintype (FixedBitString n)
        return (a, b)) =
    PMF.uniformOfFintype (FixedBitString k × FixedBitString n) := by
  -- reduce the do-notation to explicit bind/pure form
  show (PMF.uniformOfFintype (FixedBitString k)).bind (fun a =>
    (PMF.uniformOfFintype (FixedBitString n)).bind (fun b =>
    PMF.pure (a, b))) =
    PMF.uniformOfFintype (FixedBitString k × FixedBitString n)
  -- fix a concrete pair $(a, b)$; it suffices to show the joint distribution
  -- and the product uniform distribution assign $(a, b)$ the same probability
  ext ⟨a, b⟩
  -- expand the PMF bind/pure operations, replace $|\{0,1\}^n|$ with $2^n$,
  -- and split the product cardinality
  simp only [PMF.bind_apply, PMF.pure_apply, PMF.uniformOfFintype_apply,
             tsum_fintype, Fintype.card_prod, Prod.mk.injEq,
             bitStringCardinality, Nat.cast_pow, Nat.cast_ofNat]
  -- the goal is now:
  --   $\sum_{a_1} \frac{1}{2^k} \cdot \sum_{b_1} \frac{1}{2^n}
  --     \cdot \mathbf{1}[a=a_1 \wedge b=b_1] = \frac{1}{2^k \cdot 2^n}$

  -- key lemma: for each fixed $a_1$, the inner sum over $b_1$ collapses:
  --   $\sum_{b_1} \frac{1}{2^n} \cdot \mathbf{1}[a=a_1 \wedge b=b_1]
  --   = \frac{1}{2^n} \cdot \mathbf{1}[a=a_1]$
  -- because $b$ is fixed, so exactly one $b_1$ (namely $b_1 = b$) contributes
  have h_inner : ∀ (a_1 : FixedBitString k),
      ∑ b_1 : FixedBitString n,
        (2^n : ENNReal)⁻¹ * ite (a = a_1 ∧ b = b_1) (1 : ENNReal) 0
      = (2^n : ENNReal)⁻¹ * ite (a = a_1) (1 : ENNReal) 0 := fun a_1 => by
    by_cases hx : a = a_1
    -- **Case $a = a_1$:** the conjunction reduces to $\mathbf{1}[b = b_1]$,
    -- so $\sum_{b_1} \frac{1}{2^n} \cdot \mathbf{1}[b=b_1] = \frac{1}{2^n}$
    · subst hx; simp [Finset.mem_univ]
    -- **Case $a \neq a_1$:** the conjunction is always false,
    -- so the sum is $0$
    · simp [hx]
  -- substitute h_inner, factor out $\frac{1}{2^k}$, and evaluate the outer sum
  simp_rw [h_inner, ← mul_assoc, ← Finset.mul_sum]
  -- $\sum_{a_1} \mathbf{1}[a=a_1] = 1$ (exactly one term is nonzero)
  simp only [Finset.sum_ite_eq, Finset.mem_univ, if_true, mul_one]
  -- $\frac{1}{2^k} \cdot \frac{1}{2^n} = \frac{1}{2^k \cdot 2^n}$
  simp [ENNReal.mul_inv]

-- # {{bijection}}: Pairing bijects with concatenation
-- ++ indicates concatenation.
--
-- The map $(va, vb) \mapsto va \mathbin{++} vb$ is a bijection
-- $\{0,1\}^k \times \{0,1\}^n \simeq \{0,1\}^{k+n}$,
-- with inverse $v \mapsto (v[{:}k],\, v[k{:}])$.
--
/-- The bijection sending `(va, vb) : FixedBitString k × FixedBitString n`
    to the concatenated vector `va ++ vb : FixedBitString (k + n)`. -/
private def vectorAppendEquiv (k n : ℕ) :
    FixedBitString k × FixedBitString n ≃ FixedBitString (k + n) where
  toFun := fun ⟨a, b⟩ => ⟨a.toList ++ b.toList, by simp⟩
  invFun := fun v =>
    ⟨⟨v.toList.take k, by simp⟩, ⟨v.toList.drop k, by simp⟩⟩
  left_inv := fun ⟨⟨la, hla⟩, ⟨lb, hlb⟩⟩ => by
    simp only [List.Vector.toList]
    ext1
    · simp [← hla]
    · simp [← hla]
  right_inv := fun ⟨l, hl⟩ => by simp [List.take_append_drop]

-- # {{theorem}}: Concatenating independent uniform bit strings
-- # yields a uniform bit string
--
-- **given:**
-- - $k, n : \mathbb{N}$
-- - $a \sim \{0,1\}^k$ and $b \sim \{0,1\}^n$ sampled independently
--
-- **then:**
--  (++ indicates concatenation)
-- - $a \mathbin{++} b \sim \{0,1\}^{k+n}$
--
-- **proof sketch:**
-- - the joint distribution of $(a, b)$ is uniform on
--   $\{0,1\}^k \times \{0,1\}^n$ by `bitstring_prod_eq`
-- - concatenation is a bijection
--   $\{0,1\}^k \times \{0,1\}^n \simeq \{0,1\}^{k+n}$,
--   with inverse $v \mapsto (v_{:k},\, v_{k:})$ (see [ref:Bijection 1])
-- - push the uniform product distribution along this bijection
--   using `PMF.uniformOfFintype_map_equiv`
-- and the result is uniform on $\{0,1\}^{k+n}$
--
/-- Independently sampling a k-bit string and an n-bit string and
    concatenating is the same distribution as sampling a (k+n)-bit
    string uniformly at random. -/
theorem uniform_concat_eq (k n : ℕ) :
    (do let a ← uniformBitStringOfLength k
        let b ← uniformBitStringOfLength n
        return a ++ b) =
    uniformBitStringOfLength (k + n) := by
  simp only [uniformBitStringOfLength]
  -- restate the goal in fully explicit form so subsequent rewrites
  -- have a precise target to match against
  show (PMF.map List.Vector.toList
          (PMF.uniformOfFintype (FixedBitString k))).bind
      (fun a => (PMF.map List.Vector.toList
          (PMF.uniformOfFintype (FixedBitString n))).bind
        (fun b => PMF.pure (a ++ b))) =
      PMF.map List.Vector.toList
        (PMF.uniformOfFintype (FixedBitString (k + n)))
  -- `PMF.bind_map`: `(f.map g).bind h = f.bind (h ∘ g)`
  -- push the `toList` maps inside the binds, composing them into
  -- the bound function so we work directly on `FixedBitString` values
  simp only [PMF.bind_map]
  -- rewrite using side goal `?_`: the nested binds equal a single map,
  --
  -- i.e. sampling $a \sim \{0,1\}^k$ then $b \sim \{0,1\}^n$ and concatenating
  -- is the same as sampling $(a,b) \sim \{0,1\}^k \times \{0,1\}^n$ and concatenating.
  -- after the rewrite the goal reduces to:
  -- `(uniform (k×n)).map concat = (uniform (k+n)).map toList`
  rw [show
      (PMF.uniformOfFintype (FixedBitString k)).bind
        ((fun a => (PMF.uniformOfFintype (FixedBitString n)).bind
          ((fun b => PMF.pure (a ++ b)) ∘
            List.Vector.toList)) ∘ List.Vector.toList) =
      (PMF.uniformOfFintype (FixedBitString k × FixedBitString n)).map
        (fun p => p.1.toList ++ p.2.toList) from ?_]
  -- after this rewrite, the goal becomes:
  -- `(uniform (k×n)).map (fun p => p.1.toList ++ p.2.toList)`
  -- `  = (uniform (k+n)).map toList`
  -- <hr>
  --
  -- ### Main goal
  --
  -- i.e. if $(a,b) \sim U(\{0,1\}^k \times \{0,1\}^n)$ then $a \mathbin{++} b \sim U(\{0,1\}^{k+n})$
  --
  -- show `(uniform (k×n)).map (fun p => p.1.toList ++ p.2.toList)`
  -- $=$ `(uniform (k+n)).map toList`
  --
  -- **Step 1.** `fun p => p.1.toList ++ p.2.toList` is definitionally equal to
  -- `toList ∘ vectorAppendEquiv`; rewrite by `funext`/`rfl` to make this explicit.
  --
  · rw [show
        (fun p : FixedBitString k × FixedBitString n =>
          p.1.toList ++ p.2.toList) =
        List.Vector.toList ∘ vectorAppendEquiv k n from
        funext fun ⟨va, vb⟩ => rfl]
  -- **Step 2.** `← PMF.map_comp` splits `map (toList ∘ vectorAppendEquiv)` into
  -- `(map vectorAppendEquiv).map toList`. Then `PMF.uniformOfFintype_map_equiv`
  -- converts `uniform (k×n)` to `uniform (k+n)`, closing the goal.
    rw [← PMF.map_comp, PMF.uniformOfFintype_map_equiv]
  -- <hr>
  --
  -- ### Side goal
  --
  -- prove that the two nested binds (after `bind_map`) equal
  -- `(uniform (k×n)).map (fun p => p.1.toList ++ p.2.toList)`
  --
  -- **Step 1.** `← PMF.bind_pure_comp` rewrites `f.map g` as `f.bind (pure ∘ g)`,
  -- converting the RHS to a `bind` so that `bind` lemmas apply.
  --
  -- **Step 2.** `bitstring_prod_eq` (backwards) replaces `uniform (k×n)` with
  -- sequential sampling so both sides have the same bind structure.
  --
  -- **Step 3.** `PMF.bind_bind` (associativity) flattens the three nested binds.
  -- `congr 1` + `ext` reduces to checking the bound functions agree pointwise.
  -- `PMF.pure_bind` eliminates the `pure (va, vb)` wrapper, closing the goal.
  · rw [← PMF.bind_pure_comp
          (f := fun p : FixedBitString k × FixedBitString n =>
            p.1.toList ++ p.2.toList)]
    have prod_eq :
        PMF.uniformOfFintype (FixedBitString k × FixedBitString n) =
        (PMF.uniformOfFintype (FixedBitString k)).bind fun va =>
          (PMF.uniformOfFintype (FixedBitString n)).bind fun vb =>
            PMF.pure (va, vb) :=
      (bitstring_prod_eq k n).symm
    rw [prod_eq, PMF.bind_bind]
    congr 1; ext va; congr 1; ext vb
    simp [PMF.pure_bind]