
(* 2017 (C) Jussi Rintanen *)

(* Reduce reachability constraints to sets of clauses.
   Greery strategy for minimizing the number of separate
   reachability constraints that need to be CNF encoded:
     1. Find the node with the maximum number of occurrences
       as s or t in reachability constraints (s,t).
     2. Remove all those constraints.
     3. If constraints left, go to 1.
   Reachability constraints can be constructed from the beginning or
   from the end. Consider a collection of constraints (s1,t), (s2,t),...,
   (sN,t), all with the same target node. Now it is best to encode
   the distance TO t, and say that (si,t) is satisfied if si has
   a finite distance to t. The number of formulas is 1/N in comparison
   to encoding the distances starting from all of s1,...,sN and
   requiring that t has a finite distance in all cases.
 *)

(* Encode reachability by the Distance Encoding.
  If fwd=true, then we encode distances forward from a node
  (following arcs s->t from s to t. Otherwise distances are
  backwards, following arcs from t to s.
 *)

(* Find the element with most occurrences in a list.
  The list elements are abstracted with f, and we count
  the number of occurrences w.r.t. this abstraction.
  We need a function ord for ordering the list elements.
  The function runs in linear time, after the list has
  been ordered by mergesort.
*)

fun mostOccurs (es as (e::_)) =
  let val eso = mergesort (op <) es
      fun count (e0,cnt,e::es) = if e0 = e
				 then count (e0,cnt+1,es)
				 else (es,cnt)
	| count (e0,cnt,[]) = ([],cnt)
      fun getMost ([],best,cnt) = (best,cnt)
	| getMost (e::es,best,bestcnt) =
	  let val (es0,cnt) = count (e,1,es)
	  in
	      if cnt > bestcnt then getMost(es0,e,cnt)
	      else getMost(es0,best,bestcnt)
	  end
  in
      getMost(es,e,1)
  end
  | mostOccurs [] = ERROR "0 elements in mostOccurs";

(* Neighbors of a node n, along arcs or reversed arcs. *)

fun neighborsOf(n,true,arcs) = fold (fn ((v,s,t),ac) => if s=n then (t,v)::ac else ac) arcs []
  | neighborsOf(n,false,arcs) = fold (fn ((v,s,t),ac) => if t=n then (s,v)::ac else ac) arcs [];

fun distEncoding (src,nodes,arcs,fwd,(distVar,auxVar),ac)=
  let 
      fun oneStep ((n,d),ac) =
	(* distance of n is d iff it is d-1 or it has neighbor of distance d-1 *)
	let val neighbors0 = if not fwd (* neighbors fwd or bwd *)
			     then fold (fn ((v,s,t),ac) => if s=n then (t,v)::ac else ac) arcs []
			     else fold (fn ((v,s,t),ac) => if t=n then (s,v)::ac else ac) arcs []
	    val neighbors = map (fn (m,v) => (m,distVar(src,m,d-1),auxVar(src,n,m,d-1),v)) (neighborsOf(n,not fwd,arcs))
	    val dn = distVar(src,n,d)
	    val dn1 = distVar(src,n,d-1)
  (* distVar(src,n,d) <-> distVar(src,n,d-1) V auxVar(src,n1,d) V ... V auxVar(src,nM,d) *)
  (* auxVar(src,ni,d) <-> distVar(src,ni,d-1) & arcVar(n,m) *)
	in
	    ((Neg dn)::(Pos dn1)::(map (Pos o p34) neighbors))
	    ::[Neg dn1,Pos dn]
	    ::(fold (fn ((m,dm1,aux,arcv),ac) => [Neg aux,Pos dm1]
						 ::[Neg aux,Pos arcv]
						 ::[Neg dm1,Neg arcv,Pos dn]
						 ::ac)
		    neighbors
		    ac)
	end
  in
      fold oneStep (product(nodes,fromto(1,(length nodes)-1)))
	   ((map (fn n => [(if n=src then Pos else Neg) (distVar(src,n,0))])
		nodes)@ac)
  end;

(* When there are only unreachability constraints, we can use the trivial encoding
   of reachability that provides an assignment that may include reachabilities
   that do not hold, but if something is claimed unreachable, it is not.
   This is by trivial implications reachable(s) & arc(s,t) -> reachable(t).
*)

fun simpleEncoding (src,arcs,fwd,rvar,ac) =
  let
      fun oneArc ((v,s,t),ac) =
        (* reachVar(src,n) <- reachVar(src,m) & arcVar(m,n) *)
	if fwd
	then if t=src then ac else [Pos(rvar(src,t)),Neg(rvar(src,s)),Neg v]::ac
	else if s=src then ac else [Pos(rvar(src,s)),Neg(rvar(src,t)),Neg v]::ac
  in
      fold oneArc arcs (([Pos(rvar(src,src))])::ac)
  end;

datatype arrowend = FWD | BWD;

fun dir2tag FWD = "F"
  | dir2tag BWD = "B";

fun reqDistances ([],[],_,_) = ([],[])
  | reqDistances (reacha : (int * int * lit) list,nonreacha,connVar,debugoutput) =
    let val (bestend,cnt1) = mostOccurs(map p23 (reacha@nonreacha))
	val (beststart,cnt2) = mostOccurs(map p13 (reacha@nonreacha))
	val (origin,direction) = if cnt1 > cnt2
				 then (bestend,BWD)
				 else (beststart,FWD)
	val (f,nf,target) = if direction=BWD then ((fn n => (p23 n) = origin),
						   (fn n => not((p23 n) = origin)),
						   p13)
			    else ((fn n => (p13 n) = origin),
				  (fn n => not((p13 n) = origin)),
				  p23)
				     
	val hasReachability = forsome f reacha

	val rvar = (if hasReachability then p12 else p22) connVar
	val dirtag = dir2tag direction
	val encodingR = map (fn (s,t,v) => [negLit v,Pos(rvar dirtag (origin,target (s,t,v)))])
			   (filter f reacha)
	val encodingN = map (fn (s,t,v) => [negLit v,Neg(rvar dirtag (origin,target (s,t,v)))])
			    (filter f nonreacha)

	val (ld,le) = reqDistances (filter nf reacha,filter nf nonreacha,connVar,debugoutput)
    in
	if debugoutput then app print ["GOT ORIGIN ",Int.toString origin," ",
				       (case direction of FWD => "FWD" 
							| BWD => "BWD"),
				       "; reachability ",
				       if hasReachability then "YES" else "NO",
				       "\n"]
	else ();
	((origin,direction,hasReachability)::ld,encodingR@encodingN@le)
    end;

fun reach2cnf ((_,_,true,_,_,_,_),_) = ERROR "Reduction from acyclicity to CNF not implemented"
  | reach2cnf ((_,NONE,false,_,_,_,_),_) = ERROR "No graph given: cannot reduce reachability to CNF"
  | reach2cnf ((varcnt,SOME graph,false,reacha,nonreacha,clauses,symtab),debugoutput) =
    let
	val ST = SymbolTable.create(stringlisthash,1000000,String.concat)
	val _ = SymbolTable.reserveN(ST,varcnt)

	fun reachVarName(tag,s,n) = ["reach",tag,"_",Int.toString s,"_",Int.toString n]
	fun distVarName(tag,s,n,d) = ["dist",tag,"_",Int.toString s,"_",Int.toString n,"_",Int.toString d]
	fun auxVarName4(tag,s,n,m,d) = ["aux",tag,"_",Int.toString s,"_",Int.toString n,"_",Int.toString m,"_",Int.toString d]

	val maxdistance = (p12 graph)-1

	fun reachVar tag (s,n) = SymbolTable.add(ST,reachVarName(tag,s,n),())
	fun distMaxVar tag (s,n) = SymbolTable.add(ST,distVarName(tag,s,n,maxdistance),())
	fun distVar tag (s,n,d) = SymbolTable.add(ST,distVarName(tag,s,n,d),())
	fun auxVar4 tag (s,n,m,d) = SymbolTable.add(ST,auxVarName4(tag,s,n,m,d),())

	val (origins,connects) = reqDistances(reacha,nonreacha,(distMaxVar,reachVar),debugoutput)


	val _ = if debugoutput
		then app (fn (orig,dir,_) =>
			     app print ["ORIGIN: ",Int.toString orig,
					" ",(case dir of FWD => "FWD" 
						      |  BWD => "BWD"),
					"\n"])
			 origins
		else ()

	fun VARS tag = (distVar tag,auxVar4 tag)
			    
	val arcs = p22 graph
	val nodes = fromto(0,maxdistance)
			
	val newclauses = fold (fn ((orig,dir,true),ac) =>
				  distEncoding(orig,nodes,arcs,dir=FWD,VARS (dir2tag dir),ac)
			      | ((orig,dir,false),ac) =>
				  simpleEncoding(orig,arcs,dir=FWD,reachVar (dir2tag dir),ac)
			      )
			      origins
			      (connects@clauses)

	val newsymbols = map (fn (a,b,c) => (b,String.concat a)) (SymbolTable.allvalues ST)
	val newsymtab = mergesort (fn (e1,e2) => (p12 e1) < (p12 e2)) (newsymbols@symtab)

	val newvarcnt = (fold intmax (map p12 newsymbols) 0)+1
    in
	(newvarcnt,NONE,false,[],[],newclauses,newsymtab)
    end;
