Main.lean (15495B)
1 import Batteries.Data.Rat 2 3 -------------------------------------------------------------------------------- 4 -- Fin util 5 6 def finSum (i : Fin (m + n)) : Sum (Fin m) (Fin n) := 7 -- Commute m and n in i's type to please Fin.subNat. 8 have fin_comm : Fin (m + n) = Fin (n + m) := by rw [Nat.add_comm] 9 let i : Fin (n + m) := cast fin_comm i 10 11 if h : i.val < m then 12 Sum.inl (Fin.castLT i h) 13 else 14 have m_le_i : m <= i.val := Nat.le_of_not_gt h 15 Sum.inr (Fin.subNat m i m_le_i) 16 17 def finProd (i : Fin (m * n)) : Prod (Fin m) (Fin n) := 18 let mul_gt_zero_imp_gt_zero {m n : Nat} : (m * n > 0) -> (m > 0) := by 19 intro 20 rw [<-Nat.zero_div n] 21 apply Nat.div_lt_of_lt_mul 22 rw [Nat.mul_comm] 23 assumption 24 25 have mn_gt_zero : m * n > 0 := Fin.pos i 26 have m_gt_zero : m > 0 := mul_gt_zero_imp_gt_zero mn_gt_zero 27 let q := Fin.ofNat' (i.val / n) m_gt_zero 28 29 have nm_gt_zero : n * m > 0 := by rw [Nat.mul_comm]; exact mn_gt_zero 30 have n_gt_zero : n > 0 := mul_gt_zero_imp_gt_zero nm_gt_zero 31 let r := Fin.ofNat' (i.val % n) n_gt_zero 32 33 (q, r) 34 35 -------------------------------------------------------------------------------- 36 -- BitVec util 37 38 def bitVecFoldl (f : α -> Bool -> α) (bv : BitVec n) (init : α) : α := 39 Fin.foldl n (fun acc i => f acc (bv.getLsb i)) init 40 41 def bitVecEnuml (n : Nat) (f : α -> BitVec n -> α) (init : α) : α := 42 let loop acc i := f acc (BitVec.ofFin i) 43 Fin.foldl (2^n) loop init 44 45 def bitVecChoose (n k : Nat) (f : α -> BitVec n -> α) (acc : α) : α := 46 match n, k with 47 | 0, 0 => f acc BitVec.nil 48 | 0, _ + 1 => acc 49 | _ + 1, 0 => f acc 0 50 | n' + 1, k' + 1 => 51 let f' (msb : Bool) (acc : α) (bv' : BitVec n') : α := 52 let bv := BitVec.cons msb bv' 53 f acc bv 54 let acc := bitVecChoose n' k (f' false) acc 55 bitVecChoose n' k' (f' true) acc 56 57 def bitVecPopCount (bv : BitVec n) : Nat := 58 bitVecFoldl (fun acc bit => acc + bit.toNat) bv 0 59 60 -------------------------------------------------------------------------------- 61 -- Notation 62 63 instance : HPow Type Nat Type where 64 hPow Y n := Fin n -> Y 65 66 def vec (coords : List Y) : Y ^ coords.length := coords.get 67 68 -- SMul represents the ability to use "⋅" as an infix operator. 69 class SMul (α β : Type) where 70 smul : α -> β -> β 71 72 -- TODO: Is 75 an appropriate precedence? 73 -- How do I figure out precedences of standard Lean operators? 74 infixr:75 " ⋅ " => SMul.smul 75 76 class Inv (α : Type) where 77 inv : α -> α 78 79 -------------------------------------------------------------------------------- 80 -- Additive (commutative) semigroups 81 82 class AddSemigroup (A : Type) where 83 add : A -> A -> A 84 85 instance [AddSemigroup A] : Add A where 86 add := AddSemigroup.add 87 88 instance [AddSemigroup A] [AddSemigroup B] : AddSemigroup (A × B) where 89 add x y := (x.1 + y.1, x.2 + y.2) 90 91 instance AddSemigroup.instFun {X : Type} [AddSemigroup A] : AddSemigroup (X -> A) where 92 add f g := fun x => f x + g x 93 94 instance [AddSemigroup A] {n : Nat} : AddSemigroup (A ^ n) := AddSemigroup.instFun 95 96 -------------------------------------------------------------------------------- 97 -- Additive (commutative) monoids 98 99 class AddMonoid (A : Type) extends AddSemigroup A where 100 zero : A 101 102 instance [AddMonoid A] : OfNat A 0 where 103 ofNat := AddMonoid.zero 104 105 instance [AddMonoid A] [AddMonoid B] : AddMonoid (A × B) where 106 zero := (0, 0) 107 108 instance AddMonoid.instFun {X : Type} [AddMonoid A] : AddMonoid (X -> A) where 109 zero := fun _ => 0 110 111 instance [AddMonoid A] {n : Nat} : AddMonoid (A ^ n) := AddMonoid.instFun 112 113 -------------------------------------------------------------------------------- 114 -- Additive (commutative) groups 115 116 class AddGroup (A : Type) extends AddMonoid A where 117 sub : A -> A -> A 118 neg : A -> A := fun x => sub zero x 119 120 instance [AddGroup A] : Sub A where 121 sub := AddGroup.sub 122 instance [AddGroup A] : Neg A where 123 neg := AddGroup.neg 124 125 instance [AddGroup A] [AddGroup B] : AddGroup (A × B) where 126 sub x y := (x.1 - y.1, x.2 - y.2) 127 128 instance AddGroup.instFun {X : Type} [AddGroup A] : AddGroup (X -> A) where 129 sub f g := fun x => f x - g x 130 131 instance [AddGroup A] {n : Nat} : AddGroup (A ^ n) := AddGroup.instFun 132 133 -------------------------------------------------------------------------------- 134 -- Multiplicative semigroups, non-commutative or commutative 135 136 class MulSemigroup (A : Type) where 137 mul : A -> A -> A 138 class MulCommSemigroup (A : Type) extends MulSemigroup A where 139 140 instance [MulSemigroup A] : Mul A where 141 mul := MulSemigroup.mul 142 143 instance [MulSemigroup A] [MulSemigroup B] : MulSemigroup (A × B) where 144 mul x y := (x.1 * y.1, x.2 * y.2) 145 instance [MulCommSemigroup A] [MulCommSemigroup B] : MulCommSemigroup (A × B) where 146 147 instance MulSemigroup.instFun {X : Type} [MulSemigroup A] : MulSemigroup (X -> A) where 148 mul f g := fun x => f x * g x 149 instance MulCommSemigroup.instFun {X : Type} [MulCommSemigroup A] : MulCommSemigroup (X -> A) where 150 151 instance [MulSemigroup A] {n : Nat} : MulSemigroup (A ^ n) := MulSemigroup.instFun 152 instance [MulCommSemigroup A] {n : Nat} : MulCommSemigroup (A ^ n) := MulCommSemigroup.instFun 153 154 -------------------------------------------------------------------------------- 155 -- Multiplicative monoids, non-commutative or commutative 156 157 class MulMonoid (A : Type) extends MulSemigroup A where 158 one : A 159 class MulCommMonoid (A : Type) extends MulMonoid A, MulCommSemigroup A where 160 161 instance [MulMonoid A] : OfNat A 1 where 162 ofNat := MulMonoid.one 163 164 instance [MulMonoid A] [MulMonoid B] : MulMonoid (A × B) where 165 one := (1, 1) 166 instance [MulCommMonoid A] [MulCommMonoid B] : MulCommMonoid (A × B) where 167 168 instance MulMonoid.instFun {X : Type} [MulMonoid A] : MulMonoid (X -> A) where 169 one := fun _ => 1 170 instance MulCommMonoid.instFun {X : Type} [MulCommMonoid A] : MulCommMonoid (X -> A) where 171 172 instance [MulMonoid A] {n : Nat} : MulMonoid (A ^ n) := MulMonoid.instFun 173 instance [MulCommMonoid A] {n : Nat} : MulCommMonoid (A ^ n) := MulCommMonoid.instFun 174 175 -------------------------------------------------------------------------------- 176 -- Multiplicative groups, non-commutative or commutative 177 178 class MulGroup (A : Type) extends MulMonoid A where 179 div : A -> A -> A 180 inv : A -> A := fun x => div one x 181 class MulCommGroup (A : Type) extends MulGroup A, MulCommMonoid A where 182 183 instance [MulGroup A] : Div A where 184 div := MulGroup.div 185 instance [MulGroup A] : Inv A where 186 inv := MulGroup.inv 187 188 instance [MulGroup A] [MulGroup B] : MulGroup (A × B) where 189 div x y := (x.1 / y.1, x.2 / y.2) 190 instance [MulCommGroup A] [MulCommGroup B] : MulCommGroup (A × B) where 191 192 instance MulGroup.instFun {X : Type} [MulGroup A] : MulGroup (X -> A) where 193 div f g := fun x => f x / g x 194 instance MulCommGroup.instFun {X : Type} [MulCommGroup A] : MulCommGroup (X -> A) where 195 196 instance [MulGroup A] {n : Nat} : MulGroup (A ^ n) := MulGroup.instFun 197 instance [MulCommGroup A] {n : Nat} : MulCommGroup (A ^ n) := MulCommGroup.instFun 198 199 -------------------------------------------------------------------------------- 200 -- Commutative and unital rings 201 202 class Ring (R : Type) extends AddGroup R, MulCommMonoid R 203 204 instance : Ring Rat where 205 zero := 0 206 add := Rat.add 207 sub := Rat.sub 208 one := 1 209 mul := Rat.mul 210 211 -- TODO: PolyRing 212 213 -------------------------------------------------------------------------------- 214 -- Modules 215 216 -- R is an outParam, because usually V will only be a module over one ring, and 217 -- this allows us to leave R implicit in many cases. 218 class Module (R : outParam Type) [Ring R] (M : Type) [AddGroup M] where 219 smul : R -> M -> M 220 221 instance [Ring R] [AddGroup M] [Module R M] : SMul R M where 222 smul := Module.smul 223 224 instance [Ring R] : Module R R where 225 smul := MulSemigroup.mul 226 227 instance [Ring R] [AddGroup M] [Module R M] [AddGroup N] [Module R N] : Module R (M × N) where 228 smul a p := (a ⋅ p.1, a ⋅ p.2) 229 230 instance Module.instFun [Ring R] [AddGroup M] [Module R M] : Module R (X -> M) where 231 smul a f := fun x => a ⋅ f x 232 233 instance [Ring R] [AddGroup M] [Module R M] {n : Nat} : Module R (M ^ n) := Module.instFun 234 235 def Coordinates [Ring R] (M : Type) [AddGroup M] [Module R M] (n : Nat) : Type := 236 Fin n -> R 237 238 instance [Ring R] [ToString R] [AddGroup M] [Module R M] : ToString (Coordinates M n) where 239 toString coords := 240 match n with 241 | 1 => toString (coords 0) 242 | n => 243 let consCoordString i acc := toString (coords i) :: acc 244 s!"({String.intercalate ", " $ Fin.foldr n consCoordString []})" 245 246 structure FiniteBasis [Ring R] (M : Type) [AddGroup M] [Module R M] (n : Nat) where 247 basis : Fin n -> M 248 coords : M -> Coordinates M n 249 250 -- XXX: Calling this "vector" is maybe a bit weird. 251 def FiniteBasis.vector [Ring R] [AddGroup M] [Module R M] 252 (b : FiniteBasis M n) (coords : Coordinates M n) : M := 253 Fin.foldl n (fun acc i => acc + coords i ⋅ b.basis i) 0 254 255 def FiniteBasis.ring (R : Type) [Ring R] : FiniteBasis R 1 := 256 ⟨fun 0 => 1, fun a 0 => a⟩ 257 258 def FiniteBasis.mul [Ring R] [AddGroup M] [Module R M] (b : FiniteBasis M m) 259 [AddGroup N] [Module R N] (c : FiniteBasis N n) : FiniteBasis (M × N) (m + n) := 260 let basis i := 261 match finSum i with 262 | .inl i => (b.basis i, 0) 263 | .inr i => (0, c.basis i) 264 let coords v i := 265 match finSum i with 266 | .inl i => b.coords v.1 i 267 | .inr i => c.coords v.2 i 268 ⟨basis, coords⟩ 269 270 def FiniteBasis.pow [Ring R] [AddGroup M] [Module R M] (b : FiniteBasis M m) 271 (n : Nat) : FiniteBasis (M ^ n) (m * n) := 272 let basis i := 273 let ⟨iq, ir⟩ := finProd i 274 fun j => if j = ir then b.basis iq else 0 275 let coords v i := 276 let ⟨iq, ir⟩ := finProd i 277 b.coords (v ir) iq 278 ⟨basis, coords⟩ 279 280 -- OrthoQuadraticForm b is a quadratic form with respect to which the finite 281 -- basis b is orthogonal. 282 structure OrthoQuadraticForm [Ring R] [AddGroup M] [Module R M] (_ : FiniteBasis M n) where 283 evalBasis : Fin n -> R 284 285 def OrthoQuadraticForm.eval [Ring R] [AddGroup M] [Module R M] 286 {b : FiniteBasis M n} (q : OrthoQuadraticForm b) (v : M) : R := 287 let vCoords := b.coords v 288 -- q v = q (∑ vCoords i ⋅ b.basis i) 289 -- = ∑ q (vCoords i ⋅ b.basis i) (orthogonality of b wrt q) 290 -- = ∑ (vCoords i)^2 * q (b.basis i) (def. of quadratic form) 291 -- = ∑ (vCoords i)^2 * q.evalBasis i 292 Fin.foldl n (fun acc i => acc + vCoords i * vCoords i * q.evalBasis i) 0 293 294 -------------------------------------------------------------------------------- 295 -- Algebras 296 297 class Algebra (R : outParam Type) [Ring R] (A : Type) [AddGroup A] [MulMonoid A] extends Module R A 298 299 instance [Ring R] : Algebra R R where 300 301 -- TODO: More Algebra instances 302 303 -------------------------------------------------------------------------------- 304 -- Geometric algebra generation from vector spaces with a quadratic form and 305 -- an orthogonal finite basis 306 307 structure GeometricAlgebra [Ring R] [AddGroup M] [Module R M] 308 {b : FiniteBasis M n} (q : OrthoQuadraticForm b) : Type where 309 coords : BitVec n -> R 310 311 instance [Ring R] [DecidableEq R] [ToString R] [AddGroup M] [Module R M] 312 {b : FiniteBasis M n} {q : OrthoQuadraticForm b} : ToString (GeometricAlgebra q) where 313 toString v := 314 let bldToString (bld : BitVec n) : String := 315 let indexToString (i : Fin n) acc := 316 if bld.getLsb i then toString i :: acc else acc 317 let indexStrings := Fin.foldr n indexToString [] 318 let sep := if n > 9 then "," else "" 319 "{" ++ String.intercalate sep indexStrings ++ "}" 320 321 let termToString acc bld := 322 if v.coords bld != 0 then 323 let coeffString := toString (v.coords bld) 324 let termString := 325 if bld != 0 then 326 s!"{coeffString} * {bldToString bld}" 327 else 328 coeffString 329 termString :: acc 330 else 331 acc 332 let termStrings := Fin.foldl (n + 1) 333 (fun acc k => bitVecChoose n k.val termToString acc) [] 334 335 if termStrings.length > 0 then 336 String.intercalate " + " (List.reverse termStrings) 337 else 338 "0" 339 340 instance [Ring R] [AddGroup M] [Module R M] {b : FiniteBasis M n} 341 {q : OrthoQuadraticForm b} : AddGroup (GeometricAlgebra q) where 342 zero := ⟨0⟩ 343 add v w := ⟨v.coords + w.coords⟩ 344 sub v w := ⟨v.coords - w.coords⟩ 345 346 instance [Ring R] [AddGroup M] [Module R M] {b : FiniteBasis M n} 347 {q : OrthoQuadraticForm b} : Module R (GeometricAlgebra q) where 348 smul a v := ⟨a ⋅ v.coords⟩ 349 350 -- Multiplication algorithm: 351 -- Given v w : GeometricAlgebra q and bld : BitVec n, the bld-th component 352 -- of the product v * w is a sum of 2^n terms, where each term is the product of 353 -- the vbld-th component of v and the wbld-th component of w, for all 354 -- vbld wbld : BitVec n that multiply (XOR) to bld (modulo a scalar). For 355 -- example, if we take n = 3 and denote blades as bit vectors, then we can write 356 -- generic vectors v and w as follows: 357 -- v0 * 000 + v1 * 001 + v2 * 010 + v3 * 011 + v4 * 100 + v5 * 101 + v6 * 110 + v7 * 111 358 -- w0 * 000 + w1 * 001 + w2 * 010 + w3 * 011 + w4 * 100 + w5 * 101 + w6 * 110 + w7 * 111 359 -- Now, if we want the component of v * w corresponding to blade bld = 010, we 360 -- consider the following pairs of vbld and wbld: 361 -- vbld wbld 362 -- 000 010 363 -- 001 011 364 -- 010 000 365 -- 011 001 366 -- 100 110 367 -- 101 111 368 -- 110 100 369 -- 111 101 370 -- (These are just all 2^3 3-bit numbers, but with one column (wbld in this 371 -- case) XORed with bld.) Then we multiply the coefficients corresponding to 372 -- each row and sum the results. Additionally, we must calculate a sign from 373 -- vbld and wbld (in accordance with the equation ei * ej = - ej * ei for all 374 -- basis vectors ei != ej), and we must calculuate a scalar by multiplying the 375 -- "squares" of each basis vector shared by vbld and wbld, where "square" means 376 -- apply the quadratic form. 377 378 instance [Ring R] [AddGroup M] [Module R M] {b : FiniteBasis M n} 379 {q : OrthoQuadraticForm b} : MulMonoid (GeometricAlgebra q) where 380 one := ⟨fun bld => if bld = 0 then 1 else 0⟩ 381 mul v w := 382 let sign (x y : BitVec n) : R := 383 let addSwaps (acc : Nat) (i : Fin n) : Nat := 384 acc + bitVecPopCount (x &&& (y <<< (i.val + 1))) 385 let nSwaps := Fin.foldl n addSwaps 0 386 if nSwaps % 2 = 0 then 1 else -1 387 388 let scalar (x y : BitVec n) : R := 389 let z := x &&& y 390 let mulSquares (acc : R) (i : Fin n) : R := 391 if z.getLsb i then acc * q.evalBasis i else acc 392 Fin.foldl n mulSquares 1 393 394 let coords (bld : BitVec n) : R := 395 let addTerm (acc : R) (vbld : BitVec n) : R := 396 let wbld : BitVec n := vbld ^^^ bld 397 acc + sign vbld wbld * scalar vbld wbld * v.coords vbld * w.coords wbld 398 bitVecEnuml n addTerm 0 399 ⟨coords⟩ 400 401 instance [Ring R] [AddGroup M] [Module R M] {b : FiniteBasis M n} 402 {q : OrthoQuadraticForm b} : Algebra R (GeometricAlgebra q) where 403 404 -------------------------------------------------------------------------------- 405 -- Main 406 407 abbrev R : Type := Rat 408 abbrev n : Nat := 3 409 abbrev M : Type := R ^ n 410 abbrev b : FiniteBasis M n := FiniteBasis.pow (FiniteBasis.ring R) n 411 abbrev q : OrthoQuadraticForm b := ⟨ 412 fun 413 | 0 => 1 414 | 1 => 1 415 | 2 => 1 416 ⟩ 417 abbrev G : Type := GeometricAlgebra q 418 419 def main : IO Unit := do 420 let v : G := ⟨ 421 fun 422 | 0 => 0 423 | 1 => 2 424 | 2 => 0 425 | 3 => 0 426 | 4 => 0 427 | 5 => 0 428 | 6 => 0 429 | 7 => 0 430 ⟩ 431 let w : G := ⟨ 432 fun 433 | 0 => 0 434 | 1 => 0 435 | 2 => 3 436 | 3 => 0 437 | 4 => 0 438 | 5 => 0 439 | 6 => 0 440 | 7 => 0 441 ⟩ 442 let x : G := v * w 443 IO.println v 444 IO.println w 445 IO.println x