Module RTLsort

Require Import Coqlib Registers Integers Maps.

Require Import RTL.
Require Errors.

Require Import Utils.

Section SORT.

  Variable func : function.
  Let m : node := Psucc (max_key func.(fn_code)).

  Definition nodemap := PTree.t node.

  Fixpoint dfs (cnt: nat) (k: node) (next: node) (nm: nodemap)
    { struct cnt }
    : Errors.res (node * nodemap) :=
    match cnt with
      | S cnt' =>
          match func.(fn_code) ! k with
            | Some i =>
                let nm' := nm ! k <- (m - next)%positive in
                let next' := Psucc next in
                  List.fold_left
                      (fun acc n =>
                         match acc with
                           | Errors.OK (acc_next, acc_nm) =>
                               match acc_nm ! n with
                                 | Some _ => acc
                                 | None => dfs cnt' n acc_next acc_nm
                               end
                           | Errors.Error _ => acc
                         end
                      )
                      (successors_instr i)
                      (Errors.OK (next', nm'))
            | None => Errors.Error (Errors.msg "DFS")
          end
      | O => Errors.Error (Errors.msg "Dame mรกs gasolina")
    end.

  Definition sort_cfg : Errors.res nodemap :=
    match dfs (nat_of_P m) func.(fn_entrypoint) xH (PTree.empty _) with
      | Errors.OK (_, nm) => Errors.OK nm
      | Errors.Error msg => Errors.Error msg
    end.

End SORT.

Section SUCCESSORS.
  Variable c : code.
  Let haskey k : bool := match c ! k with Some _ => true | None => false end.
  Let f := fun (_: node) i => forallb haskey (successors_instr i).

  Lemma check_successors_dec k i :
    { f k i } + { ~ f k i }.
Proof.
subst f. simpl. destruct (forallb haskey (successors_instr i)); intuition. Qed.

  Definition check_successors : Errors.res unit :=
    if tree_forall f check_successors_dec c
    then Errors.OK tt
    else Errors.Error (Errors.msg "Bad successors").

  Lemma check_successors_correct :
    check_successors = Errors.OK tt ->
    forall pc i, c ! pc = Some i ->
                 forall s, In s (successors_instr i) -> c ! s <> None.
Proof.
    unfold check_successors.
    case_eq (tree_forall f check_successors_dec c); clarify.
    intros H _.
    assert (K := tree_forall_correct f check_successors_dec c).
    rewrite H in K. unfold forall_spec in K.
    intros pc i U s Hs.
    generalize (K pc i U). unfold f, is_true. rewrite forallb_forall.
    intros V. generalize (V _ Hs). unfold haskey.
    case (c ! s); clarify.
  Qed.

End SUCCESSORS.

Section BIJECTION.

  Variable c : code.
  Variable nm : nodemap.

  Lemma check_dom_dec k (i: instruction) :
   {match nm ! k with
    | Some _ => true
    | None => false
    end} + {~ match nm ! k with
              | Some _ => true
              | None => false
              end}.
Proof.
destruct (nm ! k); intuition. Qed.

  Definition check_dom : Errors.res unit :=
    if tree_forall (fun k _ => match nm ! k with Some _ => true | None => false end) check_dom_dec c
    then Errors.OK tt
    else Errors.Error (Errors.msg "Bad domain").

  Lemma check_dom_correct (HOK: check_dom = Errors.OK tt) {k v} :
    c ! k = Some v -> nm ! k <> None.
Proof.
    unfold check_dom in HOK.
    case_eq (tree_forall _ check_dom_dec c); intros H; rewrite H in HOK; inv HOK.
    pose proof (tree_forall_correct _ check_dom_dec c) as K.
    unfold forall_spec in K. rewrite H in K.
    intros U. generalize (K k v U).
    destruct (nm ! k); clarify.
  Qed.

  Lemma check_ran_dec :
    forall k a : positive,
   {match rev_find peq nm a with
    | Some k' => peq k k'
    | None => false
    end} +
   {~ match rev_find peq nm a with
      | Some k' => peq k k'
      | None => false
      end}.
Proof.
    intros k a. case (rev_find peq nm a). 2: intuition. intros p. case (peq k p); intuition.
  Qed.

  Definition check_ran : Errors.res unit :=
    if tree_forall (fun k v => match rev_find peq nm v with Some k' => peq k k' | None => false end)
                   check_ran_dec nm
    then Errors.OK tt
    else Errors.Error (Errors.msg "Bad image").

  Lemma check_ran_inj (HOK:check_ran = Errors.OK tt) { n n' k } :
    nm ! n = Some k -> nm ! n' = Some k -> n = n'.
Proof.
    unfold check_ran in HOK.
    case_eq (tree_forall _ check_ran_dec nm); intros H; rewrite H in HOK; inv HOK.
    pose proof (tree_forall_correct _ check_ran_dec nm) as K.
    unfold forall_spec in K. rewrite H in K.
    intros A B.
    assert (U := K n k A).
    assert (V := K n' k B).
    case_eq (rev_find peq nm k);[intros u|]; intros X; rewrite X in *; clarify.
    destruct (peq n u); clarify.
    destruct (peq n' u); clarify.
  Qed.

End BIJECTION.

Section TR_NODES.

  Variable nm : nodemap.
  Variable func : function.

  Definition upd_node (n: node) : node :=
    match nm ! n with
      | Some k => k
      | None => xH
    end.

  Local Notation f := (upd_node).

  Definition upd_instr (i: instruction) : instruction :=
    match i with
      | Inop s => Inop (f s)
      | Iop a b c s => Iop a b c (f s)
      | Iload a b c d s => Iload a b c d (f s)
      | Istore a b c d s => Istore a b c d (f s)
      | Icall a b c d s => Icall a b c d (f s)
      | Icond a b s t => Icond a b (f s) (f t)
      | Ireturn a => Ireturn a
      | Ithreadcreate a b s => Ithreadcreate a b (f s)
      | Iatomic a b c s => Iatomic a b c (f s)
      | Ifence s => Ifence (f s)
    end.

  Definition tr_code (c: code) : code :=
    PTree.fold
         (fun q k i => q ! (f k) <- (upd_instr i))
         c
         (PTree.empty _).

  Definition tr_nodes : function :=
    mkfunction func.(fn_sig) func.(fn_params) func.(fn_stacksize)
               (tr_code func.(fn_code))
               (f func.(fn_entrypoint)).

End TR_NODES.

Section TR_PROGRAM.

  Definition check_ep (f: function) : Errors.res unit :=
    match f.(fn_code) ! (f.(fn_entrypoint)) with
        | Some _ => Errors.OK tt
        | None => Errors.Error (Errors.msg "No instruction at function entry")
    end.

  Definition compute_nodemap (f: function) : Errors.res nodemap :=
    Errors.bind (sort_cfg f)
    (fun nm => Errors.bind (check_successors f.(fn_code))
    (fun _ => Errors.bind (check_dom f.(fn_code) nm)
    (fun _ => Errors.bind (check_ran nm)
    (fun _ => Errors.bind (check_ep f)
    (fun _ => Errors.OK nm))))).

  Definition transf_fundef (f: fundef) : Errors.res fundef :=
    match f with
      | Ast.Internal func =>
          Errors.bind (compute_nodemap func)
                (fun nm =>
                     Errors.OK (Ast.Internal (tr_nodes nm func)))
      | Ast.External _ => Errors.OK f
    end.

  Definition transf_program (p : program) : Errors.res program :=
    Ast.transform_partial_program transf_fundef p.

End TR_PROGRAM.