Module Product

Require Sylvie Rel.
Require Import TreeAl.
Import Utf8.
Import Coqlib AST Integers Maps.
Import Util ShareTree.
Import Sylvie Sexpr Rel.

Definition rev_map_o {X Y} (f: Xoption Y) (xs: list X) : list Y :=
  (fix rec xs ys :=
    match xs with
    | nil => ys
    | x :: xs' => rec xs' match f x with Some y => y :: ys | None => ys end
    end) xs nil.

Unset Elimination Schemes.
Set Implicit Arguments.

Inductive side : Type := L | R.
Instance side_dec : EqDec side :=
  λ x y,
  match x, y with
  | L, L => left eq_refl
  | R, R => left eq_refl
  | L, R => rightK : L = R, match K with eq_refl => I end)
  | R, L => rightK : R = L, match K with eq_refl => I end)
  end.

Module PPTree : SHARETREE with Definition elt := (node * node)%type
  := ProdShareTree PShareTree PShareTree.

Definition ppmap_get {X} (pp: node * node) (m: PPTree.t (list X)) : list X :=
  match PPTree.get pp m with
  | Some xs => xs
  | None => nil
  end.

Section REGISTER.

Context (reg: Type) (reg_dec: EqDec reg).

Context (left right: Registers.regreg).

Definition sexp_of_func (ι: Registers.regreg) f :=
  match f with
  | inl r => Regr)
  | inr n => Name n
  end.

Inductive decoration : Type :=
| Invariant `(assertion reg)
| Assertion `(assertion reg)
.

Record state : Type :=
  State {
      code : Sylvie.code reg;
      ppmap: PPTree.t (list node);
      deco: PTree.t (list decoration);
      hints: PTree.t side;
      next_node: node
    }.

Definition init_state : state :=
  {|
    code := PTree.empty _;
    ppmap := PPTree.empty _;
    deco := PTree.empty _;
    hints := PTree.empty _;
    next_node := xH |}.

Definition is_visible (i: RTL.instruction) : bool :=
  match i with
  | RTL.Inop _
  | RTL.Iop _ _ _ _
    => false
  | RTL.Icond _ _ _ _
  | RTL.Iload _ _ _ _ _ _
  | RTL.Istore _ _ _ _ _ _
  | RTL.Icall _ _ _ _ _
  | RTL.Itailcall _ _ _
  | RTL.Ibuiltin _ _ _ _
  | RTL.Ijumptable _ _
  | RTL.Ireturn _
    => true
  end.

Definition is_dangerous_op op (args: list Registers.reg) : bool :=
    match op, args with
    | (Op.Omove | Op.Ocast8signed | Op.Ocast8unsigned | Op.Ocast16signed | Op.Ocast16unsigned
      | Op.Oneg | Op.Omulimm _ | Op.Oandimm _ | Op.Oorimm _ | Op.Oxorimm _ | Op.Onot
      | Op.Oshlimm _ | Op.Oshrimm _ | Op.Oshruimm _ | Op.Ororimm _
      | Op.Onegf | Op.Oabsf | Op.Onegfs | Op.Oabsfs
      | Op.Osingleoffloat | Op.Ofloatofsingle | Op.Olowlong | Op.Ohighlong),
      (nil | _ :: _ :: _)
    | (Op.Ointconst _ | Op.Ofloatconst _ | Op.Osingleconst _ | Op.Oindirectsymbol _), _ :: _
    | (Op.Osub | Op.Omul | Op.Omulhs | Op.Omulhu | Op.Oand | Op.Oor | Op.Oxor
       | Op.Oshl | Op.Oshr | Op.Oshru | Op.Oshldimm _
       | Op.Oaddf | Op.Osubf | Op.Omulf | Op.Odivf
       | Op.Oaddfs | Op.Osubfs | Op.Omulfs | Op.Odivfs
       | Op.Omakelong), (nil | _ :: nil | _ :: _ :: _ :: _)
    | (Op.Odiv | Op.Odivu | Op.Omod | Op.Omodu | Op.Oshrximm _
       | Op.Ointoffloat | Op.Ofloatofint | Op.Ointofsingle | Op.Osingleofint), _
      => true
    | Op.Olea addr, _ =>
      match addr, args with
      | (Op.Aindexed _ | Op.Ascaled _ _ | Op.Abased _ _ | Op.Abasedscaled _ _ _), (nil | _ :: _ :: _)
      | (Op.Aindexed2 _ | Op.Aindexed2scaled _ _), (nil | _ :: nil | _ :: _ :: _ :: _)
      | (Op.Aglobal _ _ | Op.Ainstack _), _ :: _
        => true
      | _, _ => false
      end
    | _, _ => false
    end.

Definition is_dangerous (i: RTL.instruction) : bool :=
  match i with
  | RTL.Inop _ => false
  | RTL.Iop op args _ _ => is_dangerous_op op args
  | _ => true
  end.

Lemma not_dangerous_eval_operation {F V} (ge: Globalenvs.Genv.t F V) sp rs m {op args} :
  is_dangerous_op op args = false
  ∃ v, Op.eval_operation ge sp op (mapr, rs !! r) args) m = Some v.
Proof.
  intros SAFE.
  destruct op;
    try match goal with x : Op.addressing |- _ => destruct x end;
    try exact (false_not_true (eq_sym SAFE) _);
    vauto;
    destruct args as [ | α args ]; try exact (false_not_true (eq_sym SAFE) _);
    vauto;
    (try destruct args as [ | β args ]; try exact (false_not_true (eq_sym SAFE) _));
    vauto;
    (try destruct args as [ | γ args ]; try exact (false_not_true (eq_sym SAFE) _));
    vauto.
Qed.

Definition is_mark (pred: nodelist node) (pc: node) (i: RTL.instruction) : bool :=
  if is_dangerous i
  then true
  else
    match pred pc with
    | _ :: _ :: _ => true
    | _ => false
    end.

Definition add_instruction (k: node * node) (s: option side) (i: nodeedge reg) (st: state) : Errors.res state :=
  let pc := next_node st in
  let pc' := Pos.succ pc in
  match (code st) ! pc with
  | Some _ => Errors.Error (Errors.msg "add_instruction")
  | None =>
    let old := ppmap_get k (ppmap st) in
    Errors.OK {|
        code := PTree.set pc (i pc') (code st);
        ppmap := PPTree.set k (old ++ pc :: nil) (ppmap st);
        deco := deco st;
        hints := match s with None => hints st | Some s => PTree.set pc s (hints st) end;
        next_node := pc' |}
  end.

Definition skip_instruction (k: node * node) (st: state) : Errors.res (node * state) :=
  let pc := next_node st in
  let pc' := Pos.succ pc in
  let old := ppmap_get k (ppmap st) in
  Errors.OK (pc, {|
      code := code st;
      ppmap := PPTree.set k (old ++ pc :: nil) (ppmap st);
      deco := deco st;
      hints := hints st;
      next_node := pc' |}).

Definition patch_instruction (pc: node) (i: edge reg) (st: state) : Errors.res state :=
  match (code st) ! pc with
  | Some _ => Errors.Error (Errors.msg "patch_instruction")
  | None =>
    Errors.OK {|
        code := PTree.set pc i (code st);
        ppmap := ppmap st;
        deco := deco st;
        hints := hints st;
        next_node := next_node st |}
  end.

Definition add_assertion (opc: option node) (d: decoration) (st: state) : Errors.res state :=
  let pc := match opc with Some p => p | None => next_node st end in
  let old := match (deco st) ! pc with Some o => o | None => nil end in
  Errors.OK {|
      code := code st;
      ppmap := ppmap st;
      deco := PTree.set pc (d :: old) (deco st);
      hints := hints st;
      next_node := pc |}.

Import Errors.
Import String.
Open Scope error_monad_scope.

Definition str_of_i (i: RTL.instruction) : String.string :=
  match i with
  | RTL.Inop _ => "nop"
  | RTL.Iop _ _ _ _ => "op"
  | RTL.Iload _ _ _ _ _ _ => "load"
  | RTL.Istore _ _ _ _ _ _ => "store"
  | RTL.Icall _ _ _ _ _ => "call"
  | RTL.Itailcall _ _ _ => "tail-call"
  | RTL.Ibuiltin _ _ _ _ => "built-in"
  | RTL.Icond _ _ _ _ => "cond"
  | RTL.Ijumptable _ _ => "jump table"
  | RTL.Ireturn _ => "return"
  end%string.

Require Import PrintPos.
Definition str_of_k (k: node * node) : string :=
  ("(" ++ print_pos (fst k) ++ ", " ++ print_pos (snd k) ++ ")")%string.

Definition add_one (k: node * node) (s: side) (reg: Registers.regreg) i (st: state) : res (node * state) :=
  match i with
  | RTL.Inop pc' =>
    do st' <- add_instruction k (Some s) Egoto st;
    OK (pc', st')
  | RTL.Iop op args dst pc' =>
    do st' <- add_instruction k (Some s) (Eop (reg dst) (Operation op) (List.map reg args)) st;
    OK (pc', st')
  | _ => Error (msg ("add_one: unexpected lone instruction " ++ str_of_i i ++ " at " ++ str_of_k k))
  end.

Definition add_invariant (k: node * node) (r: rel) (st: state) : Errors.res state :=
  match ppmap_get k (ppmap st) with
  | nil => Error (msg "add_invariant")
  | pc :: _ =>
    match r with
    | nil =>
      add_assertion (Some pc) (Invariant True) st
    | _ =>
      List.fold_left
        (λ st p,
         do st <- st;
         let '(x, y) := p in
         add_assertion (Some pc) (Invariant (assert_eq_reg (left x) (right y))) st
        )
        r (OK st)
    end
  end.

Definition add_invariants (inv: list ((node * node) * rel)) (st: state) : Errors.res state :=
  List.fold_left
    (λ st inv,
     do st <- st;
     let '(k, r) := inv in
     add_invariant k r st
    )
    inv (OK st).

Section PRODUCT.

Context (is_dead_branch: sideRTL.codenodebool).

Context (p1 p2: RTL.code).

Definition assert_same_args args1 args2 : res (list decoration) :=
  Util.fold_left2x a1 a2,
                   do y <- x;
                   OK (Assertion (assert_eq_reg (left a1) (right a2)) :: y))
                  (λ _ _, Error (msg "Too many left arguments"))
                  (λ _ _, Error (msg "Too many right arguments"))
                  args1 args2
                  (OK nil).

Definition assert_same_builtin_args args1 args2 : res (list decoration) :=
  Util.fold_left2x a1 a2,
                   do y <- x;
                   match a1, a2 with
                   | BA r1, BA r2 => OK (Assertion (assert_eq_reg (left r1) (right r2)) :: y)
                   | BA_int i1, BA_int i2 => if Int.eq i1 i2 then x else Error (msg "BA_int mismatch")
                   | BA_addrglobal s1 o1, BA_addrglobal s2 o2 =>
                     if if ident_eq s1 s2 then Int.eq o1 o2 else false
                     then x
                     else Error (msg "BA_addrglobal mismatch")
                   | _, _ => Error (msg "Builtin arguments mismatch")
                   end)
                  (λ _ _, Error (msg "Too many left builtin arguments"))
                  (λ _ _, Error (msg "Too many right builtin arguments"))
                  args1 args2
                  (OK nil).

Definition match_builtin_res (r1 r2: builtin_res Registers.reg) : res _ :=
  match r1, r2 with
  | BR x1, BR x2 => OK (Some (x1, x2))
  | BR_none, BR_none => OK None
  | _, _ => Error (msg "match_builtin_res failed")
  end.

Inductive instructions_match_t : Type :=
| IM_OK (pre: list decoration) `(nodeedge reg) (post: list (reg * reg)) (succ: option (node * node))
| IM_Cond (pre: decoration) (th1 th2 el1 el2: node) `(nodenodeedge reg)
| IM_Error (msg: errmsg).

Definition instructions_match (pc1 pc2: node) (i1 i2: RTL.instruction) : instructions_match_t :=
  match i1, i2 with
  | RTL.Ireturn None, RTL.Ireturn None => IM_OK nil_, Estop) nil None
  | RTL.Ireturn (Some r1), RTL.Ireturn (Some r2) =>
    IM_OK (Assertion (assert_eq_reg (left r1) (right r2)) :: nil) (λ _, Estop) nil None

  | RTL.Iload _ κ1 addr1 args1 dst1 pc1', RTL.Iload _ κ2 addr2 args2 dst2 pc2' =>
    if chunk_eq κ1 κ2
    then
      IM_OK
        (Assertion (AssertEQ (addr addr1 left args1) (addr addr2 right args2)) :: nil)
        (Eop (left dst1) Havoc nil)
        ((left dst1, right dst2) :: nil)
        (Some (pc1', pc2'))
    else IM_Error (msg "Chunk mismatch in load")

  | RTL.Icall sg1 f1 args1 dst1 pc1', RTL.Icall sg2 f2 args2 dst2 pc2' =>
    if signature_eq sg1 sg2
    then
      match assert_same_args args1 args2 with
      | Error e => IM_Error e | OK pre =>
      IM_OK
       (Assertion (AssertEQ (sexp_of_func left f1) (sexp_of_func right f2)) :: pre)
              (Eop (left dst1) Havoc nil)
              ((left dst1, right dst2) :: nil)
              (Some (pc1', pc2'))
      end
    else IM_Error (msg "Signature mismatch in call")

  | RTL.Ibuiltin ef1 args1 dst1 pc1', RTL.Ibuiltin ef2 args2 dst2 pc2' =>
    if external_function_eq ef1 ef2
    then
      match assert_same_builtin_args args1 args2 with
      | Error e => IM_Error e | OK pre =>
      match match_builtin_res dst1 dst2 with
      | Error e => IM_Error e | OK out =>
      IM_OK
        pre
        match out with None => Egoto | Some (o1,_) => Eop (left o1) Havoc nil end
        match out with None => nil | Some (o1, o2) => (left o1, right o2) :: nil end
        (Some (pc1', pc2'))
      end end
    else IM_Error (msg "External function mismatch in builtin")

  | RTL.Istore _ κ1 addr1 args1 src1 pc1', RTL.Istore _ κ2 addr2 args2 src2 pc2' =>
    if chunk_eq κ1 κ2
    then
      IM_OK
        (Assertion (assert_eq_reg (left src1) (right src2))
       :: Assertion (AssertEQ (addr addr1 left args1) (addr addr2 right args2)) :: nil)
        Egoto
        nil
        (Some (pc1', pc2'))
    else IM_Error (msg "Chunk mismatch in store")

  | RTL.Icond cond1 args1 th1 el1, RTL.Icond cond2 args2 th2 el2 =>
      match assert_iff (comp cond1 left args1) (comp cond2 right args2) with
      | None => IM_Error (msg "mk_prod: assert_iff")
      | Some a =>
        IM_Cond (Assertion a) th1 th2 el1 el2 (Ebranch cond1 (List.map left args1))
      end


    | RTL.Iop op1 args1 dst1 pc1', RTL.Iop op2 args2 dst2 pc2' =>
      if Op.eq_operation op1 op2
      then
        match assert_same_args args1 args2 with
        | Error e => IM_Error e
        | OK pre =>
          IM_OK
            pre
            (Eop (left dst1) (Operation op1) (List.map left args1))
            ((left dst1, right dst2) :: nil)
            (Some (pc1', pc2'))
        end
      else IM_Error (msg "Operation mismatch")

  | _, _ => IM_Error (msg "Instruction mismatch")
  end.

Context (pred1 pred2: nodelist node).

Fixpoint mk_prod (pc1 pc2: node) (st: state) (fuel: nat) : res state :=
  match fuel with
  | O => Error (msg "mk_prod: Dame más gasolina")
  | S fuel' =>
    match p1 ! pc1 with
    | None => Error (msg "mk_prod: no instruction in program 1")
    | Some i1 =>
      if is_mark pred1 pc1 i1
      then
        match p2 ! pc2 with
        | None => Error (msg "mk_prod: no instruction in program 2")
        | Some i2 =>
          if is_mark pred2 pc2 i2
          then
            match ppmap_get (pc1, pc2) (ppmap st) with
            | pc :: _ =>
              add_instruction (pc1, pc2) None_, Egoto pc) st
            | nil =>
              let dangerous1 := is_dangerous i1 in
              let dangerous2 := is_dangerous i2 in

              match dangerous1, dangerous2 with
              | false, false =>
              do (pc1', st') <- add_one (pc1, pc2) L left i1 st;
              do (pc2', st'') <- add_one (pc1', pc2) R right i2 st';
              mk_prod pc1' pc2' st'' fuel'

              | true, false =>
              do (pc2', st) <- add_one (pc1, pc2) R right i2 st;
              mk_prod pc1 pc2' st fuel'

              | false, true =>
              do (pc1', st) <- add_one (pc1, pc2) L left i1 st;
              mk_prod pc1' pc2 st fuel'

              | true, true =>
                match instructions_match pc1 pc2 i1 i2 with
                | IM_Error m => Error (app m (msg (" " ++ str_of_i i1 ++ " × " ++ str_of_i i2 ++ " at " ++ str_of_k (pc1, pc2))))
                | IM_OK pre j post next =>
                  do st <- List.fold_left
                     (λ st a,
                      do st <- st;
               add_assertion None a st)
                     pre
                     (OK st);
                  do st <- add_instruction (pc1, pc2) None j st;
                  do st <- List.fold_left
                     (λ st out,
                      do st <- st;
                      add_instruction (pc1, pc2) None (Eop (snd out) (Operation Op.Omove) (fst out :: nil)) st)
                     post
                     (OK st);
                  match next with
                  | Some (pc1', pc2') => mk_prod pc1' pc2' st fuel'
                  | None => OK st
                  end
                  | IM_Cond a th1 th2 el1 el2 j =>
                    do st <- add_assertion None a st;
                    do (pc, st) <- skip_instruction (pc1, pc2) st;
                    let th := st.(next_node) in
                    do st <- mk_prod th1 th2 st fuel';
                    let el := st.(next_node) in
                    do st <- mk_prod el1 el2 st fuel';
                    patch_instruction pc (j th el) st
                end
              end
            end
          else
            do (pc2', st') <- add_one (pc1, pc2) R right i2 st;
            mk_prod pc1 pc2' st' fuel'
        end
      else
        do (pc1', st') <- add_one (pc1, pc2) L left i1 st;
        mk_prod pc1' pc2 st' fuel'
    end
  end.

End PRODUCT.

Definition ep_annot_fun acc x y :=
    do f <- acc; OK (assert_eq_reg (left x) (right y) :: f).

Definition ep_annot_err {X Y} n (_: X) (_: Y) : Errors.res (list (Sexpr.assertion reg)) :=
    Error (msg ("ep_annot: length mismatch " ++ n)).

Definition ep_annot (p q: list Registers.reg) : Errors.res (list (Sexpr.assertion reg)) :=
  Util.fold_left2
    ep_annot_fun
    (ep_annot_err "1")
    (ep_annot_err "2")
    p q
    (OK nil).

Definition with_precondition (pre: list (Sexpr.assertion reg)) (a: Sexpr.assertion reg) : Sexpr.assertion reg :=
  List.fold_lefta p, Implies p a) pre a.

Lemma eval_with_precondition fs sp vp ε :
  ∀ pre a,
    eval_assertion _ fs sp vp ε (with_precondition pre a) →
    (∀ i, In i preeval_assertion _ fs sp vp ε i) →
    eval_assertion _ fs sp vp ε a.
Proof.
  intros pre; elim pre; clear pre; simpl.
  - intros a H _; exact H.
  - intros i pre IH a Hia Hpre.
    specialize (IH _ Hiaj Hj, Hpre j (or_intror Hj))).
    clear Hia.
    simpl in IH. apply IH, Hpre; left; reflexivity.
Qed.

Definition pre_of_assertion_list fs sp pe (al: list (Sexpr.assertion reg)) (ε: env reg) : Prop :=
  ∀ a,
    In a al
    Sexpr.eval_assertion reg_dec fs sp pe ε a.

Definition get_assertions (deco: hashmap _) : nodelist (Sexpr.assertion reg) :=
  λ pc,
  rev_map_o
    (λ d, match d with Invariant _ => None | Assertion a => Some a end)
    (deco pc).

End REGISTER.

Arguments init_state {reg}.

Definition is_dead_branch (s: side) (c: RTL.code) (pc: RTL.node) : bool :=
  match Maps.PTree.get pc c with
  | Some (RTL.Iop (Op.Ointconst v1) nil x1 pc) =>
    match Maps.PTree.get pc c with
    | Some (RTL.Iop (Op.Ointconst v2) nil x2 pc) =>
      match Maps.PTree.get pc c with
      | Some (RTL.Iop Op.Odiv (y1 :: y2 :: nil) dst pc) =>
        if eq_dec x1 y1 then if eq_dec x2 y2 then Int.eq v2 Int.zero else false else
        if eq_dec x1 y2 then if eq_dec x2 y1 then Int.eq v1 Int.zero else false else false
      | _ => false
      end
    | _ => false
    end
  | _ => false
  end.