Module RTLofRTLinject

Require Import Coqlib Maps Utf8.
Require Import Integers Registers.
Require Import Op.
Require RTL.

Require Import MIR Common.
Require Import INJECT RTLinject Transcode.

Definition add_move (rs rd: reg) (nd: node) : mon node :=
  if Reg.eq rs rd
  then ret nd
  else add_instr (RTL.Iop Omove (rs::nil) rd nd).

Definition add_fence (nd: node) : mon node :=
  add_instr (RTL.Ifence nd).

Definition add_nop (nd: node) : mon node :=
  add_instr (RTL.Inop nd).

Statements use odd registers. External registers are even.

Variable bi : backend_info.

Variable res : list reg.
Variable exit : node.

The abort function.
Definition abort_sig : Ast.signature := Ast.mksignature (Ast.Tint::nil) None.

Definition translate_abort (res: reg) (msg: abort_msg) (nd: node) : mon node :=
  do ns <- add_instr (RTL.Icall abort_sig (inr _ bi.(bi_abort)) (res::nil) res nd);
  add_instr (RTL.Iop (Op.Ointconst (bi.(bi_translate_abort_msg) msg)) nil res ns).

Definition leak_sig : Ast.signature := Ast.mksignature (Ast.Tint::Ast.Tint::nil) None.

Definition translate_leak (l: leak) (r out: reg) (nd: node) : mon node :=
  do ns <- add_instr (RTL.Icall leak_sig (inr _ bi.(bi_leak)) (xI out::xI r::nil) (xI out) nd);
  add_instr (RTL.Iop (Op.Ointconst (bi.(bi_translate_leak) l)) nil (xI out) ns).

Fixpoint translate_return (erl irl : list reg) (nd: node) : mon node :=
  match erl, irl with
    | e::erl', i::irl' =>
        do ns <- add_instr (RTL.Iop Op.Omove (xI i::nil) e nd);
        translate_return erl' irl' ns
    | nil, nil => ret nd
    | _, _ => error (Errors.msg "Return values number mismatch")

  Fixpoint transl_stmt (s: statement) (nd: node)
                       {struct s} : mon node :=
    match s with
    | Sassume _ _
    | Srequestperm _ _ _
    | Sfreeperm
    | Srelease
    | Sskip =>
        ret nd
    | Sop op args dst =>
        if ok_op op
        then add_instr (RTL.Iop op ( xI args) (xI dst) nd)
        else error (Errors.msg "Bad op")
    | Sload _ addr al b =>
        add_instr (RTL.Iload Ast.Mint32 addr ( xI al) (xI b) nd)
    | Sstore _ _ addr al b =>
add_instr (RTL.Istore Ast.Mint32 addr ( xI al) (xI b) nd)
    | Sseq s1 s2 =>
        do ns <- transl_stmt s2 nd;
        transl_stmt s1 ns
    | Sifthenelse cnd args strue sfalse =>
          do nfalse <- transl_stmt sfalse nd;
          do ntrue <- transl_stmt strue nd;
        add_instr (RTL.Icond cnd ( xI args) ntrue nfalse)
    | Swhile cnd args sbody =>
        do n1 <- reserve_instr;
        do n2 <- transl_stmt sbody n1;
        do xx <- update_instr n1 (RTL.Icond cnd ( xI args) n2 nd);
        ret n1
    | Srepeat sbody cnd args =>
        do n1 <- reserve_instr;
        do n2 <- transl_stmt sbody n1;
        do xx <- update_instr n1 (RTL.Icond (Op.negate_condition cnd) ( xI args) n2 nd);
        ret n2
    | Satomicmem _ aop rargs r =>
        add_instr (RTL.Iatomic aop ( xI rargs) (xI r) nd)
    | Sfence _ =>
        add_fence nd
    | Sreturn tgtl =>
        translate_return res tgtl exit
  | Sabort _ _ msg =>
      translate_abort xH msg nd
  | Leak l r o =>
      if bi.(bi_show_leak)
      then translate_leak l r o nd
      else ret nd
  | Satomic _ _
  | Sloop _
  | Sbranch _ _
            => error (Errors.msg "Cannot compile high-level statements.")


Move actual arguments to statement registers and jump to nd.
  Fixpoint init_stmt (args: list reg) (param: list reg) (nd: node) { struct param } : mon node :=
    match args, param with
      | nil, nil => ret nd
      | a::args', p::param' =>
          do nd' <- init_stmt args' param' nd;
          add_instr (RTL.Iop Op.Omove (xO a::nil) (xI p) nd')
      | _, _ => error (Errors.msg "Argument number mismatch")


Variable bi : backend_info.

Definition bindE {A B: Type} (f: Errors.res A) (g: A -> Errors.res B) : Errors.res B :=
  match f with
    | Errors.Error m => Errors.Error m
    | Errors.OK a => g a

Definition asE {A:Type} {s:state (X:=RTL.instruction)} (r: res A s) : Errors.res state :=
  match r with
    | Error m => Errors.Error m
    | OK _ s' _ => Errors.OK s'

Notation "'do' X <-- A ; B" := (bindE A (fun X => @asE _ _ B))
   (at level 200, X ident, A at level 100, B at level 200).

Definition translate_instruction (k: node) (i: instruction) : mon unit :=
  match i with
    | Inop succ => update_instr k (RTL.Inop succ)
    | Iop op args dst succ =>
      if ok_op op
      then update_instr k (RTL.Iop op ( xO args) (xO dst) succ)
      else error (Errors.msg "Bad Op!")
    | Icall sg func args dst succ =>
        let func' := match func with inl f => inl _ (xO f) | inr f => inr _ f end in
          update_instr k (RTL.Icall sg func' ( xO args) (xO dst) succ)
    | Ithreadcreate fp arg succ =>
        update_instr k (RTL.Ithreadcreate (xO fp) (xO arg) succ)
    | Icond cond args if_so if_not =>
        update_instr k (RTL.Icond cond ( xO args) if_so if_not)
    | Ireturn tgt =>
        update_instr k (RTL.Ireturn (Some (xO tgt)))
    | Iinject ic args dst succ =>
        do nd <- transl_stmt bi ( xO dst) succ ic.(ic_stmt_low) succ;
        do np <- init_stmt args ic.(ic_params) nd;
        update_instr k (RTL.Inop np)

Definition translate_code : code -> Errors.res state -> Errors.res state :=
       (fun es k i =>
          do s <-- es ;
          translate_instruction k i s).


Definition check_sig (sg: Ast.signature) : Errors.res unit :=
  if forallb (fun t => match t with Ast.Tint => true | _ => false end) sg.(Ast.sig_args)
  then Errors.OK tt
  else Errors.Error (Errors.msg "Float in sig").

Lemma check_sig_correct {sg} :
  check_sig sg = Errors.OK tt
  ∀t, In t sg.(Ast.sig_args) → t = Ast.Tint.
  unfold check_sig. generalize sg.(Ast.sig_args). clear sg.
  intros ? ? t H. bif2.
  pose proof (proj1 (forallb_forall _ _) Htrue t H).
  now destruct t.

Definition translate_function bi (f:function) : Errors.res RTL.function :=
  let nodeM : node := Psucc (max_key f.(fn_code)) in
    (check_sig f.(fn_sig))
    (fun _ => Errors.bind
          (translate_code bi f.(fn_code) (Errors.OK (init_state nodeM)))
          (fun s =>
             Errors.OK (
                ( xO f.(fn_params))

Definition transl_fundef bi : fundef -> Errors.res RTL.fundef :=
  Ast.transf_partial_fundef (translate_function bi).

Definition transf_program bi (p: program) : Errors.res RTL.program :=
  Ast.transform_partial_program (transl_fundef bi) p.