symgeoalg

symbolic geometric algebra calculations
git clone git://git.rr3.xyz/symgeoalg
Log | Files | Refs | README | LICENSE

Main.lean (15495B)


      1 import Batteries.Data.Rat
      2 
      3 --------------------------------------------------------------------------------
      4 -- Fin util
      5 
      6 def finSum (i : Fin (m + n)) : Sum (Fin m) (Fin n) :=
      7   -- Commute m and n in i's type to please Fin.subNat.
      8   have fin_comm : Fin (m + n) = Fin (n + m) := by rw [Nat.add_comm]
      9   let i : Fin (n + m) := cast fin_comm i
     10 
     11   if h : i.val < m then
     12     Sum.inl (Fin.castLT i h)
     13   else
     14     have m_le_i : m <= i.val := Nat.le_of_not_gt h
     15     Sum.inr (Fin.subNat m i m_le_i)
     16 
     17 def finProd (i : Fin (m * n)) : Prod (Fin m) (Fin n) :=
     18   let mul_gt_zero_imp_gt_zero {m n : Nat} : (m * n > 0) -> (m > 0) := by
     19     intro
     20     rw [<-Nat.zero_div n]
     21     apply Nat.div_lt_of_lt_mul
     22     rw [Nat.mul_comm]
     23     assumption
     24 
     25   have mn_gt_zero : m * n > 0 := Fin.pos i
     26   have m_gt_zero : m > 0 := mul_gt_zero_imp_gt_zero mn_gt_zero
     27   let q := Fin.ofNat' (i.val / n) m_gt_zero
     28 
     29   have nm_gt_zero : n * m > 0 := by rw [Nat.mul_comm]; exact mn_gt_zero
     30   have n_gt_zero : n > 0 := mul_gt_zero_imp_gt_zero nm_gt_zero
     31   let r := Fin.ofNat' (i.val % n) n_gt_zero
     32 
     33   (q, r)
     34 
     35 --------------------------------------------------------------------------------
     36 -- BitVec util
     37 
     38 def bitVecFoldl (f : α -> Bool -> α) (bv : BitVec n) (init : α) : α :=
     39   Fin.foldl n (fun acc i => f acc (bv.getLsb i)) init
     40 
     41 def bitVecEnuml (n : Nat) (f : α -> BitVec n -> α) (init : α) : α :=
     42   let loop acc i := f acc (BitVec.ofFin i)
     43   Fin.foldl (2^n) loop init
     44 
     45 def bitVecChoose (n k : Nat) (f : α -> BitVec n -> α) (acc : α) : α :=
     46   match n, k with
     47   | 0, 0 => f acc BitVec.nil
     48   | 0, _ + 1 => acc
     49   | _ + 1, 0 => f acc 0
     50   | n' + 1, k' + 1 =>
     51     let f' (msb : Bool) (acc : α) (bv' : BitVec n') : α :=
     52       let bv := BitVec.cons msb bv'
     53       f acc bv
     54     let acc := bitVecChoose n' k (f' false) acc
     55     bitVecChoose n' k' (f' true) acc
     56 
     57 def bitVecPopCount (bv : BitVec n) : Nat :=
     58   bitVecFoldl (fun acc bit => acc + bit.toNat) bv 0
     59 
     60 --------------------------------------------------------------------------------
     61 -- Notation
     62 
     63 instance : HPow Type Nat Type where
     64   hPow Y n := Fin n -> Y
     65 
     66 def vec (coords : List Y) : Y ^ coords.length := coords.get
     67 
     68 -- SMul represents the ability to use "⋅" as an infix operator.
     69 class SMul (α β : Type) where
     70   smul : α -> β -> β
     71 
     72 -- TODO: Is 75 an appropriate precedence?
     73 -- How do I figure out precedences of standard Lean operators?
     74 infixr:75 " ⋅ " => SMul.smul
     75 
     76 class Inv (α : Type) where
     77   inv : α -> α
     78 
     79 --------------------------------------------------------------------------------
     80 -- Additive (commutative) semigroups
     81 
     82 class AddSemigroup (A : Type) where
     83   add : A -> A -> A
     84 
     85 instance [AddSemigroup A] : Add A where
     86   add := AddSemigroup.add
     87 
     88 instance [AddSemigroup A] [AddSemigroup B] : AddSemigroup (A × B) where
     89   add x y := (x.1 + y.1, x.2 + y.2)
     90 
     91 instance AddSemigroup.instFun {X : Type} [AddSemigroup A] : AddSemigroup (X -> A) where
     92   add f g := fun x => f x + g x
     93 
     94 instance [AddSemigroup A] {n : Nat} : AddSemigroup (A ^ n) := AddSemigroup.instFun
     95 
     96 --------------------------------------------------------------------------------
     97 -- Additive (commutative) monoids
     98 
     99 class AddMonoid (A : Type) extends AddSemigroup A where
    100   zero : A
    101 
    102 instance [AddMonoid A] : OfNat A 0 where
    103   ofNat := AddMonoid.zero
    104 
    105 instance [AddMonoid A] [AddMonoid B] : AddMonoid (A × B) where
    106   zero := (0, 0)
    107 
    108 instance AddMonoid.instFun {X : Type} [AddMonoid A] : AddMonoid (X -> A) where
    109   zero := fun _ => 0
    110 
    111 instance [AddMonoid A] {n : Nat} : AddMonoid (A ^ n) := AddMonoid.instFun
    112 
    113 --------------------------------------------------------------------------------
    114 -- Additive (commutative) groups
    115 
    116 class AddGroup (A : Type) extends AddMonoid A where
    117   sub : A -> A -> A
    118   neg : A -> A := fun x => sub zero x
    119 
    120 instance [AddGroup A] : Sub A where
    121   sub := AddGroup.sub
    122 instance [AddGroup A] : Neg A where
    123   neg := AddGroup.neg
    124 
    125 instance [AddGroup A] [AddGroup B] : AddGroup (A × B) where
    126   sub x y := (x.1 - y.1, x.2 - y.2)
    127 
    128 instance AddGroup.instFun {X : Type} [AddGroup A] : AddGroup (X -> A) where
    129   sub f g := fun x => f x - g x
    130 
    131 instance [AddGroup A] {n : Nat} : AddGroup (A ^ n) := AddGroup.instFun
    132 
    133 --------------------------------------------------------------------------------
    134 -- Multiplicative semigroups, non-commutative or commutative
    135 
    136 class MulSemigroup (A : Type) where
    137   mul : A -> A -> A
    138 class MulCommSemigroup (A : Type) extends MulSemigroup A where
    139 
    140 instance [MulSemigroup A] : Mul A where
    141   mul := MulSemigroup.mul
    142 
    143 instance [MulSemigroup A] [MulSemigroup B] : MulSemigroup (A × B) where
    144   mul x y := (x.1 * y.1, x.2 * y.2)
    145 instance [MulCommSemigroup A] [MulCommSemigroup B] : MulCommSemigroup (A × B) where
    146 
    147 instance MulSemigroup.instFun {X : Type} [MulSemigroup A] : MulSemigroup (X -> A) where
    148   mul f g := fun x => f x * g x
    149 instance MulCommSemigroup.instFun {X : Type} [MulCommSemigroup A] : MulCommSemigroup (X -> A) where
    150 
    151 instance [MulSemigroup A] {n : Nat} : MulSemigroup (A ^ n) := MulSemigroup.instFun
    152 instance [MulCommSemigroup A] {n : Nat} : MulCommSemigroup (A ^ n) := MulCommSemigroup.instFun
    153 
    154 --------------------------------------------------------------------------------
    155 -- Multiplicative monoids, non-commutative or commutative
    156 
    157 class MulMonoid (A : Type) extends MulSemigroup A where
    158   one : A
    159 class MulCommMonoid (A : Type) extends MulMonoid A, MulCommSemigroup A where
    160 
    161 instance [MulMonoid A] : OfNat A 1 where
    162   ofNat := MulMonoid.one
    163 
    164 instance [MulMonoid A] [MulMonoid B] : MulMonoid (A × B) where
    165   one := (1, 1)
    166 instance [MulCommMonoid A] [MulCommMonoid B] : MulCommMonoid (A × B) where
    167 
    168 instance MulMonoid.instFun {X : Type} [MulMonoid A] : MulMonoid (X -> A) where
    169   one := fun _ => 1
    170 instance MulCommMonoid.instFun {X : Type} [MulCommMonoid A] : MulCommMonoid (X -> A) where
    171 
    172 instance [MulMonoid A] {n : Nat} : MulMonoid (A ^ n) := MulMonoid.instFun
    173 instance [MulCommMonoid A] {n : Nat} : MulCommMonoid (A ^ n) := MulCommMonoid.instFun
    174 
    175 --------------------------------------------------------------------------------
    176 -- Multiplicative groups, non-commutative or commutative
    177 
    178 class MulGroup (A : Type) extends MulMonoid A where
    179   div : A -> A -> A
    180   inv : A -> A := fun x => div one x
    181 class MulCommGroup (A : Type) extends MulGroup A, MulCommMonoid A where
    182 
    183 instance [MulGroup A] : Div A where
    184   div := MulGroup.div
    185 instance [MulGroup A] : Inv A where
    186   inv := MulGroup.inv
    187 
    188 instance [MulGroup A] [MulGroup B] : MulGroup (A × B) where
    189   div x y := (x.1 / y.1, x.2 / y.2)
    190 instance [MulCommGroup A] [MulCommGroup B] : MulCommGroup (A × B) where
    191 
    192 instance MulGroup.instFun {X : Type} [MulGroup A] : MulGroup (X -> A) where
    193   div f g := fun x => f x / g x
    194 instance MulCommGroup.instFun {X : Type} [MulCommGroup A] : MulCommGroup (X -> A) where
    195 
    196 instance [MulGroup A] {n : Nat} : MulGroup (A ^ n) := MulGroup.instFun
    197 instance [MulCommGroup A] {n : Nat} : MulCommGroup (A ^ n) := MulCommGroup.instFun
    198 
    199 --------------------------------------------------------------------------------
    200 -- Commutative and unital rings
    201 
    202 class Ring (R : Type) extends AddGroup R, MulCommMonoid R
    203 
    204 instance : Ring Rat where
    205   zero := 0
    206   add := Rat.add
    207   sub := Rat.sub
    208   one := 1
    209   mul := Rat.mul
    210 
    211 -- TODO: PolyRing
    212 
    213 --------------------------------------------------------------------------------
    214 -- Modules
    215 
    216 -- R is an outParam, because usually V will only be a module over one ring, and
    217 -- this allows us to leave R implicit in many cases.
    218 class Module (R : outParam Type) [Ring R] (M : Type) [AddGroup M] where
    219   smul : R -> M -> M
    220 
    221 instance [Ring R] [AddGroup M] [Module R M] : SMul R M where
    222   smul := Module.smul
    223 
    224 instance [Ring R] : Module R R where
    225   smul := MulSemigroup.mul
    226 
    227 instance [Ring R] [AddGroup M] [Module R M] [AddGroup N] [Module R N] : Module R (M × N) where
    228   smul a p := (a ⋅ p.1, a ⋅ p.2)
    229 
    230 instance Module.instFun [Ring R] [AddGroup M] [Module R M] : Module R (X -> M) where
    231   smul a f := fun x => a ⋅ f x
    232 
    233 instance [Ring R] [AddGroup M] [Module R M] {n : Nat} : Module R (M ^ n) := Module.instFun
    234 
    235 def Coordinates [Ring R] (M : Type) [AddGroup M] [Module R M] (n : Nat) : Type :=
    236   Fin n -> R
    237 
    238 instance [Ring R] [ToString R] [AddGroup M] [Module R M] : ToString (Coordinates M n) where
    239   toString coords :=
    240     match n with
    241     | 1 => toString (coords 0)
    242     | n =>
    243       let consCoordString i acc := toString (coords i) :: acc
    244       s!"({String.intercalate ", " $ Fin.foldr n consCoordString []})"
    245 
    246 structure FiniteBasis [Ring R] (M : Type) [AddGroup M] [Module R M] (n : Nat) where
    247   basis : Fin n -> M
    248   coords : M -> Coordinates M n
    249 
    250 -- XXX: Calling this "vector" is maybe a bit weird.
    251 def FiniteBasis.vector [Ring R] [AddGroup M] [Module R M]
    252     (b : FiniteBasis M n) (coords : Coordinates M n) : M :=
    253   Fin.foldl n (fun acc i => acc + coords i ⋅ b.basis i) 0
    254 
    255 def FiniteBasis.ring (R : Type) [Ring R] : FiniteBasis R 1 :=
    256   ⟨fun 0 => 1, fun a 0 => a⟩
    257 
    258 def FiniteBasis.mul [Ring R] [AddGroup M] [Module R M] (b : FiniteBasis M m)
    259     [AddGroup N] [Module R N] (c : FiniteBasis N n) : FiniteBasis (M × N) (m + n) :=
    260   let basis i :=
    261     match finSum i with
    262     | .inl i => (b.basis i, 0)
    263     | .inr i => (0, c.basis i)
    264   let coords v i :=
    265     match finSum i with
    266     | .inl i => b.coords v.1 i
    267     | .inr i => c.coords v.2 i
    268   ⟨basis, coords⟩
    269 
    270 def FiniteBasis.pow [Ring R] [AddGroup M] [Module R M] (b : FiniteBasis M m)
    271     (n : Nat) : FiniteBasis (M ^ n) (m * n) :=
    272   let basis i :=
    273     let ⟨iq, ir⟩ := finProd i
    274     fun j => if j = ir then b.basis iq else 0
    275   let coords v i :=
    276     let ⟨iq, ir⟩ := finProd i
    277     b.coords (v ir) iq
    278   ⟨basis, coords⟩
    279 
    280 -- OrthoQuadraticForm b is a quadratic form with respect to which the finite
    281 -- basis b is orthogonal.
    282 structure OrthoQuadraticForm [Ring R] [AddGroup M] [Module R M] (_ : FiniteBasis M n) where
    283   evalBasis : Fin n -> R
    284 
    285 def OrthoQuadraticForm.eval [Ring R] [AddGroup M] [Module R M]
    286     {b : FiniteBasis M n} (q : OrthoQuadraticForm b) (v : M) : R :=
    287   let vCoords := b.coords v
    288   -- q v = q (∑ vCoords i ⋅ b.basis i)
    289   --     = ∑ q (vCoords i ⋅ b.basis i)        (orthogonality of b wrt q)
    290   --     = ∑ (vCoords i)^2 * q (b.basis i)    (def. of quadratic form)
    291   --     = ∑ (vCoords i)^2 * q.evalBasis i
    292   Fin.foldl n (fun acc i => acc + vCoords i * vCoords i * q.evalBasis i) 0
    293 
    294 --------------------------------------------------------------------------------
    295 -- Algebras
    296 
    297 class Algebra (R : outParam Type) [Ring R] (A : Type) [AddGroup A] [MulMonoid A] extends Module R A
    298 
    299 instance [Ring R] : Algebra R R where
    300 
    301 -- TODO: More Algebra instances
    302 
    303 --------------------------------------------------------------------------------
    304 -- Geometric algebra generation from vector spaces with a quadratic form and
    305 -- an orthogonal finite basis
    306 
    307 structure GeometricAlgebra [Ring R] [AddGroup M] [Module R M]
    308     {b : FiniteBasis M n} (q : OrthoQuadraticForm b) : Type where
    309   coords : BitVec n -> R
    310 
    311 instance [Ring R] [DecidableEq R] [ToString R] [AddGroup M] [Module R M]
    312     {b : FiniteBasis M n} {q : OrthoQuadraticForm b} : ToString (GeometricAlgebra q) where
    313   toString v :=
    314     let bldToString (bld : BitVec n) : String :=
    315       let indexToString (i : Fin n) acc :=
    316         if bld.getLsb i then toString i :: acc else acc
    317       let indexStrings := Fin.foldr n indexToString []
    318       let sep := if n > 9 then "," else ""
    319       "{" ++ String.intercalate sep indexStrings ++ "}"
    320 
    321     let termToString acc bld :=
    322       if v.coords bld != 0 then
    323         let coeffString := toString (v.coords bld)
    324         let termString :=
    325           if bld != 0 then
    326              s!"{coeffString} * {bldToString bld}"
    327           else
    328             coeffString
    329         termString :: acc
    330       else
    331         acc
    332     let termStrings := Fin.foldl (n + 1)
    333       (fun acc k => bitVecChoose n k.val termToString acc) []
    334 
    335     if termStrings.length > 0 then
    336       String.intercalate " + " (List.reverse termStrings)
    337     else
    338       "0"
    339 
    340 instance [Ring R] [AddGroup M] [Module R M] {b : FiniteBasis M n}
    341     {q : OrthoQuadraticForm b} : AddGroup (GeometricAlgebra q) where
    342   zero := ⟨0⟩
    343   add v w := ⟨v.coords + w.coords⟩
    344   sub v w := ⟨v.coords - w.coords⟩
    345 
    346 instance [Ring R] [AddGroup M] [Module R M] {b : FiniteBasis M n}
    347     {q : OrthoQuadraticForm b} : Module R (GeometricAlgebra q) where
    348   smul a v := ⟨a ⋅ v.coords⟩
    349 
    350 -- Multiplication algorithm:
    351 -- Given v w : GeometricAlgebra q and bld : BitVec n, the bld-th component
    352 -- of the product v * w is a sum of 2^n terms, where each term is the product of
    353 -- the vbld-th component of v and the wbld-th component of w, for all
    354 -- vbld wbld : BitVec n that multiply (XOR) to bld (modulo a scalar). For
    355 -- example, if we take n = 3 and denote blades as bit vectors, then we can write
    356 -- generic vectors v and w as follows:
    357 --    v0 * 000 + v1 * 001 + v2 * 010 + v3 * 011 + v4 * 100 + v5 * 101 + v6 * 110 + v7 * 111
    358 --    w0 * 000 + w1 * 001 + w2 * 010 + w3 * 011 + w4 * 100 + w5 * 101 + w6 * 110 + w7 * 111
    359 -- Now, if we want the component of v * w corresponding to blade bld = 010, we
    360 -- consider the following pairs of vbld and wbld:
    361 --    vbld   wbld
    362 --     000    010
    363 --     001    011
    364 --     010    000
    365 --     011    001
    366 --     100    110
    367 --     101    111
    368 --     110    100
    369 --     111    101
    370 -- (These are just all 2^3 3-bit numbers, but with one column (wbld in this
    371 -- case) XORed with bld.) Then we multiply the coefficients corresponding to
    372 -- each row and sum the results. Additionally, we must calculate a sign from
    373 -- vbld and wbld (in accordance with the equation ei * ej = - ej * ei for all
    374 -- basis vectors ei != ej), and we must calculuate a scalar by multiplying the
    375 -- "squares" of each basis vector shared by vbld and wbld, where "square" means
    376 -- apply the quadratic form.
    377 
    378 instance [Ring R] [AddGroup M] [Module R M] {b : FiniteBasis M n}
    379     {q : OrthoQuadraticForm b} : MulMonoid (GeometricAlgebra q) where
    380   one := ⟨fun bld => if bld = 0 then 1 else 0⟩
    381   mul v w :=
    382     let sign (x y : BitVec n) : R :=
    383       let addSwaps (acc : Nat) (i : Fin n) : Nat :=
    384         acc + bitVecPopCount (x &&& (y <<< (i.val + 1)))
    385       let nSwaps := Fin.foldl n addSwaps 0
    386       if nSwaps % 2 = 0 then 1 else -1
    387 
    388     let scalar (x y : BitVec n) : R :=
    389       let z := x &&& y
    390       let mulSquares (acc : R) (i : Fin n) : R :=
    391         if z.getLsb i then acc * q.evalBasis i else acc
    392       Fin.foldl n mulSquares 1
    393 
    394     let coords (bld : BitVec n) : R :=
    395       let addTerm (acc : R) (vbld : BitVec n) : R :=
    396         let wbld : BitVec n := vbld ^^^ bld
    397         acc + sign vbld wbld * scalar vbld wbld * v.coords vbld * w.coords wbld
    398       bitVecEnuml n addTerm 0
    399     ⟨coords⟩
    400 
    401 instance [Ring R] [AddGroup M] [Module R M] {b : FiniteBasis M n}
    402     {q : OrthoQuadraticForm b} : Algebra R (GeometricAlgebra q) where
    403 
    404 --------------------------------------------------------------------------------
    405 -- Main
    406 
    407 abbrev R : Type := Rat
    408 abbrev n : Nat := 3
    409 abbrev M : Type := R ^ n
    410 abbrev b : FiniteBasis M n := FiniteBasis.pow (FiniteBasis.ring R) n
    411 abbrev q : OrthoQuadraticForm b := ⟨
    412     fun
    413     | 0 => 1
    414     | 1 => 1
    415     | 2 => 1
    416    417 abbrev G : Type := GeometricAlgebra q
    418 
    419 def main : IO Unit := do
    420   let v : G := ⟨
    421       fun
    422       | 0 => 0
    423       | 1 => 2
    424       | 2 => 0
    425       | 3 => 0
    426       | 4 => 0
    427       | 5 => 0
    428       | 6 => 0
    429       | 7 => 0
    430    431   let w : G := ⟨
    432       fun
    433       | 0 => 0
    434       | 1 => 0
    435       | 2 => 3
    436       | 3 => 0
    437       | 4 => 0
    438       | 5 => 0
    439       | 6 => 0
    440       | 7 => 0
    441    442   let x : G := v * w
    443   IO.println v
    444   IO.println w
    445   IO.println x