import math;
import os;
import shutil;
import subprocess;
import sys;
import re;
import random;

workingDirectory = os.path.join(".", "xorTestTemp");
"""The directory that the temporary files are created in"""
xorTestPredicate = "xorTestPredicate";
"""The predicate that used for the generated constraint"""
referenceProgramPath = os.path.join(workingDirectory, "reference.lp");
"""The path to the reference program file """
programUnderTestPath = os.path.join(workingDirectory, "test.lp");
"""The path to the program under test file """
programToSolvePath = os.path.join(workingDirectory, "solve.lp");
"""The path to the program to check for SAT"""
settingArgs = [];
"""the arguments that should be passed to clasp"""
settingArgsGringo = [];
"""the arguments that should be passed to gringo"""
settingC = 1;
"""the maximum number of answer sets to choose at once"""
settingN = 1;
"""the number of answer sets to compute"""
settingQ = 0.5;
""""the probability for an atom to be included in an XOR constraint; gets multiplied with 100 after the parsing of the arguments has finished"""
settingS = -1;
"""the number of constraints"""
settingT = 0;
"""the time limit for solving"""
settingReferenceFile = "";
"""the path of the reference program file"""
settingTestFile = "";
"""the path of the test program file"""
settingCommonFile = "";
"""the path of the common program file"""
settingAddShows = True;
"""indicates whether the show statements show be added or not"""
xorSampleArguments = [];
"""the arguments that xorsample should be called with"""
failedRuns = [];
"""a list of the index and the atoms of failed runs"""

atoms = [];
"""the list of atoms"""
testCaseCounter = 0;
"""the number of the current test case"""
repeatFailedCasesAtEnd = False;


def cleanUp():
    """Removes the working directory"""
    shutil.rmtree(workingDirectory, True, None);       

def usage():       
    """Prints the usage information."""
    
    print(os.path.basename(sys.argv[0]) + " [--help] [--a=args]* [--g=gArgs]* [--c=C] [--n=N] [--q=Q] [--s=S] [--t=T] [--dontAddShows] [--summary] --cf=p --rf=p --tf=p");
    print("A tool for testing answer sets programs");          
    print("usage:");
    print("--help:\tdisplays this usage info");
    print("--a=args:\tArguments that should be passed on to clasp.");    
    print("--g=Gargs:\tArguments that should be passed on to gringo.");
    print("--c=C:\t\tThe maximum number of answer sets to pick per iteration\n\t\t(default: 1)");
    print("--n=N:\t\tset an positive integer N as the number of answer sets to compute\n\t\t(default: 1)");
    print("--q=Q:\t\tset Q as the probability for an atom to be included in an XOR\n\t\tconstraint (0.01 <= Q <= 0.5, default:0.5)");
    print("--s=S:\t\tset an positive integer S as the initial number of constraints\n\t\t(default:log(X) where X is the number of atoms in the grounding)");    
    print("--t=T:\t\tset the time limit in seconds for clasp (only for xorsample)\n\t\t(default: 0, 0 = no time limit)");
    print("--dontAddShows:\tDoesn't add #show expressions if supplied");
    print("--summary:\tIf supplied, the failed test cases get repeated at the end");
    print("--cf=p:\t\tThe path to the common asp program");
    print("--rf=p:\t\tThe path to the reference asp program");
    print("--tf=p:\t\tThe path to the asp program under test");  
    sys.exit()  
    
def parseArgs():
    """Tries to parse the arguments that the program has been started with"""
    
    cSet = False;
    nSet = False;
    qSet = False;
    sSet = False;    
    tSet = False;    
    addShowsSet = False;
    cfSet = False;
    rfSet = False;
    tfSet = False;   
    summarySet = False;
    
    global settingArgs;
    global settingArgsGringo;
    global settingC;
    global settingN;
    global settingQ;
    global settingS;
    global settingT;     
    global settingAddShows;
    global settingReferenceFile;
    global settingTestFile;
    global settingCommonFile;     
    global repeatFailedCasesAtEnd;
    global xorTestPredicate;
    global xorSampleArguments;
           
    
    for arg in sys.argv[1:]: #skip the program name
        if arg == "--help":
            usage();            
        elif arg[0:4] == '--a=':            
            settingArgs.append(arg[4:]);
            xorSampleArguments = xorSampleArguments + [arg];
        elif arg[0:4] == '--g=':            
            settingArgsGringo.append(arg[4:]);
            xorSampleArguments = xorSampleArguments + [arg];         
        elif arg[0:4] == '--c=':
            if cSet:
                usage();
            cSet = True;
            settingC = tryParsePositiveInteger(arg[4:]);   
            xorSampleArguments = xorSampleArguments + [arg];           
        elif arg[0:4] == '--n=':
            if nSet:
                usage();
            nSet = True;
            settingN = tryParsePositiveInteger(arg[4:]);                          
        elif arg[0:4] == "--q=":
            if qSet:
                usage();
            qSet = True;
            settingQ = tryPraseParameterQ(arg[4:]);
            xorSampleArguments = xorSampleArguments + [arg];
        elif arg[0:4] == "--s=":
            if sSet:
                usage();
            sSet = True;
            settingS = tryParsePositiveInteger(arg[4:]);
            xorSampleArguments = xorSampleArguments + [arg];
        elif arg[0:4] == "--t=":
            if tSet:
                usage();
            tSet = True;
            settingT = tryParseInteger(arg[4:]);
            xorSampleArguments = xorSampleArguments + [arg];
        elif arg[0:14] == "--dontAddShows":
            if addShowsSet:
                usage();
            addShowsSet = True;
            settingAddShows = False;
            xorSampleArguments = xorSampleArguments + [arg];
        elif arg[0:9] == "--summary":
            if summarySet:
                usage();
            summarySet = True;
            repeatFailedCasesAtEnd = True;                        
        elif arg[0:5] == "--cf=":
            if cfSet:
                usage();
            cfSet = True;
            settingCommonFile = arg[5:]      
        elif arg[0:5] == "--rf=":
            if rfSet:
                usage();
            rfSet = True;
            settingReferenceFile = arg[5:]   
        elif arg[0:5] == "--tf=":
            if tfSet:
                usage();
            tfSet = True;
            settingTestFile = arg[5:]   
        else:
            usage(); 
            
    if (not rfSet) or (not tfSet) or (not cfSet):
        usage(); 
        
    if not os.path.isfile(settingCommonFile):
        print("The common file was not found");
        sys.exit(); 
        
    if not os.path.isfile(settingReferenceFile):
        print("The reference file was not found");
        sys.exit();                        
    
    if not os.path.isfile(settingTestFile):
        print("The test file was not found");
        sys.exit();          
        
    """Generate the encoding files and determine the predicate name""" 
    xorTestPredicateToUse = xorTestPredicate;   
    completeEncoding = "";    
        
    if cfSet:
        completeEncoding = completeEncoding + concatFiles([settingCommonFile, settingReferenceFile], referenceProgramPath);
        completeEncoding = completeEncoding + concatFiles([settingCommonFile, settingTestFile], programUnderTestPath);
    else:
        completeEncoding = completeEncoding + concatFiles([settingReferenceFile], referenceProgramPath);
        completeEncoding = completeEncoding + concatFiles([settingTestFile], programUnderTestPath);
    
    predicateIndex = 1;
    while xorTestPredicateToUse in completeEncoding:
        xorTestPredicateToUse = xorTestPredicate + str(predicateIndex);
        predicateIndex = predicateIndex + 1;   
        
    xorTestPredicate = xorTestPredicateToUse;
    
    
def concatFiles(inputFiles, outputFile, text = ""):
    """Concatenates all files in inputFiles to outputFile and appends a given optional text and returns the content
    
    :param inputFiles: A list of input files
    :param outputFile: The path to the output file
    :param text: The text to append
    """

    returnValue = "";
    
    with open(outputFile, 'w') as oFile:
        for fname in inputFiles:
            with open(fname) as infile:
                readContent = infile.read();
                oFile.write(readContent);
                returnValue = returnValue + readContent;
        if not text == "":
            oFile.write(os.linesep + text);
            returnValue = os.linesep + text;
             
    return returnValue;       
    
def tryParsePositiveInteger(s):
    """Tries to convert s into a positive integer and calls usage if not possible
    
    :param s: The string that should be parsed
    """
    r = tryParseInteger(s);    
    if r == 0:
        usage();
    return r;  
    
def tryParseInteger(s):
    """Tries to convert s into a integer and calls usage if not possible
    
    :param s: The string that should be parsed
    """
    if not s.isdigit():
        usage();    
    r = int(s); 
    if r < 0:
        usage();   
    return r;                   

def tryPraseParameterQ(s):
    """Tries to convert s into a real number between 0.01 (inclusive) and 0.5 (inclusive)
    
    :param s: The string that should be parsed
    """
    try:
        fs = float(s);        
        if fs < 0.01 or fs > 0.5 or math.isnan(fs):
            usage();
        return fs;
    except ValueError:
        usage();       


def callGringo(parameters):
    """Calls gringo with the given parameters and returns it output
    
    :param parameters: The parameters that should be passed to gringo    
    """               
    if len(settingArgsGringo) > 0:    
        process = subprocess.Popen(["gringo"] + settingArgsGringo + parameters, stdout=subprocess.PIPE, stderr=subprocess.PIPE);
    else:
        process = subprocess.Popen(["gringo"] + parameters, stdout=subprocess.PIPE, stderr=subprocess.PIPE);
    output, unused_err = process.communicate();         
    
    return str(output, "utf-8");       

def getAtoms(programFile):
    """Tries to read the atoms of the grounded version of the provided file"""    
    regexAtom = re.compile('(?<=[0-9] ).*',re.I+re.S)          
    
    #Ground the program using gringo
    grounding = callGringo([programFile]).split(os.linesep);          
    
    #parse the atoms from the grounding    
    global atoms;
    atomsBlockStarted = False    
    for line in grounding:
        if not atomsBlockStarted: #Skip everything before the block with the atoms
            if line == "0":
                atomsBlockStarted = True;
                continue;
        
        if (atomsBlockStarted == True):  
            if line == "0": #End of the block with the atoms is reached
                break;
            
            atoms.append(regexAtom.findall(line)[0])                                                                                       

def doTest(testCase):
    """Tries to check using gringo and clasp if the test file and the given AS of the reference file are SAT.
    
    :param testCase The test case           
    """     
    
    testCaseAtoms = testCase.split(" ");
    testCaseString = "";
    
    for testCaseAtom in testCaseAtoms:
        if len(testCaseAtom) > 0:
            testCaseString = testCaseString + testCaseAtom +".\n";
            
    concatFiles([referenceProgramPath], programToSolvePath, "\n" + testCaseString);        
    
    result = getAS(programToSolvePath);          
    
    if result is None:
        return None;        
      
    answerSet = result[1];    
             
    if(result[0]): #Program was consistent      
        #Create a rule that contains xorTestOKAtom in the head and an conjunction of all atoms in the body    
        okRule = xorTestPredicate + " :- ";
        answerSetAtoms = answerSet.split(" ");
        for atom in atoms:
            if not (atom in answerSetAtoms):
                okRule = okRule + "not ";
            okRule = okRule + atom + ",";
        
        okRule = okRule[:-1] + ".";                                  
        #Create the constraint
        contraintRule = ":- not " + xorTestPredicate + ".";        
        #Create a file with the test implementation, the answer set from xorsample and the common file (if supplied) and check if this program is sat.
        concatFiles([programUnderTestPath], programToSolvePath, "\n" + testCaseString + "\n" + okRule + os.linesep + contraintRule);
        
    else: #Program was inconsistent
        concatFiles([programUnderTestPath], programToSolvePath, "\n" + testCaseString);
        
    solverResults = callSolver(1, programToSolvePath);       
        
    if "TIME LIMIT   : 1" in solverResults:
        return None; 
        
    #Determine whether the program is satisfiable                         
    if "\nSATISFIABLE\n" in solverResults:
        return (result[0], result[0]);
    else:            
        if not ("\nUNSATISFIABLE\n" in solverResults):
            raise Exception("Unexpected solver result: " + solverResults);
        else:
            return (not result[0], result[0]);        

def callSolver(n, programFile):
    """Calls gringo and clasp with the given parameters and returns it output
    
    :param programFile: The file of the program that should be solved     
    :param n: The number of wanted answer sets
    """               
    gringoArguments = ["gringo"];
    if len(settingArgsGringo) > 0:
        gringoArguments = gringoArguments + settingArgsGringo;
    gringoArguments = gringoArguments + [programFile];    
    
    claspArguments = ["clasp"];
    if len(settingArgs) > 0:
        claspArguments = claspArguments + settingArgs;
    claspArguments = claspArguments + ["--models=" + str(n)] + ["--time-limit=" + str(settingT)];
    
    gringoProcess = subprocess.Popen(gringoArguments, stdout=subprocess.PIPE, stderr=subprocess.PIPE);                   
    
    claspProcess = subprocess.Popen(claspArguments, stdin = gringoProcess.stdout, stdout=subprocess.PIPE, stderr = subprocess.PIPE)     
    claspOutput, claspUnused_err = claspProcess.communicate();
    
    returnValue = str(claspOutput, "utf-8");       
    gringoProcess.stdout.close();
    return  returnValue;           

def callXorSample(n, program):
    """Calls xorsample to generate answer set"""
    xorSampleProcess = subprocess.Popen(["python3.4"] + ["xorsample.py"] + xorSampleArguments +["--n=" + str(n)] + [program], stdout=subprocess.PIPE, stderr=subprocess.PIPE);          
    
    ifFirstLine = True;    
    
    returnValue = [];
    
    while True: #Read line by line from xorsample
        line = xorSampleProcess.stdout.readline();
        line = str(line, "utf-8");
        if line != "": #EOF?
            if ifFirstLine: #Check if the line is 'Answer set: n' and if yes, skip
                ifFirstLine = False;
                if not line.startswith("Answer set "):
                    raise Exception("Failed to find answer sets using xorsample: " + line);                
            else:
                ifFirstLine = True;
                returnValue = returnValue + [str(line[:-1])];
        else:
            break;  
        
    return returnValue;              

def getAS(programFile):
    solverOutput = callSolver(1, programFile);
    
    #Determine whether the program is satisfiable             
    if "TIME LIMIT   : 1" in solverOutput:
        return None; 
             
    isSat = "\nSATISFIABLE\n" in solverOutput    
    if (not isSat):
        if not ("\nUNSATISFIABLE\n" in solverOutput):
            raise Exception("Failed to find an answer set: " +  solverOutput);
        else:
            return (False, "");
    
    answerSetRegex = re.compile(r"(?<=Answer: 1\n)(.*)",re.I);
    return (True, answerSetRegex.findall(solverOutput)[0]);    
    
def getRandomAS(n, programFile):
    return callXorSample(n, programFile);    
       

if __name__=="__main__":      
        
    try:
        if os.path.isdir(workingDirectory):
            print ("The working directory " + workingDirectory + " already exists. Aborting...");
            sys.exit();
        else:
            os.makedirs(workingDirectory);    
    except Exception as e:
        print("Failed to prepare the working directory " + workingDirectory + " : " + e.message);
    
    try:                   
        parseArgs();                       
        getAtoms(programUnderTestPath);        
        
        testcases = getRandomAS(settingN, settingCommonFile); 
        
        testCaseCounter = 0;
        consistentFailedCounter = 0;
        inconsistentFailedCounter = 0;
        consistentPassedCounter = 0;
        inconsistentPassedCounter = 0;
        indefiniteCounter = 0;
        
        for testcase in testcases:
            testCaseCounter = testCaseCounter + 1;
            testResult = doTest(testcase); 
            if testResult is None:
                print("Test case " + str(testCaseCounter) + ": INDEFINITE");
                indefiniteCounter = indefiniteCounter + 1;
            else:                   
                if testResult[0]:                    
                    if testResult[1]:
                        print("Test case " + str(testCaseCounter) + ": PASSED (CONSISTENT)");
                        consistentPassedCounter = consistentPassedCounter + 1;
                    else:
                        print("Test case " + str(testCaseCounter) + ": PASSED (INCONSISTENT)");
                        inconsistentPassedCounter = inconsistentPassedCounter + 1;
                else:
                    if testResult[1]:
                        print("Test case " + str(testCaseCounter) + ": FAILED (CONSISTENT)");
                        consistentFailedCounter = consistentFailedCounter + 1;
                    else:
                        print("Test case " + str(testCaseCounter) + ": FAILED (INCONSISTENT)");
                        inconsistentFailedCounter = inconsistentFailedCounter + 1;
                    print("Answer Set: " +  testcase);
                    failedRuns.append((testCaseCounter, testcase));  
                                                                        
                                                                                                                                           
        print("Failed testcases: " + str(len(failedRuns)));
        print("Failed consistent testcases: " + str(consistentFailedCounter));
        print("Failed inconsistent testcases: " + str(inconsistentFailedCounter));
        print("Passed consistent testcases: " + str(consistentPassedCounter));
        print("Passed inconsistent testcases: " + str(inconsistentPassedCounter));
        print("Indefinite testcases: " + str(indefiniteCounter));
        
        #Print the summary (if the argument --summary is supplied)
        if repeatFailedCasesAtEnd and len(failedRuns) > 0:
            print("The following testcases failed:");
            i = 0;
            while i < len(failedRuns):
                (runNumber, answerSet) = failedRuns[i];
                print("Testcase " + str(runNumber));
                print(answerSet); 
                i = i + 1;                
            print("Failed testcases: " + str(len(failedRuns)));                                 
        
    except Exception as e:
        print("An unexpected exception occurred: " + str(e));
    finally:
        cleanUp()                