Module RTLdefgen

Require Import MoreRTL.
Import Coqlib.
Import Integers.
Import Floats.
Import Globalenvs.
Import Errors.
Import Memdata.
Import AST.
Import Values.
Import Maps.
Import Op.
Import Registers.
Import RTL.
Import AssocList.
Import Annotations.
Require Import ArithLib.
Import Utf8.

Open Scope error_monad_scope.

Record state: Type :=
  mkstate {
      st_nextreg: positive;
      st_nextnode: positive;
      st_code: code;
      st_wf: forall (pc: positive), Plt pc st_nextnode \/ st_code!pc = None
    }.

Remark max_pc_function_st_wf:
  forall f pc, Plt pc (Psucc (max_pc_function f)) \/ f.(fn_code)!pc = None.
Proof.
  intros. case_eq (f.(fn_code)!pc); intros.
  - left. generalize (max_pc_function_sound f pc i H).
    intros. xomega.
  - right; reflexivity.
Qed.

Definition init_state (f: function) :=
  mkstate (Psucc (max_reg_function f)) (Psucc (max_pc_function f)) f.(fn_code) (max_pc_function_st_wf f).

Remark add_instr_wf:
  forall s i pc,
    Plt pc (Psucc s.(st_nextnode)) \/ (PTree.set s.(st_nextnode) i s.(st_code))!pc = None.
Proof.
  intros. rewrite PTree.gsspec.
  destruct (peq pc (st_nextnode s)).
  - left. rewrite e. eapply Plt_succ.
  - elim (st_wf s pc).
    + intros; left; xomega.
    + intros; right; assumption.
Qed.
      
Definition add_instr (instr: instruction) (s: state): state :=
  mkstate s.(st_nextreg) (Psucc (s.(st_nextnode))) (PTree.set s.(st_nextnode) instr s.(st_code)) (add_instr_wf s instr).

Definition new_reg (s: state): reg * state :=
  (s.(st_nextreg), mkstate (Psucc s.(st_nextreg)) s.(st_nextnode) s.(st_code) s.(st_wf)).

Section TRANSL.

  Variable prog: program.
  Variable ge: genv.
  Variable STK: ident.
  Variable SIZE: ident.

  Definition is_singleton (alpha: list ablock): bool :=
    match alpha with
    | nil => false
    | _::nil => true
    | (ABglobal id _)::alpha => fold_left (fun acc x => match x with ABlocal _ _ _ => false | ABglobal id' _ => ((ident_eq id id') && acc) end) alpha true
    | (ABlocal d id _)::alpha => fold_left (fun acc x => match x with ABglobal _ _ => false | ABlocal d' id' _ => ((eq_nat_dec d d') && (ident_eq id id') && acc) end) alpha true
    end.

  Fixpoint put_stack_range_in_regs kappa (ofs: Int.int) (depth: nat) (n: nat) (s: state): res (list reg * state) :=
    match n with
    | S n =>
      let (post, s) :=
          if Zdivides_dec (align_chunk kappa) (Int.unsigned ofs)
          then
            let (reg, s) := new_reg s in
            let (reg', s) := new_reg s in
            let (reg'', s) := new_reg s in
            let s := add_instr (Iload (xH, nil) Mint32 (Aglobal SIZE Int.zero) nil reg'' (Psucc s.(st_nextnode))) s in
            let s := add_instr (Iload (xH, nil) Mint32 (Abased STK (Int.repr (-4 * Z.of_nat depth))) (reg''::nil) reg' (Psucc s.(st_nextnode))) s in
            let s := add_instr (Iop (if Int.eq_dec ofs Int.zero then Omove else Oaddimm ofs) (reg'::nil) reg (Psucc s.(st_nextnode))) s in ((fun a => reg::a), s) else (id, s) in
      do (a, s) <- put_stack_range_in_regs kappa (Int.add ofs Int.one) depth n s;
      OK (post a, s)
    | _ => OK (nil, s)
    end.

  Fixpoint put_symbol_range_in_regs kappa (g: ident) (ofs: Int.int) (n: nat) (s: state): res (list reg * state) :=
    match n with
    | S n =>
      let (post, s) :=
          if Zdivides_dec (align_chunk kappa) (Int.unsigned ofs)
          then
            let (reg, s) := new_reg s in
            let s := add_instr (Iop (Oaddrsymbol g ofs) nil reg (Psucc s.(st_nextnode))) s in
            ((fun a => reg :: a), s)
          else (id, s) in
      do (a, s) <- put_symbol_range_in_regs kappa g (Int.add ofs Int.one) n s;
      OK (post a, s)
    | _ => OK (nil, s)
    end.

  Definition make_zero_for_type (ty: AST.typ) (r: reg) (s: state) : state :=
    match ty with
    | Tint | Tany32 | Tany64 => s
    | Tfloat => add_instr (Iop (Ofloatconst Float.zero) nil r (Psucc s.(st_nextnode))) s
    | Tsingle => add_instr (Iop (Osingleconst Float32.zero) nil r (Psucc s.(st_nextnode))) s
    | Tlong => add_instr (Iop Omakelong (r :: r :: nil) r (Psucc s.(st_nextnode))) s
    end.

  Definition add_return_undef opt (n: ident) (s: state) :=
    let (reg0, s) := new_reg s in
    let (reg1, s) := new_reg s in
    let (reg, s) := new_reg s in
    let s := add_instr (Iop (Ointconst Int.zero) nil reg0 (Psucc s.(st_nextnode))) s in
    let s := add_instr (Iop (Ointconst (Int.repr (Zpos n))) nil reg1 (Psucc s.(st_nextnode))) s in
    let s := add_instr (Iop Odiv (reg1::reg0::nil) reg (Psucc s.(st_nextnode))) s in
    let s := add_instr (Icond (Ccompuimm Cne Int.zero) (reg::nil) (Psucc s.(st_nextnode)) (Psucc s.(st_nextnode))) s in
    match opt with
    | None => add_instr (Ireturn None) s
    | Some ty =>
      let s := make_zero_for_type ty reg0 s in
      add_instr (Ireturn (Some reg0)) s
    end.

  Fixpoint put_all_in_regs kappa (alpha: list ablock) (s: state): res (list reg * state) :=
    match alpha with
      | nil => OK (nil, s)
      | (ABlocal depth varname (base, bound))::alpha =>
        let size := Int.unsigned (Int.sub bound base) + 1 in
        match size with
          | Zpos size =>
            do (l', s) <- put_stack_range_in_regs kappa base depth (Pos.to_nat size) s;
            do (l, s) <- put_all_in_regs kappa alpha s;
            OK (l' ++ l, s)
          | _ => Error (msg "size is non positive")
        end
      | (ABglobal b (base, bound))::alpha =>
        let size := Int.unsigned (Int.sub bound base) + 1 in
        match size with
          | Zpos size =>
            match Genv.find_symbol ge b with
              | Some id =>
                do (l', s) <- put_symbol_range_in_regs kappa b base (Pos.to_nat size) s;
                do (l, s) <- put_all_in_regs kappa alpha s;
                OK (l' ++ l, s)
              | None => Error (msg "symbol is unused")
            end
          | _ => Error (msg "size is non positive")
        end
    end.

  Definition put_stack_range_in_regs' (base bound: Int.int) (depth: nat) (s: state): res ((reg + (reg * reg)) * state) :=
    if Int.eq_dec base bound then
      let (reg, s) := new_reg s in
      let (reg', s) := new_reg s in
      let (reg'', s) := new_reg s in
      let s := add_instr (Iload (xH, nil) Mint32 (Aglobal SIZE Int.zero) nil reg'' (Psucc s.(st_nextnode))) s in
      let s := add_instr (Iload (xH, nil) Mint32 (Abased STK (Int.repr (-4 * Z.of_nat depth))) (reg''::nil) reg' (Psucc s.(st_nextnode))) s in
      let s := add_instr (Iop (if Int.eq_dec base Int.zero then Omove else Oaddimm base) (reg'::nil) reg (Psucc s.(st_nextnode))) s in
      OK (inl reg, s)
    else
      let (reg0, s) := new_reg s in
      let (reg1, s) := new_reg s in
      let (reg', s) := new_reg s in
      let (reg'', s) := new_reg s in
      let s := add_instr (Iload (xH, nil) Mint32 (Aglobal SIZE Int.zero) nil reg'' (Psucc s.(st_nextnode))) s in
      let s := add_instr (Iload (xH, nil) Mint32 (Abased STK (Int.repr (-4 * Z.of_nat depth))) (reg''::nil) reg' (Psucc s.(st_nextnode))) s in
      let s := add_instr (Iop (if Int.eq_dec base Int.zero then Omove else Oaddimm base) (reg'::nil) reg0 (Psucc s.(st_nextnode))) s in
      let s := add_instr (Iload (xH, nil) Mint32 (Aglobal SIZE Int.zero) nil reg'' (Psucc s.(st_nextnode))) s in
      let s := add_instr (Iload (xH, nil) Mint32 (Abased STK (Int.repr (-4 * Z.of_nat depth))) (reg''::nil) reg' (Psucc s.(st_nextnode))) s in
      let s := add_instr (Iop (if Int.eq_dec bound Int.zero then Omove else Oaddimm bound) (reg'::nil) reg1 (Psucc s.(st_nextnode))) s in
      OK (inr (reg0, reg1), s).

  Definition put_symbol_range_in_regs' (g: ident) (base bound: Int.int) (s: state): res ((reg + (reg * reg)) * state) :=
    if Int.eq_dec base bound then
      let (reg, s) := new_reg s in
      let s := add_instr (Iop (Oaddrsymbol g base) nil reg (Psucc s.(st_nextnode))) s in
      OK (inl reg, s)
    else
      let (reg0, s) := new_reg s in
      let (reg1, s) := new_reg s in
      let s := add_instr (Iop (Oaddrsymbol g base) nil reg0 (Psucc s.(st_nextnode))) s in
      let s := add_instr (Iop (Oaddrsymbol g bound) nil reg1 (Psucc s.(st_nextnode))) s in
      OK (inr (reg0, reg1), s).

  Definition find_symbol (g: ident) (msg: identerrmsg) : res block :=
    match Genv.find_symbol ge g with
    | None => Error (msg g)
    | Some b => OK b
    end.

  Fixpoint put_all_in_regs' (alpha: list ablock) (s: state): res (list (reg + (reg * reg)) * state) :=
    match alpha with
    | nil => OK (nil, s)
    | (ABlocal depth varname (base, bound))::alpha =>
      do (x1, s) <- put_stack_range_in_regs' base bound depth s;
      do (x2, s) <- put_all_in_regs' alpha s;
      OK (x1::x2, s)
    | (ABglobal b (base, bound))::alpha =>
      do _ <- find_symbol bb, MSG "pute_all_in_regs: unbound global symbol" :: POS b :: nil);
      do (x1, s) <- put_symbol_range_in_regs' b base bound s;
      do (x2, s) <- put_all_in_regs' alpha s;
      OK (x1::x2, s)
    end.

  Fixpoint put_checks (reg0: reg) (regs: list reg) (s: state): res state :=
    match regs with
      | nil => OK s
      | reg1::regs =>
        let s := add_instr (Iop (Ocmp (Ccompu Ceq)) (reg0::reg1::nil) reg1 (Psucc s.(st_nextnode))) s in
        put_checks reg0 regs s
    end.

  Fixpoint put_checks' (reg0: reg) (regs: list (reg + (reg * reg))) (s: state): res state :=
    match regs with
    | nil => OK s
    | (inr (reg, reg'))::regs =>
      let s := add_instr (Iop (Ocmp (Ccompu Cle)) (reg::reg0::nil) reg (Psucc s.(st_nextnode))) s in
      let s := add_instr (Iop (Ocmp (Ccompu Cle)) (reg0::reg'::nil) reg' (Psucc s.(st_nextnode))) s in
      let s := add_instr (Iop Oand (reg::reg'::nil) reg (Psucc s.(st_nextnode))) s in
      put_checks' reg0 regs s
    | (inl reg1)::regs =>
      let s := add_instr (Iop (Ocmp (Ccompu Ceq)) (reg1::reg0::nil) reg1 (Psucc s.(st_nextnode))) s in
      put_checks' reg0 regs s
    end.

  Fixpoint put_ors (reg0: reg) (regs: list reg) (s: state): res state :=
    match regs with
    | nil => OK s
    | reg1::regs =>
      let s := add_instr (Iop Oor (reg0::reg1::nil) reg0 (Psucc s.(st_nextnode))) s in
      put_ors reg0 regs s
    end.

  Definition add_checks opt i (alpha: annotation) (kappa: memory_chunk) (addr: addressing) (args: list reg) (s:state): res state:=
    if is_trivial_annotation prog alpha kappa addr
    then OK (add_instr (i alpha kappa addr args) s)
    else if is_singleton (snd alpha) then
           let (reg0, s) := new_reg s in
           let s := add_instr (Iop (match addr with | Aindexed n => if Int.eq_dec n Int.zero then Omove else Olea addr | _ => Olea addr end) args reg0 (Psucc s.(st_nextnode))) s in
           do (regs, s) <- put_all_in_regs' (snd alpha) s;
           do s <- put_checks' reg0 regs s;
           match map (fun x => match x with | inl r => r | inr r => fst r end) regs with
           | nil => Error (msg "Should not happen")
           | reg1::regs' => let s := add_instr (Iop Omove (reg1::nil) reg0 (Psucc s.(st_nextnode))) s in
                           do s <- put_ors reg0 regs' s;
                           let s := add_instr (Icond (Ccompuimm Cne Int.zero) (reg0::nil) (Psucc s.(st_nextnode)) (Psucc (Psucc s.(st_nextnode)))) s in
                           let s := add_instr (i alpha kappa addr args) s in
                           OK (add_return_undef opt (fst alpha) s)
           end
         else
           let (reg0, s) := new_reg s in
           let s := add_instr (Iop (match addr with | Aindexed n => if Int.eq_dec n Int.zero then Omove else Olea addr | _ => Olea addr end) args reg0 (Psucc s.(st_nextnode))) s in
           do (regs, s) <- put_all_in_regs kappa (snd alpha) s;
           do s <- put_checks reg0 regs s;
           match regs with
           | nil => Error (msg "Should not happen")
           | reg1::regs' => let s := add_instr (Iop Omove (reg1::nil) reg0 (Psucc s.(st_nextnode))) s in
                           do s <- put_ors reg0 regs' s;
                           let s := add_instr (Icond (Ccompuimm Cne Int.zero) (reg0::nil) (Psucc s.(st_nextnode)) (Psucc (Psucc s.(st_nextnode)))) s in
                           let s := add_instr (i alpha kappa addr args) s in
                           OK (add_return_undef opt (fst alpha) s)
           end.

  Remark update_instr_wf:
    forall s n i,
      Plt n s.(st_nextnode) ->
      forall pc,
        Plt pc s.(st_nextnode) \/ (PTree.set n i s.(st_code))!pc = None.
Proof.
    intros.
    case (peq pc n); intro.
    subst pc; left; assumption.
    rewrite PTree.gso; auto. exact (st_wf s pc).
  Qed.
  
  Definition update_instr (n: node) (i: instruction) (s: state): res state :=
    match plt n s.(st_nextnode) with
      | left LT =>
        OK (mkstate s.(st_nextreg) s.(st_nextnode) (PTree.set n i s.(st_code)) (update_instr_wf s n i LT))
      | _ => Error (msg "update_instr")
    end.

  Definition add_epilogue (i: instruction) (s: state): res state :=
    let (r, s) := new_reg s in
    let (r', s) := new_reg s in
    let s := add_instr (Iload (xH, nil) Mint32 (Aglobal SIZE Int.zero) nil r (Psucc s.(st_nextnode))) s in
    let s := add_instr (Iop (Ointconst (Int.repr 4)) nil r' (Psucc s.(st_nextnode))) s in
    let s := add_instr (Iop Osub (r::r'::nil) r (Psucc s.(st_nextnode))) s in
    let s := add_instr (Istore (xH, nil) Mint32 (Aglobal SIZE Int.zero) nil r (Psucc s.(st_nextnode))) s in
    let s := add_instr i s in OK s.

  Definition check_globals_of_builtin_args (id: ident) (args: list (builtin_arg reg)) :=
    if in_dec peq id (globals_of_builtin_args args) then
      Error (msg "checks globals of builtin args fail")
    else OK tt.

  Fixpoint check_annotations_depth (alpha: list ablock) :=
    match alpha with
    | nil => OK tt
    | (ABlocal depth varname range)::alpha =>
      do _ <- check_annotations_depth alpha;
      if lt_dec depth 128 then OK tt else Error (msg "annotation too deep")
    | (ABglobal b range)::alpha =>
      do _ <- check_annotations_depth alpha;
      if peq b STK then Error (msg "STK appears in annotations") else
      if peq b SIZE then Error (msg "SIZE appears in annotations") else
      OK tt
    end.

  Fixpoint check_annotations_range (alpha: list ablock) :=
    match alpha with
    | nil => OK tt
    | (ABlocal depth varname (base, bound))::alpha =>
      if zle (Int.unsigned base) (Int.unsigned bound) then
        if zlt (Int.unsigned bound) (Int.modulus - 1) then check_annotations_range alpha
        else Error (msg "range too high")
      else Error (msg "lower bound is greater that greater bound")
    | (ABglobal b (base, bound))::alpha =>
      if zle (Int.unsigned base) (Int.unsigned bound) then
        if zlt (Int.unsigned bound) (Int.modulus - 1) then check_annotations_range alpha
        else Error (msg "range too high")
      else Error (msg "lower bound is greater that greater bound")
    end.
      
  Fixpoint check_annotations_divides' (kappa: memory_chunk) (ofs: Int.int) (n: nat) :=
    match n with
    | O => Error (msg "Annotation does not respect alignment constraint")
    | S n => if Zdivides_dec (align_chunk kappa) (Int.unsigned ofs) then OK tt else check_annotations_divides' kappa (Int.add ofs Int.one) n
    end.
  
  Fixpoint check_annotations_divides (kappa: memory_chunk) (alpha: list ablock) :=
    match alpha with
    | nil => OK tt
    | (ABlocal depth varname (base, bound))::alpha =>
      do _ <- check_annotations_divides kappa alpha;
      let size := Int.unsigned (Int.sub bound base) + 1 in
        match size with
        | Zpos size => check_annotations_divides' kappa base (Pos.to_nat size)
        | _ => Error (msg "should not happen")
        end
    | (ABglobal id (base, bound))::alpha =>
      do _ <- check_annotations_divides kappa alpha;
      let size := Int.unsigned (Int.sub bound base) + 1 in
      match size with
      | Zpos size => check_annotations_divides' kappa base (Pos.to_nat size)
      | _ => Error (msg "should not happen")
      end
    end.
  
  Definition transf_instr opt (st: res state) (pc: positive) (instr: instruction) :=
    match instr with
      | Iload alpha kappa addr args dst k =>
        do _ <- check_annotations_depth (snd alpha);
        do _ <- check_annotations_divides kappa (snd alpha);
        do _ <- check_annotations_range (snd alpha);
        do st <- st;
        do st <- update_instr pc (Inop st.(st_nextnode)) st;
        add_checks opt (λ α κ addr args, Iload α κ addr args dst k) alpha kappa addr args st
      | Istore alpha kappa addr args src k =>
        do _ <- check_annotations_depth (snd alpha);
        do _ <- check_annotations_divides kappa (snd alpha);
        do _ <- check_annotations_range (snd alpha);
        do st <- st;
        do st <- update_instr pc (Inop st.(st_nextnode)) st;
        add_checks opt (λ α κ addr args, Istore α κ addr args src k) alpha kappa addr args st
      | Ireturn or =>
        do st <- st;
        do st <- update_instr pc (Inop st.(st_nextnode)) st;
        add_epilogue (Ireturn or) st
      | Itailcall sig ros args =>
        do st <- st;
        do st <- update_instr pc (Inop st.(st_nextnode)) st;
        add_epilogue (Itailcall sig ros args) st
      | Ibuiltin ef args res k =>
        do _ <- check_globals_of_builtin_args STK args;
        do _ <- check_globals_of_builtin_args SIZE args;
        st
      | _ => st
    end.

  Definition add_prologue (k: node) (s: state): state :=
    let (r, s) := new_reg s in
    let (r', s) := new_reg s in
    let (r'', s) := new_reg s in
    let s := add_instr (Iload (xH, nil) Mint32 (Aglobal SIZE Int.zero) nil r (Psucc s.(st_nextnode))) s in
    let s := add_instr (Iop (Oaddimm (Int.repr 4)) (r::nil) r (Psucc s.(st_nextnode))) s in
    let s := add_instr (Istore (xH, nil) Mint32 (Aglobal SIZE Int.zero) nil r (Psucc s.(st_nextnode))) s in
    let s := add_instr (Iload (xH, nil) Mint32 (Aglobal SIZE Int.zero) nil r' (Psucc s.(st_nextnode))) s in
    let s := add_instr (Iop (Olea (Ainstack Int.zero)) nil r'' (Psucc s.(st_nextnode))) s in
    let s := add_instr (Istore (xH, nil) Mint32 (Abased STK Int.zero) (r'::nil) r'' (Psucc s.(st_nextnode))) s in
    add_instr (Inop k) s.

  Definition transf_function (f: function): res function :=
    let s := init_state f in
    let new_entrypoint := s.(st_nextnode) in
    let s := add_prologue f.(fn_entrypoint) s in
    do s <- PTree.fold (transf_instr f.(fn_sig).(sig_res)) f.(fn_code) (OK s);
    OK (mkfunction f.(fn_sig) f.(fn_params) f.(fn_stacksize) s.(st_code) new_entrypoint).

  Definition transf_fundef (fd: fundef): res fundef := AST.transf_partial_fundef transf_function fd.

End TRANSL.

Definition STK_globvar :=
  mkglobvar tt ((Init_space 512)::nil) false false.

Definition SIZE_globvar :=
  mkglobvar tt ((Init_int32 (Int.repr (-4)))::nil) false false.

Fixpoint check_init_data_list_size (l: list (globdef fundef unit)) :=
  match l with
  | nil => true
  | a::b => match a with
           | Gfun _ => check_init_data_list_size b
           | Gvar gv => if zle (Genv.init_data_list_size (gvar_init gv)) Int.max_unsigned
                       then check_init_data_list_size b
                       else false
           end
  end.

Definition transf_program (p: program): res program :=
  let names := List.map fst p.(prog_defs) ++ (prog_public p) in
  let STK := Psucc (Psucc (List.fold_left Pmax names xH)) in
  let SIZE := Psucc STK in
  if list_norepet_dec peq (map fst p.(prog_defs)) then
    if check_init_data_list_size (map snd p.(prog_defs)) then
      AST.transform_partial_augment_program (transf_fundef p (Genv.globalenv p) STK SIZE) (fun v => OK v) ((STK, Gvar STK_globvar)::(SIZE, Gvar SIZE_globvar)::nil) (p.(prog_main)) p
    else Error (msg "cannot allocate global variables with more than 2^32 space")
  else Error (msg "repeating identifiers in program definitions").