Multi-way branches (``switch'' statements) and their compilation
to comparison trees.
Require Import EqNat.
Require Import FMaps.
Require FMapAVL.
Require Import Coqlib.
Require Import Integers.
Require Import Ordered.
Module IntMap :=
FMapAVL.Make(
OrderedInt).
Module IntMapF :=
FMapFacts.Facts(
IntMap).
A multi-way branch is composed of a list of (key, action) pairs,
plus a default action.
Definition table (
A:
Type):
Type :=
list (
int *
A).
Fixpoint switch_target {
A:
Type} (
n:
int) (
dfl:
A) (
cases:
table A)
{
struct cases} :
A :=
match cases with
|
nil =>
dfl
| (
key,
action) ::
rem =>
if Int.eq n key then action else switch_target n dfl rem
end.
Multi-way branches are translated to comparison trees.
Each node of the tree performs either
-
an equality against one of the keys;
-
or a "less than" test against one of the keys;
-
or a computed branch (jump table) against a range of key values.
Section COMPTREE.
Context {
A:
Type}.
Variable eqA:
forall (
x y:
A), {
x=
y} + {
x<>
y}.
Inductive comptree:
Type :=
|
CTaction:
A ->
comptree
|
CTifeq:
int ->
A ->
comptree ->
comptree
|
CTiflt:
int ->
comptree ->
comptree ->
comptree
|
CTjumptable:
int ->
int ->
list A ->
comptree ->
comptree.
Fixpoint comptree_match (
n:
int) (
t:
comptree) {
struct t}:
option A :=
match t with
|
CTaction act =>
Some act
|
CTifeq key act t' =>
if Int.eq n key then Some act else comptree_match n t'
|
CTiflt key t1 t2 =>
if Int.ltu n key then comptree_match n t1 else comptree_match n t2
|
CTjumptable ofs sz tbl t' =>
if Int.ltu (
Int.sub n ofs)
sz
then list_nth_z tbl (
Int.unsigned (
Int.sub n ofs))
else comptree_match n t'
end.
The translation from a table to a comparison tree is performed
by untrusted Caml code (function compile_switch in
file RTLgenaux.ml). In Coq, we validate a posteriori the
result of this function. In other terms, we now develop
and prove correct Coq functions that take a table and a comparison
tree, and check that their semantics are equivalent.
Fixpoint split_lt (
pivot:
int) (
cases:
table A)
{
struct cases} :
table A *
table A :=
match cases with
|
nil => (
nil,
nil)
| (
key,
act) ::
rem =>
let (
l,
r) :=
split_lt pivot rem in
if Int.ltu key pivot
then ((
key,
act) ::
l,
r)
else (
l, (
key,
act) ::
r)
end.
Fixpoint split_eq (
pivot:
int) (
cases:
table A)
{
struct cases} :
option A *
table A :=
match cases with
|
nil => (
None,
nil)
| (
key,
act) ::
rem =>
let (
same,
others) :=
split_eq pivot rem in
if Int.eq key pivot
then (
Some act,
others)
else (
same, (
key,
act) ::
others)
end.
Fixpoint split_between (
ofs sz:
int) (
cases:
table A)
{
struct cases} :
IntMap.t A *
table A :=
match cases with
|
nil => (
IntMap.empty A,
nil)
| (
key,
act) ::
rem =>
let (
inside,
outside) :=
split_between ofs sz rem in
if Int.ltu (
Int.sub key ofs)
sz
then (
IntMap.add key act inside,
outside)
else (
inside, (
key,
act) ::
outside)
end.
Definition refine_low_bound (
v lo:
Z) :=
if zeq v lo then lo + 1
else lo.
Definition refine_high_bound (
v hi:
Z) :=
if zeq v hi then hi - 1
else hi.
Fixpoint validate_jumptable (
cases:
IntMap.t A) (
default:
A)
(
tbl:
list A) (
n:
int) {
struct tbl} :
bool :=
match tbl with
|
nil =>
true
|
act ::
rem =>
eqA act (
match IntMap.find n cases with Some a =>
a |
None =>
default end)
&&
validate_jumptable cases default rem (
Int.add n Int.one)
end.
Fixpoint validate (
default:
A) (
cases:
table A) (
t:
comptree)
(
lo hi:
Z) {
struct t} :
bool :=
match t with
|
CTaction act =>
match cases with
|
nil =>
eqA act default
| (
key1,
act1) ::
_ =>
zeq (
Int.unsigned key1)
lo &&
zeq lo hi &&
eqA act act1
end
|
CTifeq pivot act t' =>
match split_eq pivot cases with
| (
None,
_) =>
false
| (
Some act',
others) =>
eqA act act'
&&
validate default others t'
(
refine_low_bound (
Int.unsigned pivot)
lo)
(
refine_high_bound (
Int.unsigned pivot)
hi)
end
|
CTiflt pivot t1 t2 =>
match split_lt pivot cases with
| (
lcases,
rcases) =>
validate default lcases t1 lo (
Int.unsigned pivot - 1)
&&
validate default rcases t2 (
Int.unsigned pivot)
hi
end
|
CTjumptable ofs sz tbl t' =>
let tbl_len :=
list_length_z tbl in
match split_between ofs sz cases with
| (
inside,
outside) =>
zle (
Int.unsigned sz)
tbl_len
&&
zle tbl_len Int.max_signed
&&
validate_jumptable inside default tbl ofs
&&
validate default outside t'
lo hi
end
end.
Definition validate_switch (
default:
A) (
cases:
table A) (
t:
comptree) :=
validate default cases t 0
Int.max_unsigned.
Correctness proof for validation.
Lemma split_eq_prop:
forall v default n cases optact cases',
split_eq n cases = (
optact,
cases') ->
switch_target v default cases =
(
if Int.eq v n
then match optact with Some act =>
act |
None =>
default end
else switch_target v default cases').
Proof.
induction cases;
simpl;
intros until cases'.
intros.
inversion H;
subst.
simpl.
destruct (
Int.eq v n);
auto.
destruct a as [
key act].
case_eq (
split_eq n cases).
intros same other SEQ.
rewrite (
IHcases _ _ SEQ).
predSpec Int.eq Int.eq_spec key n;
intro EQ;
inversion EQ;
simpl.
subst n.
destruct (
Int.eq v key).
auto.
auto.
predSpec Int.eq Int.eq_spec v key.
subst v.
predSpec Int.eq Int.eq_spec key n.
congruence.
auto.
auto.
Qed.
Lemma split_lt_prop:
forall v default n cases lcases rcases,
split_lt n cases = (
lcases,
rcases) ->
switch_target v default cases =
(
if Int.ltu v n
then switch_target v default lcases
else switch_target v default rcases).
Proof.
induction cases;
intros until rcases;
simpl.
intro.
inversion H;
subst.
simpl.
destruct (
Int.ltu v n);
auto.
destruct a as [
key act].
case_eq (
split_lt n cases).
intros lc rc SEQ.
rewrite (
IHcases _ _ SEQ).
case_eq (
Int.ltu key n);
intros;
inv H0;
simpl.
predSpec Int.eq Int.eq_spec v key.
subst v.
rewrite H.
auto.
auto.
predSpec Int.eq Int.eq_spec v key.
subst v.
rewrite H.
auto.
auto.
Qed.
Lemma split_between_prop:
forall v default ofs sz cases inside outside,
split_between ofs sz cases = (
inside,
outside) ->
switch_target v default cases =
(
if Int.ltu (
Int.sub v ofs)
sz
then match IntMap.find v inside with Some a =>
a |
None =>
default end
else switch_target v default outside).
Proof.
Lemma validate_jumptable_correct_rec:
forall cases default tbl base v,
validate_jumptable cases default tbl base =
true ->
0 <=
Int.unsigned v <
list_length_z tbl ->
list_nth_z tbl (
Int.unsigned v) =
Some(
match IntMap.find (
Int.add base v)
cases with Some a =>
a |
None =>
default end).
Proof.
Lemma validate_jumptable_correct:
forall cases default tbl ofs v sz,
validate_jumptable cases default tbl ofs =
true ->
Int.ltu (
Int.sub v ofs)
sz =
true ->
Int.unsigned sz <=
list_length_z tbl ->
list_nth_z tbl (
Int.unsigned (
Int.sub v ofs)) =
Some(
match IntMap.find v cases with Some a =>
a |
None =>
default end).
Proof.
Lemma validate_correct_rec:
forall default v t cases lo hi,
validate default cases t lo hi =
true ->
lo <=
Int.unsigned v <=
hi ->
comptree_match v t =
Some (
switch_target v default cases).
Proof.
Definition table_tree_agree
(
default:
A) (
cases:
table A) (
t:
comptree) :
Prop :=
forall v,
comptree_match v t =
Some(
switch_target v default cases).
Theorem validate_switch_correct:
forall default t cases,
validate_switch default cases t =
true ->
table_tree_agree default cases t.
Proof.
End COMPTREE.