Module Opt

Require Import Coqlib.
Require Import Maps.
Require Import TrMaps2.
Require Import AST.
Require Import Op.
Require Import Registers.
Require Import Utils.
Require Import Integers.
Require Import Floats.
Require Import Classical.
Require Import Lattice.
Require Import Iteration.
Require Import DLib.
Require Import Kildall.
Require Import KildallComp.
Require Import Dom.
Require Import SSA.
Require Import SSAutils.

This file defines a generic program optimization. In our SSA middle-end, we instantiate this to SCCP and GVN-CSE, but in general, it could also be instantiated to any intra-procedural, scala optimization.

The optimization can be "conditional", hence we define the type m_exec for storing executable flag for CFG edges.

Definition m_exec := P2Map.t bool.

Section Opt.

Definition of the optimization

Variable res : Type.
Variable analysis : function -> (res * m_exec).
Variable transf_instr : res -> node -> instruction -> instruction.

Definition transf_function (f : function) : SSA.function :=
  let (res,m_ex) := analysis f in
  mkfunction
     f.(fn_sig)
     f.(fn_params)
     (fn_stacksize f)
     (PTree.map (transf_instr res) (fn_code f))
     f.(fn_phicode)
     f.(fn_max_indice)
     f.(fn_entrypoint)
     f.(fn_ext_params)
     f.(fn_dom)
.

Proof that it preserves the wf_ssa_function predicate

Hypothesis new_code_same_or_Iop : forall f:function,
  wf_ssa_function f ->
  forall pc ins,
    (fn_code f)!pc = Some ins ->
    transf_instr (fst (analysis f)) pc ins = ins
    \/ match ins with
         | Iop _ _ dst pc'
         | Iload _ _ _ dst pc'
         | Icall _ _ _ dst pc'
         | Ibuiltin _ _ (BR dst) pc' =>
           exists op' args',
           transf_instr (fst (analysis f)) pc ins = Iop op' args' dst pc'
           /\ forall x, In x args' -> exists d : node, def f x d /\ sdom f d pc
         | _ => False
       end.

Lemma fn_params_transf_function : forall f,
  fn_params (transf_function f) = fn_params f.
Proof.
  unfold transf_function; eauto.
  intros. destruct (analysis f) ; eauto.
Qed.

Lemma fn_phicode_transf_function : forall f,
  fn_phicode (transf_function f) = fn_phicode f.
Proof.
  unfold transf_function; auto.
  intros. destruct (analysis f) ; eauto.
Qed.

Lemma fn_entrypoint_transf_function : forall f,
  fn_entrypoint (transf_function f) = fn_entrypoint f.
Proof.
  unfold transf_function; auto.
  intros. destruct (analysis f) ; eauto.
Qed.

Lemma fn_code_transf_function : forall (f:function) pc ins,
 (fn_code f)!pc = Some ins ->
 exists ins, (SSA.fn_code (transf_function f))!pc = Some ins.
Proof.
  simpl. intros.
  unfold transf_function.
  destruct (analysis f); simpl.
  rewrite PTree.gmap in *.
  rewrite H. simpl.
  eexists; eauto.
Qed.

Lemma fn_code_transf_function' : forall f pc ins,
  (SSA.fn_code (transf_function f))!pc = Some ins ->
  exists ins, (fn_code f)!pc = Some ins.
Proof.
  simpl. intros.
  unfold transf_function in *.
  destruct (analysis f); simpl in *.
  rewrite PTree.gmap in *.
  destruct ((fn_code f)!pc); eauto.
Qed.

Hint Constructors assigned_code_spec.

Lemma assigned_phi_spec_same : forall (f:function) pc r,
  assigned_phi_spec (f.(fn_phicode)) pc r <->
  assigned_phi_spec ((transf_function f).(fn_phicode)) pc r.
Proof.
  intros.
  unfold transf_function in *.
  destruct (analysis f); simpl in *.
  reflexivity.
Qed.

Definition assigned_code_spec2 (code : code) (pc : node) (x:reg) : Prop :=
  match code!pc with
    | Some (Iop _ _ dst succ)
    | Some (Iload _ _ _ dst succ)
    | Some (Icall _ _ _ dst succ) => dst = x
    | Some (Ibuiltin _ _ (BR dst) succ) => dst = x
    | _ => False
  end.

Lemma assigned_code_spec_equiv : forall code pc r,
  assigned_code_spec code pc r <->
  assigned_code_spec2 code pc r.
Proof.
  split; intros H.
  - inv H ; unfold assigned_code_spec2; rewrite H0; auto.
  - unfold assigned_code_spec2 in *.
    flatten H; go.
Qed.

Lemma assigned_code_spec_same : forall (f:function) (Hwf:wf_ssa_function f) pc r,
  assigned_code_spec (fn_code f) pc r <->
  assigned_code_spec ((transf_function f).(SSA.fn_code)) pc r.
Proof.
  intros f Hwf pc r.
  repeat rewrite assigned_code_spec_equiv.
  unfold transf_function, assigned_code_spec2; simpl; intros.
  case_eq (analysis f) ; intros lv es Ha ; simpl in *.
  rewrite PTree.gmap.
  destruct ((fn_code f)!pc) eqn:E; simpl; try tauto.
  destruct (new_code_same_or_Iop f Hwf pc i) as [Hi | Hi]; auto.
  - rewrite Ha in Hi; simpl in * ; rewrite Hi; tauto.
  - rewrite Ha in Hi; simpl in * .
    flatten Hi; try (elim Hi; fail); destruct Hi as (op' & args' & Hi & Hi2); rewrite Hi; tauto.
Qed.


Lemma use_code_transf_function: forall f:function,
  wf_ssa_function f -> forall x u,
  use_code (transf_function f) x u ->
  use_code f x u \/ exists d, def f x d /\ sdom f d u.
Proof.
  intros f H x u H0.
  inv H0;
  simpl in H1;
  unfold transf_function in *;
  case_eq (analysis f) ; intros lv es Ha ; rewrite Ha in * ; simpl in *;
  rewrite PTree.gmap in H1;
  destruct ((fn_code f)!u) eqn:E; try (inv H1; fail);
  simpl in H1;
  (destruct (new_code_same_or_Iop f H u i) as [Hi | Hi]; auto;
   [left;
    rewrite Ha in * ; simpl in * ;
    rewrite Hi in H1;
    rewrite <- E in H1;
    go
   |idtac]);
  destruct i;
    try (elim Hi; fail); try (destruct b ; try intuition auto) ; try (destruct Hi as (op' & args' & Hi1 & Hi2);
        try (rewrite Ha in * ; simpl in *) ;
        try (rewrite Ha in * ; simpl in * ; congruence);
        right; apply Hi2; congruence).
Qed.

Lemma new_code: forall pc ins (f:function) (Hwf:wf_ssa_function f),
  (fn_code f)!pc = Some ins ->
  successors_instr ins = successors_instr (transf_instr (fst (analysis f)) pc ins).
Proof.
  intros pc ins f Hwf Hi.
  destruct (new_code_same_or_Iop f Hwf pc ins); auto.
  - rewrite H; auto.
  - destruct ins; try (elim H; fail); try (destruct b ; try intuition auto) ; destruct H as (op' & args' & H1 & H2); rewrite H1; auto.
Qed.

Hint Constructors use_code: ssa.
Hint Constructors use: ssa.

Lemma successors_transf_function: forall (f:function) (Hwf:wf_ssa_function f) pc,
  (successors (transf_function f))!pc = (successors f)!pc.
Proof.
  unfold successors; intros.
  unfold transf_function.
  case_eq (analysis f) ; intros lv es Ha ; simpl in *.
  repeat rewrite PTree.gmap1.
  unfold option_map.
  rewrite PTree.gmap.
  unfold option_map.
  flatten.
  replace lv with (fst (analysis f)) by (rewrite Ha ; auto).
  erewrite <- new_code; eauto.
Qed.

Lemma map1_assoc:
  forall (A B C: Type) (f1: A -> B) (f2: B -> C) (m: PTree.t A),
    PTree.map1 (fun x => f2 (f1 x)) m =
    PTree.map1 f2 (PTree.map1 f1 m).
Proof.
  intros.
  erewrite <- map1_opt ; eauto.
  erewrite <- map1_opt ; eauto.
  erewrite <- map1_opt ; eauto.
  unfold PTree.map.
  rewrite <- xmap_assoc.
  auto.
Qed.


Lemma join_point_transf_function : forall (f:function) (Hwf:wf_ssa_function f) j,
  join_point j (transf_function f) <-> join_point j f.
Proof.
  split; intros.
  - inv H.
    erewrite @same_successors_same_predecessors with
        (m2:= (fn_code f))
        (f2:= successors_instr) in Hpreds; eauto.
    + econstructor; eauto.
    + eapply successors_transf_function; eauto.
  - inv H.
    erewrite @same_successors_same_predecessors with
        (m2:= (SSA.fn_code (transf_function f)))
        (f2:= successors_instr) in Hpreds; eauto.
    + econstructor; eauto.
    + generalize (successors_transf_function f).
      unfold successors. intros Hsuccs i.
      rewrite Hsuccs at 1; auto.
Qed.

Lemma make_predecessors_transf_function: forall (f:function) (Hwf:wf_ssa_function f),
  (Kildall.make_predecessors (SSA.fn_code (transf_function f)) successors_instr) =
  (Kildall.make_predecessors (fn_code f) successors_instr).
Proof.
  intros.
  eapply same_successors_same_predecessors.
  apply successors_transf_function; auto.
Qed.

Lemma successors_list_transf_function: forall (f:function) (Hwf:wf_ssa_function f) pc,
  (Kildall.successors_list (successors (transf_function f)) pc) =
  (Kildall.successors_list (successors f) pc).
Proof.
  unfold Kildall.successors_list; intros.
  rewrite successors_transf_function; auto.
Qed.


Lemma use_phicode_transf_function: forall f:function,
  wf_ssa_function f -> forall x u,
  use_phicode (transf_function f) x u ->
  use_phicode f x u.
Proof.
  intros f H x u H0.
  inv H0.
  simpl in PHIB.
  replace (Kildall.make_predecessors (SSA.fn_code (transf_function f))
               successors_instr) with
          (Kildall.make_predecessors (fn_code f) successors_instr) in KPRED.
  - unfold transf_function in * ; destruct analysis ; go.
  - rewrite make_predecessors_transf_function; auto.
Qed.


Lemma use_transf_function: forall (f:function) x u,
  wf_ssa_function f ->
  use (transf_function f) x u ->
  use f x u \/ exists d, def f x d /\ sdom f d u.
Proof.
  intros f x u H H0.
  inv H0.
  - edestruct use_code_transf_function; eauto.
    left; go.
  - apply use_phicode_transf_function in H1; go.
Qed.

Lemma def_transf_function: forall (f:function) x d,
  wf_ssa_function f ->
  def (transf_function f) x d -> def f x d.
Proof.
  intros f.
  unfold transf_function in * ;
  case_eq (analysis f) ; intros lv es Ha.
  intros x d H H1; inv H1; simpl; go.
  - inv H0; go.
    simpl in *.
    destruct H1 as (pc & Hpc).
    edestruct use_transf_function; eauto.
    unfold transf_function ; rewrite Ha ; simpl; eauto.
    + econstructor 1; eauto.
      econstructor 2; eauto.
      replace (SSA.fn_code f) with (fn_code f).
      intros; rewrite assigned_code_spec_same; auto.
      unfold transf_function ; rewrite Ha ; simpl; eauto.
      reflexivity.

    + destruct H0 as (d & Hd & Hs).
      inv Hd; go.
      * eelim (H3 d).
        simpl.
        assert (Has:= (assigned_code_spec_same f H d x)); eauto.
        unfold transf_function in * ; rewrite Ha in * ; simpl in *; auto.
        eapply Has; auto.
      * eelim (H2 d).
        assert (Has:= (assigned_phi_spec_same f d x)); eauto.
  - simpl in *.
    assert (Has:= (assigned_code_spec_same f H d x)); eauto.
    unfold transf_function in * ; rewrite Ha in * ; simpl in *; auto.
    eapply Has in H0; auto.
Qed.
  

Lemma cfg_transf_function : forall (f:function) (Hwf:wf_ssa_function f) i j,
  cfg f i j <-> cfg (transf_function f) i j.
Proof.
  intros f i j.
  split; intros.
  - inv H.
    exploit fn_code_transf_function ; eauto.
    intros [ins' Hins'].
    intros. unfold transf_function in *.
    case_eq (analysis f) ; intros lv es Ha ; simpl in *.
    rewrite Ha in * ; simpl in *.
    econstructor; simpl; eauto.
    rewrite PTree.gmap in *.
    unfold fn_code in *.
    unfold option_map in *. rewrite HCFG_ins in *.
    inv Hins'.
    replace lv with (fst (analysis f)) in * by (rewrite Ha ; auto).
    erewrite <- new_code; eauto.
  - inv H.
    exploit fn_code_transf_function' ; eauto.
    intros [ins' Hins'].
    unfold transf_function in *; case_eq (analysis f) ; intros lv es Ha ;
    rewrite Ha in * ; simpl in *.
    econstructor; simpl; eauto.
    rewrite PTree.gmap in *.
    unfold option_map in *. rewrite Hins' in *.
    inv HCFG_ins.
    replace lv with (fst (analysis f)) in * by (rewrite Ha ; auto).
    erewrite new_code; eauto.
Qed.

Lemma reached_transf_function : forall (f:function) (Hwf:wf_ssa_function f) i,
  reached f i <-> reached (transf_function f) i.
Proof.
  split; intros.
  rewrite fn_entrypoint_transf_function.
  apply star_eq with (cfg f); auto.
  intros; rewrite <- cfg_transf_function; auto.
  apply star_eq with (cfg (transf_function f)); auto.
  intros a b; rewrite <- cfg_transf_function; auto.
  rewrite fn_entrypoint_transf_function in *; auto.
Qed.

Lemma exit_exit : forall (f:function) (Hwf: wf_ssa_function f) pc,
                    exit f pc <-> exit (transf_function f) pc.
Proof.
  split.
  - intros.
    assert (Hins: exists ins, (fn_code f) ! pc = Some ins).
    { unfold exit in *.
      flatten NOSTEP; eauto.
      inv H; go.
    }
    destruct Hins as [ins0 Hins0].
    exploit (fn_code_transf_function f pc); eauto.
    intros [ins Hins].
    unfold transf_function in *;
    unfold exit in *; rewrite Hins in *.
    unfold fn_code in *. rewrite Hins0 in *.
    case_eq (analysis f) ; intros lv es Ha ; rewrite Ha in *; simpl in *.
    rewrite PTree.gmap in *.
    unfold option_map in *.
    rewrite Hins0 in *.
    inv Hins.
    exploit new_code_same_or_Iop; eauto.
    flatten; auto; (intros [Heq | Hcont]; [idtac | elim Hcont]);
    try intuition;
    try (replace lv with (fst (analysis f)) in * by (rewrite Ha ; auto));
    try congruence.
  - intros.
    assert (Hins: exists ins, (SSA.fn_code (transf_function f)) ! pc = Some ins).
    { unfold exit in *.
      flatten NOSTEP; eauto.
      inv H; go.
    }
    destruct Hins as [ins0 Hins0].
    unfold transf_function in *.
    unfold exit in *; rewrite Hins0 in *.
    case_eq (analysis f) ; intros lv es Ha ; rewrite Ha in *; simpl in *.
    unfold fn_code in *.
    rewrite PTree.gmap in *.
    unfold option_map in *.
    flatten Hins0; intuition;
    match goal with
    | id: analysis ?ff = _ |- _ => set (f:= ff) in *
    end.
    + destruct (new_code_same_or_Iop f Hwf pc i); auto.
      * replace lv with (fst (analysis f)) in * by (rewrite Ha ; auto).
        rewrite H0 in H1 ; inv H1; auto.
      * replace lv with (fst (analysis f)) in * by (rewrite Ha ; auto);
          destruct i; try (elim H1; fail);
            try (destruct b ; try intuition auto) ;
            destruct H1 as (op' & args' & H2 & H3); congruence.
    + destruct (new_code_same_or_Iop f Hwf pc i); auto.
      * replace lv with (fst (analysis f)) in * by (rewrite Ha ; auto).
        rewrite H0 in H ; inv H; auto.
      * replace lv with (fst (analysis f)) in * by (rewrite Ha ; auto);
          destruct i; try (elim H; fail);
            try (destruct b ; try intuition auto) ;
            destruct H as (op' & args' & H2 & H3); congruence.
    + destruct (new_code_same_or_Iop f Hwf pc i); auto.
      * replace lv with (fst (analysis f)) in * by (rewrite Ha ; auto).
        rewrite H0 in H1 ; inv H1; auto.
      * replace lv with (fst (analysis f)) in * by (rewrite Ha ; auto);
          destruct i; try (elim H1; fail);
            try (destruct b ; try intuition auto) ;
            destruct H1 as (op' & args' & H2 & H3); congruence.
Qed.

Lemma ssa_path_transf_function : forall (f:function) (Hwf:wf_ssa_function f) i p j,
  SSApath f i p j <-> SSApath (transf_function f) i p j.
Proof.
  split.
  - induction 1; go.
    econstructor 2 with s2; auto.
    inv STEP.
    + assert (cfg (transf_function f) pc pc')
        by (rewrite <- cfg_transf_function; eauto; econstructor; eauto).
      inv H0.
      econstructor; eauto.
      rewrite <- reached_transf_function; auto.
      go.
    + assert (Hins: exists ins, (fn_code f) ! pc = Some ins).
      { unfold exit in *.
        flatten NOSTEP; eauto.
        inv H; go.
      }
      destruct Hins as [ins0 Hins0].
      exploit (fn_code_transf_function f pc); eauto. intros [ins Hins].
      econstructor; eauto.
      * eapply reached_transf_function in CFG; eauto.
      * eapply (exit_exit f); eauto.
  - induction 1; go.
    econstructor 2 with s2; auto.
    inv STEP.
    + assert (cfg f pc pc') by (rewrite cfg_transf_function; eauto; econstructor; eauto).
      inv H0.
      econstructor; eauto.
      go.
    + assert (Hins: exists ins, (SSA.fn_code (transf_function f)) ! pc = Some ins).
      { unfold exit in *.
        flatten NOSTEP; eauto.
        inv H; go.
      }
      destruct Hins as [ins0 Hins0].
      eelim fn_code_transf_function' ; eauto.
      intros ins' H'.
      econstructor; eauto.
      eapply exit_exit; eauto.
Qed.
  
Lemma dom_transf_function : forall (f:function) (Hwf:wf_ssa_function f) i j,
  dom f i j <-> dom (transf_function f) i j.
Proof.
  split; intros.
  - inv H; constructor.
    + rewrite <- reached_transf_function; auto.
    + unfold entry in *. intros.
      apply PATH.
      rewrite ssa_path_transf_function; auto.
      rewrite fn_entrypoint_transf_function in *.
      auto.
  - inv H; constructor.
    + rewrite reached_transf_function; auto.
    + intros.
      apply PATH.
      rewrite <- ssa_path_transf_function; auto.
      rewrite fn_entrypoint_transf_function in *.
      unfold entry in *; auto.
Qed.

Lemma transf_function_Inop : forall (f:function) (Hwf:wf_ssa_function f) pc jp,
  (fn_code f) ! pc = Some (Inop jp) ->
  (SSA.fn_code (transf_function f)) ! pc = Some (Inop jp).
Proof.
  intros.
  unfold transf_function; simpl.
  case_eq (analysis f) ; intros lv es Ha ; try rewrite Ha in * ; simpl in *.
  rewrite PTree.gmap.
  rewrite H; simpl.
  destruct (new_code_same_or_Iop f Hwf pc (Inop jp)); auto.
  - replace lv with (fst (analysis f)) by (rewrite Ha ; auto). rewrite H0; auto.
  - elim H0.
Qed.

Lemma transf_function_preserve_wf_ssa_function : forall (f:function),
  wf_ssa_function f -> wf_ssa_function (transf_function f).
Proof.
  constructor.

  - generalize (fn_ssa _ H).
    unfold unique_def_spec; split; intros.
    repeat rewrite <- assigned_phi_spec_same; auto.
    repeat rewrite <- assigned_code_spec_same; auto.
    destruct H0; auto.
    unfold transf_function in *.
    destruct (analysis f) in * ; simpl in *; eauto.
    destruct H0 as [_ H0]; eauto.
    
  - intros x; rewrite fn_params_transf_function; intros T.
    elim fn_ssa_params with (1:=H) (2:=T); split; intros.
    rewrite <- assigned_code_spec_same; auto.
    rewrite <- assigned_phi_spec_same; auto.

  - intros x u d Hu Hd.
    rewrite <- dom_transf_function.
    apply def_transf_function in Hd; auto.
    apply use_transf_function in Hu; auto.
    destruct Hu as [Hu|[y [Hu1 Hu2]]].
    + apply fn_strict with x; auto.
    + assert (y=d) by (eapply ssa_def_unique; eauto); subst.
      inv Hu2; auto.
    + assumption.

  - intros x pc Hu Ha.
    rewrite <- assigned_code_spec_same in Ha; auto.
    apply use_code_transf_function in Hu; auto.
    destruct Hu.
    + eelim fn_use_def_code; eauto.
    + destruct H0 as [d [D1 D2]].
      assert (def f x pc) by eauto.
      destruct D2.
      elim NEQ; eapply ssa_def_unique; eauto.

  - intros pc block args x.
    rewrite fn_phicode_transf_function.
    rewrite make_predecessors_transf_function; auto.
    eapply fn_wf_block; eauto.
    
  - intros.
    rewrite join_point_transf_function in H0; auto.
    rewrite successors_list_transf_function in H1; auto.
    exploit fn_normalized; eauto; intros.
    exploit transf_function_Inop; eauto.
    
  - intros.
    rewrite join_point_transf_function; auto.
    rewrite fn_phicode_transf_function; auto.
    eapply fn_phicode_inv; eauto.

  - intros pc ins.
    rewrite <- reached_transf_function; intros; auto.
    eelim fn_code_transf_function'; eauto.
    
  - intros pc pc' instr Hi Hj.
    assert (cfg f pc pc') by (rewrite cfg_transf_function; eauto; econstructor; eauto).
    inv H0.
    exploit fn_code_closed; eauto.
    intros [ii HH].
    eapply fn_code_transf_function; eauto.

  - exploit fn_entry; eauto.
    intros [s Hs].
    exists s; rewrite fn_entrypoint_transf_function.
    exploit transf_function_Inop; eauto.

  - intros pc; rewrite <- cfg_transf_function; auto; rewrite fn_entrypoint_transf_function.
    apply fn_entry_pred; auto.

  - intros.
    inv H0.
    + unfold transf_function in *; simpl in *.
      destruct analysis; simpl in *.
      inv H; go.
    + destruct H1 as [pc H1].
      rewrite fn_phicode_transf_function in H2.
      exploit use_transf_function; eauto.
      unfold transf_function.
      destruct analysis; simpl in *.
      destruct 1.
      * inv H.
          eapply fn_ext_params_complete; eauto.
          econstructor 2; eauto.
          intros.
          rewrite (assigned_code_spec_same f) ; eauto.
          constructor; auto.
      * { destruct H0 as (d & H0 & _).
          inv H0; go.
          - inv H ; eauto.
          - eelim H3; rewrite <- assigned_code_spec_same; eauto.
          - eelim H2; eauto.
        }
     
  - split.
    + intros.
      unfold transf_function ;
        destruct analysis ; simpl in * ; auto.
      apply H.(fn_dom_wf _).(dom_defined _).
      rewrite reached_transf_function; auto.
    + intros.
      unfold transf_function in *;
        destruct analysis ; simpl in * ; auto.
      eapply H.(fn_dom_wf _).(dom_pre_inj _); eauto.
    + intros.
      unfold transf_function in *;
        destruct analysis ; simpl in * ; auto.
      eapply H.(fn_dom_wf _).(dom_pre_le_post _); eauto.
    + intros.
      unfold transf_function in *;
        destruct analysis ; simpl in * ; auto.
      eapply H.(fn_dom_wf _).(dom_cases _); eauto.
    + intros. rewrite <- dom_transf_function; auto.
      unfold transf_function ;
        destruct analysis ; simpl in * ; auto.
      unfold dom_test; simpl; fold (dom_test f).
      eapply H.(fn_dom_wf _).(dom_test_correct _); eauto.
      rewrite reached_transf_function; auto.
Qed.

End Opt.