
(* 2017 (C) Jussi Rintanen *)

(* Reduce reachability constraints to sets of clauses with an acyclicity constraint.
 *)

fun acycEncoding (s,nodes,arcs,fwd,(Jvar,rvar,auxVar),ac) =
  let
      fun oneNode (n,ac) =
        (* reachVar(src,n) <- a neighbor is reachable *)
	(* reachVar(src,n) -> a neighbor is reachable and the arc is a justifying arc *)
	let val neighbors = neighborsOf(n,not fwd,arcs)
	in
	    ((Neg(rvar(s,n)))::(map (fn (t,v) => Pos(auxVar(s,n,t))) neighbors))
	    ::(map (fn (t,v) => [Pos(rvar(s,n)),Neg(rvar(s,t)),Neg v]) neighbors)
	    @(map (fn (t,v) => [Neg(auxVar(s,n,t)),Pos v]) neighbors)
	    @(map (fn (t,v) => [Neg(auxVar(s,n,t)),Pos(Jvar(s,n,t))]) neighbors)
	    @(map (fn (t,v) => [Neg(auxVar(s,n,t)),Pos(rvar(s,t))]) neighbors)
	    @ac
	end
  in
      fold oneNode (difference(nodes,[s])) (([Pos(rvar(s,s))])::ac)
  end;

(* When there are only unreachability constraints, we can use the trivial encoding
   of reachability that provides an assignment that may include reachabilities
   that do not hold, but if something is claimed unreachable, it is not.
   This is by trivial implications reachable(s) & arc(s,t) -> reachable(t).
*)

fun simpleEncoding (src,arcs,fwd,rvar,ac) =
  let
      fun oneArc ((v,s,t),ac) =
        (* reachVar(src,n) <- reachVar(src,m) & arcVar(m,n) *)
	if fwd
	then if t=src then ac else [Pos(rvar(src,t)),Neg(rvar(src,s)),Neg v]::ac
	else if s=src then ac else [Pos(rvar(src,s)),Neg(rvar(src,t)),Neg v]::ac
  in
      fold oneArc arcs (([Pos(rvar(src,src))])::ac)
  end;

fun reach2acyc ((_,_,true,_,_,_,_),_) = ERROR "Reduction from acyclicity to CNF not implemented"
  | reach2acyc ((_,NONE,false,_,_,_,_),_) = ERROR "No graph given: cannot reduce reachability to acyc"
  | reach2acyc ((varcnt,SOME graph,false,reacha,nonreacha,clauses,symtab),debugoutput) =
    let
	val ST = SymbolTable.create(stringlisthash,1000000,String.concat)
	val _ = SymbolTable.reserveN(ST,varcnt)

	fun reachVarName(tag,s,n) = ["reach",tag,"_",Int.toString s,"_",Int.toString n]
	fun JVarName(tag,s,n,m) = ["J",tag,"_",Int.toString s,"_",Int.toString n,"_",Int.toString m]
	fun auxVarName3(tag,s,n,m) = ["aux",tag,"_",Int.toString s,"_",Int.toString n,"_",Int.toString m]

	val maxdistance = (p12 graph)-1

	fun reachVar tag (s,n) = SymbolTable.add(ST,reachVarName(tag,s,n),())
	fun JVar tag (s,n,m) = SymbolTable.add(ST,JVarName(tag,s,n,m),())
	fun auxVar3 tag (s,n,m) = SymbolTable.add(ST,auxVarName3(tag,s,n,m),())

	val (origins,connects) = reqDistances(reacha,nonreacha,(reachVar,reachVar),debugoutput)

	fun VARS tag = (JVar tag,reachVar tag,auxVar3 tag)
			    
	val arcs = p22 graph
	val nodes = fromto(0,maxdistance)
			
	val newclauses = fold (fn ((orig,dir,true),ac) =>
				  acycEncoding(orig,nodes,arcs,dir=FWD,VARS (dir2tag dir),ac)
			      | ((orig,dir,false),ac) =>
				  simpleEncoding(orig,arcs,dir=FWD,reachVar (dir2tag dir),ac)
			      )
			      origins
			      (connects@clauses)

	(* Source nodes of those reachability constraints that are non-trivial *)
	val graphComponents = fold (fn ((orig,dir,true),ac) => (orig,dir)::ac | (_,ac) => ac) origins []
	(* The new graph is copies of the old for each source of a non-trivial constraint. *)

	val nOfNodes = p12 graph
	val NewNOfNodes = nOfNodes*(length graphComponents)

	fun nodeIndex (i,n) = n+i*nOfNodes
	val newArcs = map (fn ((i,(src,FWD)),(v,s,t)) => (JVar (dir2tag FWD) (src,t,s),
							  nodeIndex(i,t),
							  nodeIndex(i,s))
			  | ((i,(src,BWD)),(v,s,t)) => (JVar (dir2tag BWD) (src,s,t),
							nodeIndex(i,s),
							nodeIndex(i,t)))
			  (product(number(graphComponents,0),
				   arcs))

	val newGraph = (NewNOfNodes,newArcs)

	val newsymbols = map (fn (a,b,c) => (b,String.concat a)) (SymbolTable.allvalues ST)
	val newsymtab = mergesort (fn (e1,e2) => (p12 e1) < (p12 e2)) (newsymbols@symtab)

	val newvarcnt = (fold intmax (map p12 newsymbols) 0)+1

    in
	(newvarcnt,SOME newGraph,true,[],[],newclauses,newsymtab)
    end;
