Quadratic-time One-way Functions

If one-way functions exist, then there exists a one-way function computable in time $n^2$. In other words, **quadratic time is a universal upper bound** on the complexity of one-way functions: you never need more than $n^2$. **Proof sketch.** Given a one-way function `f` computable in time $n^c$, we construct a new function $f'$ that runs in time $n^2$ and is also a OWF. $f'$ pads its input with a random-looking prefix, which lets us reduce any attack on $f'$ back to an attack on `f`. # Definition 1: The padded function $f'$ Given `f`, exponent `c`, and block size `m`, define $$f'(x) = x_{:m^c} \mathbin{++} f(x_{m^c:})$$ On an input $x = a \mathbin{++} b$ with $|a| = m^c$ and $|b| = m$, this returns $a \mathbin{++} f(b)$, leaving the prefix untouched and applying `f` only to the suffix.
def fPrime (f : BitString → BitString) (c m : ℕ) : BitString → BitString :=
  fun x => List.take (m ^ c) x ++ f (List.drop (m ^ c) x)
# Theorem 1: $f'$ is poly-time computable `take`/`drop` are $O(n)$; applying `f` costs $O(m^c) \leq O(n^c)$; appending is $O(n)$. With $n = m^c + m$ the total is $O(n^2)$.
theorem fPrime_polytime (f : BitString → BitString) (c m : ℕ)
    (hf : PolyTimeComputable f) : PolyTimeComputable (fPrime f c m) :=
  .append (.take (m ^ c)) (.comp hf (.drop (m ^ c)))
# Definition 2: The reduction adversary Given an adversary `A'` that inverts $f'$, we build an adversary for `f`. On input $y = f(b)$ (where $b$ has length $m$): 1. Sample $a \leftarrow \{0,1\}^{m^c}$ uniformly. 2. Run `A'` on $a \mathbin{++} y$ to get output $z$. 3. Return $z_{m^c:}$ as the guess for $b$. The idea: `A'` expects to see $f'(a \mathbin{++} b) = a \mathbin{++} f(b)$, which is exactly $a \mathbin{++} y$. So we can simulate $f'$'s interface for `A'` by supplying a fresh random prefix.
noncomputable def buildAdv (A' : BitString → PMF BitString) (m c : ℕ) :
    BitString → PMF BitString :=
  fun y => do
    let a ← uniformBitStringOfLength (m ^ c)
    let z ← A' (a ++ y)
    return List.drop (m ^ c) z
# Theorem 2: [`buildAdv`](QuadraticOWF.html#definition-2 "Quadratic-time one-way functions, Definition 2") is a poly-time adversary Sampling $a$ and calling `A'` on a padded input are both poly-time; dropping the prefix is $O(n)$. The `uniformPad` constructor of `PolyTimeAdversary` captures exactly this pattern. **Note:** [`buildAdv`](QuadraticOWF.html#definition-2 "Quadratic-time one-way functions, Definition 2") is inherently randomized (step 1 samples $a$). Our `PolyTimeAdversary` model uses `uniformPad` to handle this, so the proof goes through directly.
theorem buildAdv_polytime (A' : BitString → PMF BitString) (m c : ℕ)
    (hA' : PolyTimeAdversary A') : PolyTimeAdversary (buildAdv A' m c) := by
  unfold buildAdv
  exact .uniformPad (m ^ c) hA'
# Lemma 1: The coupling inequality For any input length $n \geq m^c$, attacking $f'$ with `A'` on $n$-bit inputs succeeds with probability **at most** that of attacking $f$ with `buildAdv A' m c` on $(n - m^c)$-bit inputs. **Why inequality, not equality.** Writing $x = a \mathbin{++} b$ with $|a| = m^c$ and $|b| = n - m^c$, we have $f'(x) = a \mathbin{++} f(b)$. For `A'` to invert $f'$ it must output $z$ satisfying **both** $z_{:m^c} = a$ (correct prefix) **and** $f(z_{m^c:}) = f(b)$ (correct suffix). [`buildAdv`](QuadraticOWF.html#definition-2 "Quadratic-time one-way functions, Definition 2") only checks the suffix condition $f(z_{m^c:}) = f(b)$, so every $f'$-success implies an $f$-success but not vice versa. **Proof sketch.** 1. Reindex the LHS sum over $\{0,1\}^n$ as a double sum over $\{0,1\}^{m^c} \times \{0,1\}^{n-m^c}$ via [`vectorAppendEquiv`](Bitstrings.html#bijection-1 "Bit string distributions, Bijection 1"). 2. Expand [`buildAdv`](QuadraticOWF.html#definition-2 "Quadratic-time one-way functions, Definition 2") in the RHS using `PMF.bind_apply` and swap summation order (`ENNReal.tsum_comm`); convert `uniformBitStringOfLength` back to a finite sum over `FixedBitString (m^c)`. 3. Both sides now have the form $\sum_{a,b} \tfrac{1}{2^n} \cdot I(a,b)$. The LHS indicator $\mathbf{1}[z_{:m^c} = a \wedge f(z_{m^c:}) = f(b)]$ is pointwise $\leq$ the RHS indicator $\mathbf{1}[f(z_{m^c:}) = f(b)]$. The indicator inequality at the heart of the coupling: fPrime success (with length constraint) implies suffix-f success.
private lemma fPrime_success_implies_suffix {k n : ℕ} (hkn : k ≤ n)
    {f : BitString → BitString}
    (a : FixedBitString k) (b : FixedBitString (n - k)) (x' : BitString)
    (hfp : List.take k x' ++ f (List.drop k x') = a.toList ++ f b.toList)
    (hlen : x'.length = n) :
    f (List.drop k x') = f b.toList ∧ (List.drop k x').length = n - k := by
Since x'.length = n ≥ k, the take has exactly k elements — matching a.toList. Concatenation injectivity then gives f(x'.drop k) = f(b.toList).
  have htake_len : (List.take k x').length = k := by
    simp [List.length_take, Nat.min_eq_left (hlen ▸ hkn)]
  have hinj := List.append_inj hfp (htake_len.trans a.toList_length.symm)
  exact ⟨hinj.2, by simp [List.length_drop, hlen]⟩
Bijection splitting a length-n bit string into a k-bit prefix and (n-k)-bit suffix.
private def splitN (k n : ℕ) (hkn : k ≤ n) :
    FixedBitString n ≃ FixedBitString k × FixedBitString (n - k) where
  toFun v :=
    ⟨⟨v.toList.take k, by have h := v.toList_length; simp [List.length_take]; omega⟩,
     ⟨v.toList.drop k, by have h := v.toList_length; simp [List.length_drop, h]⟩⟩
  invFun p :=
    ⟨p.1.toList ++ p.2.toList, by
      have h1 : p.1.toList.length = k := p.1.toList_length
      have h2 : p.2.toList.length = n - k := p.2.toList_length
      simp only [List.length_append, h1, h2, Nat.add_sub_cancel' hkn]⟩
  left_inv v := List.Vector.ext (by simp [List.take_append_drop])
  right_inv p := by
    rcases p with ⟨a, b⟩
    apply Prod.ext <;> apply List.Vector.ext
    all_goals simp [a.toList_length]
The ENNReal sum representing `invertProb f F n` is ≤ 1, hence finite.
private lemma invertProb_ENNReal_le_one (f : BitString → BitString)
    (F : BitString → PMF BitString) (n : ℕ) :
    (∑ x : FixedBitString n,
      PMF.uniformOfFintype (FixedBitString n) x *
      (∑' x' : BitString, F (f x.toList) x' *
        (if f x' = f x.toList ∧ x'.length = n then (1 : ENNReal) else 0))) ≤ 1 := by
Bound each term: uniform(n,x) * (inner sum) ≤ uniform(n,x) * 1 then sum to 1.
  calc (∑ x : FixedBitString n,
        PMF.uniformOfFintype (FixedBitString n) x *
        (∑' x' : BitString, F (f x.toList) x' *
          (if f x' = f x.toList ∧ x'.length = n then (1 : ENNReal) else 0)))
      ≤ ∑ x : FixedBitString n, PMF.uniformOfFintype (FixedBitString n) x := by
        apply Finset.sum_le_sum; intro x _
bound inner sum by 1, then use uniform ≤ 1 is absorbed by mul_one
        calc PMF.uniformOfFintype (FixedBitString n) x *
              (∑' x', F (f x.toList) x' * (if f x' = f x.toList ∧ x'.length = n then 1 else 0))
            ≤ PMF.uniformOfFintype (FixedBitString n) x * 1 := by
              gcongr
              calc ∑' x', F (f x.toList) x' * (if f x' = f x.toList ∧ x'.length = n then 1 else 0)
                  ≤ ∑' x', F (f x.toList) x' * 1 := by
                    apply ENNReal.tsum_le_tsum; intro x'
                    gcongr; split_ifs <;> norm_num
                _ = 1 := by simp [PMF.tsum_coe]
          _ = PMF.uniformOfFintype (FixedBitString n) x := mul_one _
    _ = 1 := by
        have htotal := (PMF.uniformOfFintype (FixedBitString n)).tsum_coe
        rw [tsum_fintype (L := .unconditional _)] at htotal
        exact htotal

private lemma invertProb_ENNReal_ne_top (f : BitString → BitString)
    (F : BitString → PMF BitString) (n : ℕ) :
    (∑ x : FixedBitString n,
      PMF.uniformOfFintype (FixedBitString n) x *
      (∑' x' : BitString, F (f x.toList) x' *
        (if f x' = f x.toList ∧ x'.length = n then (1 : ENNReal) else 0))) ≠ ⊤ :=
  ne_top_of_le_ne_top ENNReal.one_ne_top (invertProb_ENNReal_le_one f F n)
The uniform measure splits over the splitN bijection: uniform(n, splitN.symm (a, b)) = uniform(k, a) * uniform(n-k, b)
private lemma uniform_split (k n : ℕ) (hkn : k ≤ n)
    (a : FixedBitString k) (b : FixedBitString (n - k)) :
    PMF.uniformOfFintype (FixedBitString n) ((splitN k n hkn).symm (a, b)) =
    PMF.uniformOfFintype (FixedBitString k) a *
    PMF.uniformOfFintype (FixedBitString (n - k)) b := by
Both sides equal 1/2^n; show cardinality factorizes then split the inverse
  simp only [PMF.uniformOfFintype_apply]
  rw [show Fintype.card (FixedBitString n) =
        Fintype.card (FixedBitString k) * Fintype.card (FixedBitString (n - k)) from by
    simp [← pow_add, Nat.add_sub_cancel' hkn]]
  push_cast
split (a*b)⁻¹ = a⁻¹ * b⁻¹; requires nonzero conditions since ENNReal isn't a group
  rw [ENNReal.mul_inv (Or.inl (by positivity)) (Or.inr (by positivity))]
fPrime success at (a ++ b) implies f-suffix success (dropping prefix constraint)
private lemma indicator_fPrime_le_suffix (f : BitString → BitString) (c m k n : ℕ)
    (hkn : k ≤ n) (hk : m ^ c = k)
    (a : FixedBitString k) (b : FixedBitString (n - k)) (x' : BitString) :
    (if fPrime f c m x' = a.toList ++ f b.toList ∧ x'.length = n then (1 : ENNReal) else 0) ≤
    (if f (x'.drop k) = f b.toList ∧ (x'.drop k).length = n - k then 1 else 0) := by
Case split on whether fPrime succeeds
  by_cases h : fPrime f c m x' = a.toList ++ f b.toList ∧ x'.length = n
  · simp only [if_pos h]
Unfold fPrime in h.1 (using hk : m^c = k) so fPrime_success_implies_suffix applies
    simp only [fPrime, hk] at h
    have hsuf := fPrime_success_implies_suffix hkn a b x' h.1 h.2
    rw [if_pos hsuf]
  · simp [if_neg h]
buildAdv applied to an output indicator integrates out the prefix sample: ∑' z', buildAdv A' m c y z' * [f z' = fy ∧ |z'|=r] = ∑ va, uniform(k,va) * ∑' z, A'(va.toList ++ y) z * [f(z.drop k)=fy ∧ |z.drop k|=r]
private lemma buildAdv_indicator_sum (A' : BitString → PMF BitString) (m c : ℕ)
    (f : BitString → BitString) (y fy : BitString) (r : ℕ) :
    (∑' z' : BitString, buildAdv A' m c y z' *
      (if f z' = fy ∧ z'.length = r then 1 else 0)) =
    ∑ va : FixedBitString (m ^ c),
      PMF.uniformOfFintype (FixedBitString (m ^ c)) va *
      (∑' z : BitString, A' (va.toList ++ y) z *
        (if f (z.drop (m ^ c)) = fy ∧ (z.drop (m ^ c)).length = r then 1 else 0)) := by
Step 1: express buildAdv as a Finset.sum over FixedBitString k. uniformBitStringOfLength k = (uniformOfFintype k).map toList, so PMF.bind_map applies.
  have hbind : ∀ z' : BitString, buildAdv A' m c y z' =
      ∑ va : FixedBitString (m ^ c), PMF.uniformOfFintype (FixedBitString (m ^ c)) va *
      (∑' z : BitString, A' (va.toList ++ y) z * if z.drop (m ^ c) = z' then 1 else 0) := by
    intro z'
convert do-notation to bind; uniformBitStringOfLength = uniformOfFintype.map toList
    change (PMF.bind (uniformBitStringOfLength (m ^ c)) fun a =>
            PMF.bind (A' (a ++ y)) fun z => PMF.pure (z.drop (m ^ c))) z' = _
    simp only [uniformBitStringOfLength, PMF.bind_map, PMF.bind_apply,
               tsum_fintype (L := .unconditional _)]
now: ∑ va, uniformOfFintype k va * (inner bind applied to z')
    congr 1; ext va; congr 1
    simp only [Function.comp, PMF.bind_apply, PMF.pure_apply]
    congr 1; ext z; simp [eq_comm]
  simp_rw [hbind, Finset.sum_mul]
Step 2: swap ∑' z' and ∑ va by going via ∑' va (tsum_fintype round-trip + tsum_comm)
  simp_rw [← tsum_fintype (L := .unconditional _)]
  rw [ENNReal.tsum_comm]
  simp_rw [tsum_fintype (L := .unconditional _)]
Step 3: for each va, factor out uniform(k,va) and prove inner Fubini equality
  apply Finset.sum_congr rfl; intro va _
re-associate products so uniform(k,va) is the head factor inside ∑' z'
  simp_rw [mul_assoc]
factor uniform(k,va) out of the tsum on both sides: tsum_mul_left (forward): ∑' z', a * f z' = a * ∑' z', f z'
  simp_rw [ENNReal.tsum_mul_left]
use suffices to state the inner tsum equality explicitly and rewrite; this avoids applying congr 1 directly on c * tsum = c * tsum, which would peel tsum to fun equality
  suffices h : ∑' z', (∑' z, A' (va.toList ++ y) z * if z.drop (m ^ c) = z' then 1 else 0) *
      (if f z' = fy ∧ z'.length = r then 1 else 0) =
      ∑' z, A' (va.toList ++ y) z *
        (if f (z.drop (m ^ c)) = fy ∧ (z.drop (m ^ c)).length = r then 1 else 0) by rw [h]
Step 4: Fubini — push outer indicator inside, swap sums, integrate out z' push [f z' = fy ∧ ...] inside the inner tsum: (∑' z, g z z') * c z' → ∑' z, g z z' * c z'
  simp_rw [← ENNReal.tsum_mul_right]
swap ∑' z' and ∑' z
  rw [ENNReal.tsum_comm]
for each z, reassociate, factor out A'(va++y) z (constant in z'), then integrate out z'
  apply tsum_congr; intro z
  simp_rw [mul_assoc]
tsum_mul_left (forward): ∑' z', A'(va++y) z * f z' = A'(va++y) z * ∑' z', f z'
  rw [ENNReal.tsum_mul_left]
  congr 1
∑' z', [z.drop k = z'] * [f z' = fy ∧ z'.length = r] = [f (z.drop k) = fy ∧ ...] the sum has a single nonzero term at z' = z.drop k
  rw [tsum_eq_single (z.drop (m ^ c)) (by intro z' hz'; simp [Ne.symm hz'])]
  simp

lemma buildAdv_invertProb_le (f : BitString → BitString)
    (A' : BitString → PMF BitString) (c m n : ℕ) (hn : m ^ c ≤ n) :
    invertProb (fPrime f c m) A' n ≤
    invertProb f (buildAdv A' m c) (n - m ^ c) := by
  set k := m ^ c with hk_def
  simp only [invertProb]
Reduce the ℝ inequality to an ENNReal inequality (both sides are finite).
  rw [ENNReal.toReal_le_toReal
        (invertProb_ENNReal_ne_top (fPrime f c m) A' n)
        (invertProb_ENNReal_ne_top f (buildAdv A' m c) (n - k))]
Reindex the LHS sum over FixedBitString n as a sum over k × (n-k) pairs via splitN Equiv.sum_comp e f : ∑ x, f (e x) = ∑ y, f y, so ← gives the reindexing direction
  rw [← Equiv.sum_comp (splitN k n hn).symm]
Flatten the product sum and swap to ∑ b, ∑ a order
  rw [Fintype.sum_prod_type, Finset.sum_comm]
Bound each b-term
  apply Finset.sum_le_sum; intro b _
Simplify (splitN.symm (a,b)).toList = a.toList ++ b.toList and fPrime(a.toList ++ b.toList) = a.toList ++ f(b.toList)
  simp_rw [show ∀ a : FixedBitString k,
        ((splitN k n hn).symm (a, b)).toList = a.toList ++ b.toList from
      fun a => by simp [splitN]]
  simp_rw [show ∀ a : FixedBitString k,
        fPrime f c m (a.toList ++ b.toList) = a.toList ++ f b.toList from
      fun a => by simp only [fPrime,
        List.take_left' (a.toList_length.trans hk_def),
        List.drop_left' (a.toList_length.trans hk_def)]]
  calc ∑ a : FixedBitString k,
        PMF.uniformOfFintype (FixedBitString n) ((splitN k n hn).symm (a, b)) *
        (∑' x' : BitString, A' (a.toList ++ f b.toList) x' *
          (if fPrime f c m x' = a.toList ++ f b.toList ∧ x'.length = n then 1 else 0))
Apply uniform split and drop-prefix indicator inequality
      ≤ ∑ a : FixedBitString k,
          PMF.uniformOfFintype (FixedBitString k) a *
          PMF.uniformOfFintype (FixedBitString (n - k)) b *
          (∑' x' : BitString, A' (a.toList ++ f b.toList) x' *
            (if f (x'.drop k) = f b.toList ∧ (x'.drop k).length = n - k then 1 else 0)) := by
        apply Finset.sum_le_sum; intro a _
Replace uniform(n, symm(a,b)) with uniform(k,a) * uniform(n-k,b)
        rw [uniform_split k n hn a b]
Apply indicator inequality termwise; gcongr unfolds the tsum and product structure
        gcongr with x'
        exact indicator_fPrime_le_suffix f c m k n hn hk_def a b x'
Factor out uniform(n-k,b) and recognize the buildAdv tsum
      _ = PMF.uniformOfFintype (FixedBitString (n - k)) b *
          (∑' z' : BitString, buildAdv A' m c (f b.toList) z' *
            (if f z' = f b.toList ∧ z'.length = n - k then 1 else 0)) := by
Expand the buildAdv tsum using buildAdv_indicator_sum
        rw [buildAdv_indicator_sum A' m c f (f b.toList) (f b.toList) (n - k)]
Replace m^c with k throughout
        simp_rw [← hk_def]
Factor: ∑ a, uniform(k,a) * uniform(n-k,b) * inner = uniform(n-k,b) * ∑ a, uniform(k,a) * inner Commute uniform(k,a) past uniform(n-k,b) inside each summand, then factor out
        conv_lhs =>
          arg 2; ext a
          rw [mul_comm (PMF.uniformOfFintype (FixedBitString k) a)
                       (PMF.uniformOfFintype (FixedBitString (n - k)) b),
              mul_assoc]
Now the sum has the form ∑ a, uniform(n-k,b) * f(a); factor out the constant
        rw [← Finset.mul_sum]
# Theorem 3: $f'$ is a one-way function **Efficiency:** [`fPrime_polytime`](QuadraticOWF.html#theorem-1 "Quadratic-time one-way functions, Theorem 1"). **Hardness:** suppose `A'` inverts $f'$ with non-negligible probability. Build `A = buildAdv A' m c`; by `buildAdv_invertProb_le`, inverting $f'$ on $n$-bit inputs is no easier than inverting $f$ on $(n - m^c)$-bit inputs. Since $f$ is a OWF, `A` has negligible inversion probability; applying [`negligible_shift`](Negligible.html#lemma-1 "Negligible functions, Lemma 1") transfers this to the shifted parameter $n - m^c$, giving the required bound.
theorem fPrime_is_owf (f : BitString → BitString) (hf : IsOneWayFunction f)
    (c m : ℕ) : IsOneWayFunction (fPrime f c m) where
  polytime := fPrime_polytime f c m hf.polytime
  hard := fun A' hA' => by
    have hA : PolyTimeAdversary (buildAdv A' m c) :=
      buildAdv_polytime A' m c hA'
f is hard, so buildAdv A' m c is a negligible inverter for f
    have hnegl : Negligible (invertProb f (buildAdv A' m c)) :=
      hf.hard _ hA
shifting the security parameter by m^c preserves negligibility
    have hnegl' : Negligible (fun n => invertProb f (buildAdv A' m c) (n - m ^ c)) :=
      negligible_shift _ (m ^ c) hnegl
    intro k
    obtain ⟨N, hN⟩ := hnegl' k
need n ≥ m^c for the coupling to apply, so threshold is max N (m^c)
    refine ⟨max N (m ^ c), fun n hn => ?_⟩
    have hN_le  : N ≤ n     := (le_max_left _ _).trans hn
    have hmc_le : m ^ c ≤ n := (le_max_right _ _).trans hn
coupling: inversion probability of f' is at most that of f (shifted)
    calc invertProb (fPrime f c m) A' n
        ≤ invertProb f (buildAdv A' m c) (n - m ^ c) :=
            buildAdv_invertProb_le f A' c m n hmc_le
      _ ≤ (n : ℝ) ^ (-(k : ℤ)) :=
            hN n hN_le
# Theorem 4: Quadratic time suffices for one-way functions If any one-way function exists, then `fPrime f 2 1` is a one-way function computable in time $n^2$.
theorem owf_implies_quadratic_owf :
    (∃ f : BitString → BitString, IsOneWayFunction f) →
    (∃ g : BitString → BitString, IsOneWayFunction g) :=
  fun ⟨f, hf⟩ => ⟨fPrime f 2 1, fPrime_is_owf f hf 2 1⟩
Raw source
import Onewayf.OneWayFunction
import Onewayf.Bitstrings

-- leandown
-- [meta]
-- title = "Quadratic-time one-way functions"
-- group = "Cryptography"
-- [content]

-- If one-way functions exist, then there exists a one-way function computable
-- in time $n^2$. In other words, **quadratic time is a universal upper bound**
-- on the complexity of one-way functions: you never need more than $n^2$.
--
-- **Proof sketch.** Given a one-way function `f` computable in time $n^c$,
-- we construct a new function $f'$ that runs in time $n^2$ and is also a OWF.
-- $f'$ pads its input with a random-looking prefix, which lets us reduce any
-- attack on $f'$ back to an attack on `f`.

-- # {{definition}}: The padded function $f'$
--
-- Given `f`, exponent `c`, and block size `m`, define
--
-- $$f'(x) = x_{:m^c} \mathbin{++} f(x_{m^c:})$$
--
-- On an input $x = a \mathbin{++} b$ with $|a| = m^c$ and $|b| = m$, this
-- returns $a \mathbin{++} f(b)$, leaving the prefix untouched and applying
-- `f` only to the suffix.
--
def fPrime (f : BitString → BitString) (c m : ℕ) : BitString → BitString :=
  fun x => List.take (m ^ c) x ++ f (List.drop (m ^ c) x)

-- # {{theorem}}: $f'$ is poly-time computable
--
-- `take`/`drop` are $O(n)$; applying `f` costs $O(m^c) \leq O(n^c)$;
-- appending is $O(n)$. With $n = m^c + m$ the total is $O(n^2)$.
--
theorem fPrime_polytime (f : BitString → BitString) (c m : ℕ)
    (hf : PolyTimeComputable f) : PolyTimeComputable (fPrime f c m) :=
  .append (.take (m ^ c)) (.comp hf (.drop (m ^ c)))

-- # {{definition}}: The reduction adversary
--
-- Given an adversary `A'` that inverts $f'$, we build an adversary for `f`.
-- On input $y = f(b)$ (where $b$ has length $m$):
--
-- 1. Sample $a \leftarrow \{0,1\}^{m^c}$ uniformly.
-- 2. Run `A'` on $a \mathbin{++} y$ to get output $z$.
-- 3. Return $z_{m^c:}$ as the guess for $b$.
--
-- The idea: `A'` expects to see $f'(a \mathbin{++} b) = a \mathbin{++} f(b)$,
-- which is exactly $a \mathbin{++} y$. So we can simulate $f'$'s interface
-- for `A'` by supplying a fresh random prefix.
--
noncomputable def buildAdv (A' : BitString → PMF BitString) (m c : ℕ) :
    BitString → PMF BitString :=
  fun y => do
    let a ← uniformBitStringOfLength (m ^ c)
    let z ← A' (a ++ y)
    return List.drop (m ^ c) z

-- # {{theorem}}: `buildAdv` is a poly-time adversary
--
-- Sampling $a$ and calling `A'` on a padded input are both poly-time;
-- dropping the prefix is $O(n)$. The `uniformPad` constructor of
-- `PolyTimeAdversary` captures exactly this pattern.
--
-- **Note:** `buildAdv` is inherently randomized (step 1 samples $a$).
-- Our `PolyTimeAdversary` model uses `uniformPad` to handle this, so
-- the proof goes through directly.
--
theorem buildAdv_polytime (A' : BitString → PMF BitString) (m c : ℕ)
    (hA' : PolyTimeAdversary A') : PolyTimeAdversary (buildAdv A' m c) := by
  unfold buildAdv
  exact .uniformPad (m ^ c) hA'

-- # {{lemma}}: The coupling inequality
--
-- For any input length $n \geq m^c$, attacking $f'$ with `A'` on $n$-bit
-- inputs succeeds with probability **at most** that of attacking $f$ with
-- `buildAdv A' m c` on $(n - m^c)$-bit inputs.
--
-- **Why inequality, not equality.** Writing $x = a \mathbin{++} b$ with
-- $|a| = m^c$ and $|b| = n - m^c$, we have $f'(x) = a \mathbin{++} f(b)$.
-- For `A'` to invert $f'$ it must output $z$ satisfying **both**
-- $z_{:m^c} = a$ (correct prefix) **and** $f(z_{m^c:}) = f(b)$ (correct suffix).
-- `buildAdv` only checks the suffix condition $f(z_{m^c:}) = f(b)$, so
-- every $f'$-success implies an $f$-success but not vice versa.
--
-- **Proof sketch.**
-- 1. Reindex the LHS sum over $\{0,1\}^n$ as a double sum over
--    $\{0,1\}^{m^c} \times \{0,1\}^{n-m^c}$ via `vectorAppendEquiv`.
-- 2. Expand `buildAdv` in the RHS using `PMF.bind_apply` and swap summation
--    order (`ENNReal.tsum_comm`); convert `uniformBitStringOfLength` back to
--    a finite sum over `FixedBitString (m^c)`.
-- 3. Both sides now have the form $\sum_{a,b} \tfrac{1}{2^n} \cdot I(a,b)$.
--    The LHS indicator $\mathbf{1}[z_{:m^c} = a \wedge f(z_{m^c:}) = f(b)]$ is
--    pointwise $\leq$ the RHS indicator $\mathbf{1}[f(z_{m^c:}) = f(b)]$.
--
-- The indicator inequality at the heart of the coupling:
-- fPrime success (with length constraint) implies suffix-f success.
private lemma fPrime_success_implies_suffix {k n : ℕ} (hkn : k ≤ n)
    {f : BitString → BitString}
    (a : FixedBitString k) (b : FixedBitString (n - k)) (x' : BitString)
    (hfp : List.take k x' ++ f (List.drop k x') = a.toList ++ f b.toList)
    (hlen : x'.length = n) :
    f (List.drop k x') = f b.toList ∧ (List.drop k x').length = n - k := by
  -- Since x'.length = n ≥ k, the take has exactly k elements — matching a.toList.
  -- Concatenation injectivity then gives f(x'.drop k) = f(b.toList).
  have htake_len : (List.take k x').length = k := by
    simp [List.length_take, Nat.min_eq_left (hlen ▸ hkn)]
  have hinj := List.append_inj hfp (htake_len.trans a.toList_length.symm)
  exact ⟨hinj.2, by simp [List.length_drop, hlen]⟩

-- Bijection splitting a length-n bit string into a k-bit prefix and (n-k)-bit suffix.
private def splitN (k n : ℕ) (hkn : k ≤ n) :
    FixedBitString n ≃ FixedBitString k × FixedBitString (n - k) where
  toFun v :=
    ⟨⟨v.toList.take k, by have h := v.toList_length; simp [List.length_take]; omega⟩,
     ⟨v.toList.drop k, by have h := v.toList_length; simp [List.length_drop, h]⟩⟩
  invFun p :=
    ⟨p.1.toList ++ p.2.toList, by
      have h1 : p.1.toList.length = k := p.1.toList_length
      have h2 : p.2.toList.length = n - k := p.2.toList_length
      simp only [List.length_append, h1, h2, Nat.add_sub_cancel' hkn]⟩
  left_inv v := List.Vector.ext (by simp [List.take_append_drop])
  right_inv p := by
    rcases p with ⟨a, b⟩
    apply Prod.ext <;> apply List.Vector.ext
    all_goals simp [a.toList_length]

-- The ENNReal sum representing `invertProb f F n` is ≤ 1, hence finite.
private lemma invertProb_ENNReal_le_one (f : BitString → BitString)
    (F : BitString → PMF BitString) (n : ℕ) :
    (∑ x : FixedBitString n,
      PMF.uniformOfFintype (FixedBitString n) x *
      (∑' x' : BitString, F (f x.toList) x' *
        (if f x' = f x.toList ∧ x'.length = n then (1 : ENNReal) else 0))) ≤ 1 := by
  -- Bound each term: uniform(n,x) * (inner sum) ≤ uniform(n,x) * 1
  -- then sum to 1.
  calc (∑ x : FixedBitString n,
        PMF.uniformOfFintype (FixedBitString n) x *
        (∑' x' : BitString, F (f x.toList) x' *
          (if f x' = f x.toList ∧ x'.length = n then (1 : ENNReal) else 0)))
      ≤ ∑ x : FixedBitString n, PMF.uniformOfFintype (FixedBitString n) x := by
        apply Finset.sum_le_sum; intro x _
        -- bound inner sum by 1, then use uniform ≤ 1 is absorbed by mul_one
        calc PMF.uniformOfFintype (FixedBitString n) x *
              (∑' x', F (f x.toList) x' * (if f x' = f x.toList ∧ x'.length = n then 1 else 0))
            ≤ PMF.uniformOfFintype (FixedBitString n) x * 1 := by
              gcongr
              calc ∑' x', F (f x.toList) x' * (if f x' = f x.toList ∧ x'.length = n then 1 else 0)
                  ≤ ∑' x', F (f x.toList) x' * 1 := by
                    apply ENNReal.tsum_le_tsum; intro x'
                    gcongr; split_ifs <;> norm_num
                _ = 1 := by simp [PMF.tsum_coe]
          _ = PMF.uniformOfFintype (FixedBitString n) x := mul_one _
    _ = 1 := by
        have htotal := (PMF.uniformOfFintype (FixedBitString n)).tsum_coe
        rw [tsum_fintype (L := .unconditional _)] at htotal
        exact htotal

private lemma invertProb_ENNReal_ne_top (f : BitString → BitString)
    (F : BitString → PMF BitString) (n : ℕ) :
    (∑ x : FixedBitString n,
      PMF.uniformOfFintype (FixedBitString n) x *
      (∑' x' : BitString, F (f x.toList) x' *
        (if f x' = f x.toList ∧ x'.length = n then (1 : ENNReal) else 0))) ≠ ⊤ :=
  ne_top_of_le_ne_top ENNReal.one_ne_top (invertProb_ENNReal_le_one f F n)

-- The uniform measure splits over the splitN bijection:
-- uniform(n, splitN.symm (a, b)) = uniform(k, a) * uniform(n-k, b)
private lemma uniform_split (k n : ℕ) (hkn : k ≤ n)
    (a : FixedBitString k) (b : FixedBitString (n - k)) :
    PMF.uniformOfFintype (FixedBitString n) ((splitN k n hkn).symm (a, b)) =
    PMF.uniformOfFintype (FixedBitString k) a *
    PMF.uniformOfFintype (FixedBitString (n - k)) b := by
  -- Both sides equal 1/2^n; show cardinality factorizes then split the inverse
  simp only [PMF.uniformOfFintype_apply]
  rw [show Fintype.card (FixedBitString n) =
        Fintype.card (FixedBitString k) * Fintype.card (FixedBitString (n - k)) from by
    simp [← pow_add, Nat.add_sub_cancel' hkn]]
  push_cast
  -- split (a*b)⁻¹ = a⁻¹ * b⁻¹; requires nonzero conditions since ENNReal isn't a group
  rw [ENNReal.mul_inv (Or.inl (by positivity)) (Or.inr (by positivity))]

-- fPrime success at (a ++ b) implies f-suffix success (dropping prefix constraint)
private lemma indicator_fPrime_le_suffix (f : BitString → BitString) (c m k n : ℕ)
    (hkn : k ≤ n) (hk : m ^ c = k)
    (a : FixedBitString k) (b : FixedBitString (n - k)) (x' : BitString) :
    (if fPrime f c m x' = a.toList ++ f b.toList ∧ x'.length = n then (1 : ENNReal) else 0) ≤
    (if f (x'.drop k) = f b.toList ∧ (x'.drop k).length = n - k then 1 else 0) := by
  -- Case split on whether fPrime succeeds
  by_cases h : fPrime f c m x' = a.toList ++ f b.toList ∧ x'.length = n
  · simp only [if_pos h]
    -- Unfold fPrime in h.1 (using hk : m^c = k) so fPrime_success_implies_suffix applies
    simp only [fPrime, hk] at h
    have hsuf := fPrime_success_implies_suffix hkn a b x' h.1 h.2
    rw [if_pos hsuf]
  · simp [if_neg h]

-- buildAdv applied to an output indicator integrates out the prefix sample:
-- ∑' z', buildAdv A' m c y z' * [f z' = fy ∧ |z'|=r]
-- = ∑ va, uniform(k,va) * ∑' z, A'(va.toList ++ y) z * [f(z.drop k)=fy ∧ |z.drop k|=r]
private lemma buildAdv_indicator_sum (A' : BitString → PMF BitString) (m c : ℕ)
    (f : BitString → BitString) (y fy : BitString) (r : ℕ) :
    (∑' z' : BitString, buildAdv A' m c y z' *
      (if f z' = fy ∧ z'.length = r then 1 else 0)) =
    ∑ va : FixedBitString (m ^ c),
      PMF.uniformOfFintype (FixedBitString (m ^ c)) va *
      (∑' z : BitString, A' (va.toList ++ y) z *
        (if f (z.drop (m ^ c)) = fy ∧ (z.drop (m ^ c)).length = r then 1 else 0)) := by
  -- Step 1: express buildAdv as a Finset.sum over FixedBitString k.
  -- uniformBitStringOfLength k = (uniformOfFintype k).map toList, so PMF.bind_map applies.
  have hbind : ∀ z' : BitString, buildAdv A' m c y z' =
      ∑ va : FixedBitString (m ^ c), PMF.uniformOfFintype (FixedBitString (m ^ c)) va *
      (∑' z : BitString, A' (va.toList ++ y) z * if z.drop (m ^ c) = z' then 1 else 0) := by
    intro z'
    -- convert do-notation to bind; uniformBitStringOfLength = uniformOfFintype.map toList
    change (PMF.bind (uniformBitStringOfLength (m ^ c)) fun a =>
            PMF.bind (A' (a ++ y)) fun z => PMF.pure (z.drop (m ^ c))) z' = _
    simp only [uniformBitStringOfLength, PMF.bind_map, PMF.bind_apply,
               tsum_fintype (L := .unconditional _)]
    -- now: ∑ va, uniformOfFintype k va * (inner bind applied to z')
    congr 1; ext va; congr 1
    simp only [Function.comp, PMF.bind_apply, PMF.pure_apply]
    congr 1; ext z; simp [eq_comm]
  simp_rw [hbind, Finset.sum_mul]
  -- Step 2: swap ∑' z' and ∑ va by going via ∑' va (tsum_fintype round-trip + tsum_comm)
  simp_rw [← tsum_fintype (L := .unconditional _)]
  rw [ENNReal.tsum_comm]
  simp_rw [tsum_fintype (L := .unconditional _)]
  -- Step 3: for each va, factor out uniform(k,va) and prove inner Fubini equality
  apply Finset.sum_congr rfl; intro va _
  -- re-associate products so uniform(k,va) is the head factor inside ∑' z'
  simp_rw [mul_assoc]
  -- factor uniform(k,va) out of the tsum on both sides:
  -- tsum_mul_left (forward): ∑' z', a * f z' = a * ∑' z', f z'
  simp_rw [ENNReal.tsum_mul_left]
  -- use suffices to state the inner tsum equality explicitly and rewrite; this avoids
  -- applying congr 1 directly on c * tsum = c * tsum, which would peel tsum to fun equality
  suffices h : ∑' z', (∑' z, A' (va.toList ++ y) z * if z.drop (m ^ c) = z' then 1 else 0) *
      (if f z' = fy ∧ z'.length = r then 1 else 0) =
      ∑' z, A' (va.toList ++ y) z *
        (if f (z.drop (m ^ c)) = fy ∧ (z.drop (m ^ c)).length = r then 1 else 0) by rw [h]
  -- Step 4: Fubini — push outer indicator inside, swap sums, integrate out z'
  -- push [f z' = fy ∧ ...] inside the inner tsum: (∑' z, g z z') * c z' → ∑' z, g z z' * c z'
  simp_rw [← ENNReal.tsum_mul_right]
  -- swap ∑' z' and ∑' z
  rw [ENNReal.tsum_comm]
  -- for each z, reassociate, factor out A'(va++y) z (constant in z'), then integrate out z'
  apply tsum_congr; intro z
  simp_rw [mul_assoc]
  -- tsum_mul_left (forward): ∑' z', A'(va++y) z * f z' = A'(va++y) z * ∑' z', f z'
  rw [ENNReal.tsum_mul_left]
  congr 1
  -- ∑' z', [z.drop k = z'] * [f z' = fy ∧ z'.length = r] = [f (z.drop k) = fy ∧ ...]
  -- the sum has a single nonzero term at z' = z.drop k
  rw [tsum_eq_single (z.drop (m ^ c)) (by intro z' hz'; simp [Ne.symm hz'])]
  simp

lemma buildAdv_invertProb_le (f : BitString → BitString)
    (A' : BitString → PMF BitString) (c m n : ℕ) (hn : m ^ c ≤ n) :
    invertProb (fPrime f c m) A' n ≤
    invertProb f (buildAdv A' m c) (n - m ^ c) := by
  set k := m ^ c with hk_def
  simp only [invertProb]
  -- Reduce the ℝ inequality to an ENNReal inequality (both sides are finite).
  rw [ENNReal.toReal_le_toReal
        (invertProb_ENNReal_ne_top (fPrime f c m) A' n)
        (invertProb_ENNReal_ne_top f (buildAdv A' m c) (n - k))]
  -- Reindex the LHS sum over FixedBitString n as a sum over k × (n-k) pairs via splitN
  -- Equiv.sum_comp e f : ∑ x, f (e x) = ∑ y, f y, so ← gives the reindexing direction
  rw [← Equiv.sum_comp (splitN k n hn).symm]
  -- Flatten the product sum and swap to ∑ b, ∑ a order
  rw [Fintype.sum_prod_type, Finset.sum_comm]
  -- Bound each b-term
  apply Finset.sum_le_sum; intro b _
  -- Simplify (splitN.symm (a,b)).toList = a.toList ++ b.toList
  -- and fPrime(a.toList ++ b.toList) = a.toList ++ f(b.toList)
  simp_rw [show ∀ a : FixedBitString k,
        ((splitN k n hn).symm (a, b)).toList = a.toList ++ b.toList from
      fun a => by simp [splitN]]
  simp_rw [show ∀ a : FixedBitString k,
        fPrime f c m (a.toList ++ b.toList) = a.toList ++ f b.toList from
      fun a => by simp only [fPrime,
        List.take_left' (a.toList_length.trans hk_def),
        List.drop_left' (a.toList_length.trans hk_def)]]
  calc ∑ a : FixedBitString k,
        PMF.uniformOfFintype (FixedBitString n) ((splitN k n hn).symm (a, b)) *
        (∑' x' : BitString, A' (a.toList ++ f b.toList) x' *
          (if fPrime f c m x' = a.toList ++ f b.toList ∧ x'.length = n then 1 else 0))
      -- Apply uniform split and drop-prefix indicator inequality
      ≤ ∑ a : FixedBitString k,
          PMF.uniformOfFintype (FixedBitString k) a *
          PMF.uniformOfFintype (FixedBitString (n - k)) b *
          (∑' x' : BitString, A' (a.toList ++ f b.toList) x' *
            (if f (x'.drop k) = f b.toList ∧ (x'.drop k).length = n - k then 1 else 0)) := by
        apply Finset.sum_le_sum; intro a _
        -- Replace uniform(n, symm(a,b)) with uniform(k,a) * uniform(n-k,b)
        rw [uniform_split k n hn a b]
        -- Apply indicator inequality termwise; gcongr unfolds the tsum and product structure
        gcongr with x'
        exact indicator_fPrime_le_suffix f c m k n hn hk_def a b x'
      -- Factor out uniform(n-k,b) and recognize the buildAdv tsum
      _ = PMF.uniformOfFintype (FixedBitString (n - k)) b *
          (∑' z' : BitString, buildAdv A' m c (f b.toList) z' *
            (if f z' = f b.toList ∧ z'.length = n - k then 1 else 0)) := by
        -- Expand the buildAdv tsum using buildAdv_indicator_sum
        rw [buildAdv_indicator_sum A' m c f (f b.toList) (f b.toList) (n - k)]
        -- Replace m^c with k throughout
        simp_rw [← hk_def]
        -- Factor: ∑ a, uniform(k,a) * uniform(n-k,b) * inner = uniform(n-k,b) * ∑ a, uniform(k,a) * inner
        -- Commute uniform(k,a) past uniform(n-k,b) inside each summand, then factor out
        conv_lhs =>
          arg 2; ext a
          rw [mul_comm (PMF.uniformOfFintype (FixedBitString k) a)
                       (PMF.uniformOfFintype (FixedBitString (n - k)) b),
              mul_assoc]
        -- Now the sum has the form ∑ a, uniform(n-k,b) * f(a); factor out the constant
        rw [← Finset.mul_sum]


-- # {{theorem}}: $f'$ is a one-way function
--
-- **Efficiency:** `fPrime_polytime`.
--
-- **Hardness:** suppose `A'` inverts $f'$ with non-negligible probability.
-- Build `A = buildAdv A' m c`; by `buildAdv_invertProb_le`, inverting $f'$
-- on $n$-bit inputs is no easier than inverting $f$ on $(n - m^c)$-bit inputs.
-- Since $f$ is a OWF, `A` has negligible inversion probability; applying
-- `negligible_shift` transfers this to the shifted parameter $n - m^c$,
-- giving the required bound.
--
theorem fPrime_is_owf (f : BitString → BitString) (hf : IsOneWayFunction f)
    (c m : ℕ) : IsOneWayFunction (fPrime f c m) where
  polytime := fPrime_polytime f c m hf.polytime
  hard := fun A' hA' => by
    have hA : PolyTimeAdversary (buildAdv A' m c) :=
      buildAdv_polytime A' m c hA'
    -- f is hard, so buildAdv A' m c is a negligible inverter for f
    have hnegl : Negligible (invertProb f (buildAdv A' m c)) :=
      hf.hard _ hA
    -- shifting the security parameter by m^c preserves negligibility
    have hnegl' : Negligible (fun n => invertProb f (buildAdv A' m c) (n - m ^ c)) :=
      negligible_shift _ (m ^ c) hnegl
    intro k
    obtain ⟨N, hN⟩ := hnegl' k
    -- need n ≥ m^c for the coupling to apply, so threshold is max N (m^c)
    refine ⟨max N (m ^ c), fun n hn => ?_⟩
    have hN_le  : N ≤ n     := (le_max_left _ _).trans hn
    have hmc_le : m ^ c ≤ n := (le_max_right _ _).trans hn
    -- coupling: inversion probability of f' is at most that of f (shifted)
    calc invertProb (fPrime f c m) A' n
        ≤ invertProb f (buildAdv A' m c) (n - m ^ c) :=
            buildAdv_invertProb_le f A' c m n hmc_le
      _ ≤ (n : ℝ) ^ (-(k : ℤ)) :=
            hN n hN_le

-- # {{theorem}}: Quadratic time suffices for one-way functions
--
-- If any one-way function exists, then `fPrime f 2 1` is a one-way function
-- computable in time $n^2$.
--
theorem owf_implies_quadratic_owf :
    (∃ f : BitString → BitString, IsOneWayFunction f) →
    (∃ g : BitString → BitString, IsOneWayFunction g) :=
  fun ⟨f, hf⟩ => ⟨fPrime f 2 1, fPrime_is_owf f hf 2 1⟩