Module UnLoad


Require Import
  Coqlib Utf8 Util
  AST Integers Floats
  Csharpminor.

Require Import
  AssocList
  Psatz.

Section LEAST_LARGER.

Local Open Scope positive_scope.

Definition least_larger (l: list ident) : ident :=
  match l with
  | nil => 1
  | i :: l' => List.fold_lefta j, if a <=? j then Pos.succ j else a) l' (Pos.succ i)
  end.

Lemma least_larger_larger l :
  ∀ x, In x lx < least_larger l.
Proof.
  destruct l as [|i l]. intros ? ().
  simpl least_larger.
  induction l as [|j l IH] using rev_ind.
  simpl. intros x [ <- | () ]. lia.
  intros x Hx. rewrite fold_left_app. simpl.
  destruct Hx as [ -> | Hx ].
  specialize (IH _ (or_introl eq_refl)).
  case (Pos.leb_spec0); lia.
  rewrite in_app in Hx. destruct Hx as [Hx | [ <- | () ] ].
  specialize (IH _ (or_intror Hx)).
  case Pos.leb_spec0; lia.
  specialize (IH _ (or_introl eq_refl)).
  case Pos.leb_spec0; lia.
Qed.

Lemma least_larger_least l :
  ∀ x, (∀ y, In y ly < x) →
       least_larger l <= x.
Proof.
  destruct l as [|i l].
  simpl. intros x _. lia.
  simpl least_larger.
  induction l as [|j l IH] using rev_ind.
  simpl. intros x H. specialize (H _ (or_introl eq_refl)). lia.
  intros x H. rewrite fold_left_app. simpl.
  assert (j < x). apply H. right. rewrite in_app. right; left; reflexivity.
  assert (∀ y, In y (i :: l) → y < x) as IH'.
  intros y H1. apply H. destruct H1 as [ -> | H1 ]. left; reflexivity. right. apply in_app. auto.
  specialize (IH _ IH').
  case Pos.leb_spec0; lia.
Qed.

End LEAST_LARGER.

Module ST.
  Definition t S A : Type := SA * S.

  Instance mon {S} : monad (t S) :=
    { bind A B m f := λ s, let '(a, s') := m(s) in f a s';
      ret A a := λ s, (a, s) }.

  Definition get {S} : t S S := λ s, (s, s).
  Definition put {S} s : t S _ := λ _, (tt, s).

  Definition upd {S} (f: SS) : t S _ := λ s, (tt, f s).

  Definition left {SSA} (m: t SA) : t (S₁ * S₂) A := λ s, let '(a, s') := m (fst s) in (a, (s', snd s)).
  Definition right {SSA} (m: t SA) : t (S₁ * S₂) A := λ s, let '(a, s') := m (snd s) in (a, (fst s, s')).

End ST.

Section EXPR_EQ.

Definition constant_beq (c d: constant) : bool :=
  match c, d with
  | Ointconst i, Ointconst j => Int.eq i j
  | Ofloatconst f, Ofloatconst g => Float.eq_dec f g
  | Osingleconst f, Osingleconst g => Float32.eq_dec f g
  | Olongconst i, Olongconst j => Int64.eq i j
  | _, _ => false
  end.

Lemma constant_beq_eq c d :
  constant_beq c d = truec = d.
Proof.
  destruct c as [i|f|f|i]; destruct d as [j|g|g|j]; simpl; split;
  intros H; eq_some_inv; try discriminate; f_equal.
  generalize (Int.eq_spec i j); rewrite H; exact id.
  inv H. apply Int.eq_true.
  InvBooleans. exact H.
  apply proj_sumbool_is_true. inv H. reflexivity.
  InvBooleans. exact H.
  apply proj_sumbool_is_true. inv H. reflexivity.
  generalize (Int64.eq_spec i j); rewrite H; exact id.
  inv H. apply Int64.eq_true.
Qed.

Instance CmpEqDec : EqDec comparison.
Proof.
intros x y; destruct x; destruct y;
       first [ left; reflexivity | right; abstract congruence ].
Defined.

Instance UopEqDec : EqDec unary_operation.
Proof.
intros x y; destruct x; destruct y;
       first [ left; reflexivity | right; abstract congruence ].
Defined.

Instance BopEqDec : EqDec binary_operation.
Proof.
intros x y; destruct x; destruct y;
       first [ left; reflexivity | right; abstract congruence |
       match goal with x: comparison, y: comparison |- _ =>
                       destruct (eq_dec x y);[left;congruence|right;abstract congruence]
       end ].
Defined.

Fixpoint expr_beq (x y: expr) {struct x} : bool :=
  match x, y with
  | Evar i, Evar j
  | Eaddrof i, Eaddrof j => Pos.eqb i j
  | Econst c, Econst d => constant_beq c d
  | Eunop o a, Eunop p b =>
    eq_dec o p && expr_beq a b
  | Ebinop o a b, Ebinop p c d =>
    eq_dec o p && expr_beq a c && expr_beq b d
  | Eload κ a, Eload κ' b =>
    eq_dec κ κ' && expr_beq a b
  | _, _ => false
  end.

Lemma expr_beq_eq:
  ∀ x y,
  expr_beq x y = truex = y.
Proof.
  intros x y. split.
  - revert y.
    induction x; destruct y; simpl;
    try (intros H; eq_some_inv; fail);
    try (case Pos.eqb_spec; congruence).
    rewrite constant_beq_eq; congruence.
    case eq_dec; simpl. 2: intros; eq_some_inv.
    intros <- H; rewrite (IHx _ H); reflexivity.
    case eq_dec; simpl. 2: intros; eq_some_inv.
    intros <- H; rewrite andb_true_iff in H. destruct H. erewrite IHx1, IHx2; eauto.
    case eq_dec; simpl. 2: intros; eq_some_inv.
    intros <-. intros H; erewrite IHx; eauto.
  - intros <-. induction x; simpl; rewrite ? eq_dec_true, ? Pos.eqb_refl; simpl; auto.
    apply constant_beq_eq; reflexivity.
    apply andb_true_iff; auto.
Qed.

End EXPR_EQ.

Instance ExprEqDec : EqDec expr := eq_dec_of_beq expr_beq expr_beq_eq.

Section EXTRACTION.

Fixpoint unload (e: expr) : ST.t (list expr) (list (expr * ident) → expr) :=
  match e with
  | Eunop o e₁ => do x₁ <- unload e₁; retl, Eunop o (xl))
  | Ebinop o ee₂ => do x₁ <- unload e₁; do x₂ <- unload e₂; retl, Ebinop o (xl) (xl))
  | Eload _ _ =>
    do l <- ST.get;
    do _ <- if List.in_dec eq_dec e l
           then ret tt
           else ST.put (e::l);
    retl, Evar (match assoc e l with Some t => t | None => xH end))
  | Evar _
  | Eaddrof _
  | Econst _
    => ret_, e)
  end.

End EXTRACTION.

Section TRANSFORM.

Variable unload : expr → (list (expr * ident) → expr) * list expr.

Definition next_temp : ST.t ident ident :=
  do i <- ST.get;
  do _ <- ST.put (Pos.succ i);
  ret i.

Definition next_assoc (e: expr) : ST.t (list (expr * ident) * ident) ident :=
  do t <- ST.right next_temp;
  do _ <- ST.left (ST.upds, (e, t) :: s));
  ret t.

Fixpoint expand' (base: list (expr * ident) → stmt) (l: list expr) : ST.t (list (expr * ident) * ident) stmt :=
  match l with
  | nil => do x <- ST.left ST.get; ret (base x)
  | e :: l' => do t <- next_assoc e; do s <- expand' base l'; ret (Sseq (Sset t e) s)
  end.

Definition expand (base: list (expr * ident) → stmt) (l: list expr) :
ST.t ident stmt :=
  λ i, let '(s, (_, i')) := expand' base l (nil, i) in (s, i').

Fixpoint transform_stmt (s: stmt) : ST.t ident stmt :=
  match s with
  | Sseq ss₂ =>
    do s₁' <- transform_stmt s₁;
    do s₂' <- transform_stmt s₂;
    ret (Sseq s₁' s₂')
  | Sifthenelse e ss₂ =>
    do s₁' <- transform_stmt s₁;
    do s₂' <- transform_stmt s₂;
    ret (Sifthenelse e s₁' s₂')
  | Sloop s₁ =>
    do s₁' <- transform_stmt s₁;
    ret (Sloop s₁')
  | Sblock s₁ =>
    do s₁' <- transform_stmt s₁;
    ret (Sblock s₁')
  | Sreturn (Some e) =>
    let '(e', l) := unload e in
    expandq, Sreturn (Some (e' q))) l
  | Sset x e =>
    let '(e', l) := unload e in
    expandq, Sset x (e' q)) l
  | Sstore κ a e =>
    let '(e', l) := unload e in
    expandq, Sstore κ a (e' q)) l
  | Sskip
  | Sexit _
  | Slabel _ _
  | Sgoto _
  | Sreturn None
  | Sswitch _ _ _
  | Scall _ _ _ _
  | Sbuiltin _ _ _
    => ret s
  end.

End TRANSFORM.

Definition all_ident acc (first next: ident) : list ident :=
  let num := (Pos.succ next - first)%positive in
  Pos.peano_rect_, list ident)
                 accp l, Pos.pred (p + first) :: l) num.


Definition doit (p: program) : program :=
  transform_program
    (λ f,
     match f with
     | Internal f =>
       let t₀ := Pos.max (least_larger f.(fn_temps)) (least_larger f.(fn_params)) in
       let '(body, t) := transform_stmte, unload e nil) f.(fn_body) tin
       (Internal (mkfunction f.(fn_sig) f.(fn_params) f.(fn_vars) (all_ident f.(fn_temps) tt) body))
     | External _ => f
     end)
    p.