# @package      hubzero-submit-distributor
# @file         RemoteBatchPEGASUS.py
# @copyright    Copyright (c) 2012-2020 The Regents of the University of California.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2012-2020 The Regents of the University of California.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of The Regents of the University of California.
#
import os
import re
import logging
import warnings

def showWarning(message, category, filename, lineno, file=None, line=None):
   return

_default_show_warning = warnings.showwarning

warnings.showwarning = showWarning
try:
   from Pegasus.api import *
except:
   try:
      from Pegasus.DAX3 import *
   except:
      from hubzero.submit.DAX3 import *
warnings.showwarning = _default_show_warning

from hubzero.submit.LogMessage        import getLogJobIdMessage as getLogMessage 
from hubzero.submit.ParameterTemplate import ParameterTemplate
from hubzero.submit.JobOutput         import JobOutput

class RemoteBatchPEGASUS:
   SUBMISSIONSCRIPTCOMMANDPREFIX = ''

   def __init__(self,
                hubUserName,
                hubUserId,
                submitterClass,
                session,
                instanceToken,
                wsJobId,
                runName,
                localJobId,
                instanceId,
                instanceDirectory,
                scratchDirectory,
                useSetup,
                pegasusVersion,
                pegasusHome,
                appScriptName,
                executable,
                arguments,
                isMultiCoreRequest,
                siteInfo,
                toolFilesInfo,
                dockerImageInfo,
                submissionScriptsInfo,
                managerInfo,
                x509SubmitProxy,
                sshIdentityPath,
                gridsite,
                pegasusSiteInfo,
                nGpus,
                gpn,
                memoryMB,
                wallTime,
                timeHistoryLogs):
      self.logger            = logging.getLogger(__name__)
      self.jobOutput         = JobOutput()
      self.hubUserName       = hubUserName
      self.hubUserId         = hubUserId
      self.submitterClass    = submitterClass
      self.session           = session
      self.instanceToken     = instanceToken
      self.wsJobId           = wsJobId
      self.runName           = runName
      self.localJobId        = localJobId
      self.instanceId        = instanceId
      self.instanceDirectory = instanceDirectory
      self.scratchDirectory  = scratchDirectory
      self.useSetup          = useSetup
      self.pegasusVersion    = pegasusVersion
      self.pegasusHome       = pegasusHome
      self.appScriptName     = appScriptName
      self.nGpus             = nGpus
      self.gpn               = gpn
      self.memoryMB          = memoryMB
      self.wallTime          = wallTime
      self.timestampStart    = timeHistoryLogs['timestampStart']
      self.timestampFinish   = timeHistoryLogs['timestampFinish']
      self.timeResults       = timeHistoryLogs['timeResults']
      self.checkDaxPath      = ""

      basename = os.path.basename(executable)
      if basename.startswith('pegasus-'):
         self.daxPath             = ""
         self.executable          = executable
         userArguments = arguments.split()
         scriptArguments = []
         for userArgument in userArguments:
            if os.path.isfile(userArgument):
               absolutePath = os.path.abspath(userArgument)
               scriptArguments.append(absolutePath)
            else:
               scriptArguments.append(userArgument)
         self.arguments           = ' '.join(scriptArguments)
         self.checkDaxExecutables = True
      else:
         self.daxPath             = os.path.join(self.instanceDirectory,"%s_%s.dax" % (self.localJobId,self.instanceId))
         self.executable          = 'pegasus-plan'
         self.arguments           = "--dax %s" % (self.daxPath)
         self.checkDaxExecutables = False

      self.isMultiCoreRequest       = isMultiCoreRequest
      self.computationMode          = managerInfo['computationMode']
      self.preManagerCommands       = managerInfo['preManagerCommands']
      self.managerCommand           = managerInfo['managerCommand']
      self.postManagerCommands      = managerInfo['postManagerCommands']
      self.pegasusTemplates         = siteInfo['pegasusTemplates']
      self.submissionScriptCommands = siteInfo['submissionScriptCommands']
      self.timePaths                = siteInfo['timePaths']
      self.toolFilesInfo            = toolFilesInfo
      self.dockerImageInfo          = dockerImageInfo
      self.submissionScriptsInfo    = submissionScriptsInfo
      self.x509SubmitProxy          = x509SubmitProxy
      self.sshIdentityPath          = sshIdentityPath
      self.gridsite                 = gridsite
      self.pegasusSiteInfo          = pegasusSiteInfo

      self.nodeFileName = ""
      self.nodeList     = []

      self.toolInputTemplateFileName  = ""
      self.toolInputTemplate          = ""
      self.toolOutputTemplateFileName = ""
      self.toolOutputTemplate         = ""


   def __buildSerialFile(self):
      rawSubmissionScript = self.submissionScriptsInfo.getSubmissionScript('Batch','PEGASUS','serial')

      commandsPEGASUS = ""
      if self.submissionScriptCommands:
         if self.SUBMISSIONSCRIPTCOMMANDPREFIX:
            commandSeparator = "\n%s " % self.SUBMISSIONSCRIPTCOMMANDPREFIX
            commandsPEGASUS = self.SUBMISSIONSCRIPTCOMMANDPREFIX + " " + commandSeparator.join(self.submissionScriptCommands)
         else:
            commandSeparator = "\n"
            commandsPEGASUS = commandSeparator.join(self.submissionScriptCommands)

      notRecognizedArguments = []
      discardedArguments     = []
      auxiliaryArguments     = []
      if os.path.exists(self.pegasusTemplates['rc']):
         try:
            fpRCTemplate = open(self.pegasusTemplates['rc'],'r')
            try:
               searchString = os.path.basename(self.executable) + '.arguments'
               settings = fpRCTemplate.readlines()
            except (IOError,OSError):
               self.logger.log(logging.ERROR,getLogMessage("%s could not be read" % (self.pegasusTemplates['rc'])))
            else:
               for setting in settings:
#pegasus-plan.arguments = --nocleanup
                  if setting.count(searchString) > 0:
                     try:
                        parameter,value = setting.split('=')
                        auxiliaryArguments.append(value.strip())
                     except:
                        pass
            finally:
               fpRCTemplate.close()
         except (IOError,OSError):
            self.logger.log(logging.ERROR,getLogMessage("%s could not be opened" % (self.pegasusTemplates['rc'])))

      userArguments = self.arguments.split()
      location = 0
      while location < len(userArguments):
         if   userArguments[location].startswith('-'):
            recognizeArgument = True
            saveDaxPath       = False
            if   userArguments[location].startswith('-D'):
               keepArgument = False
               discardedArguments.append(userArguments[location])
            elif userArguments[location] == '-d' or userArguments[location] == '--dax':
               keepArgument = True
               if self.checkDaxExecutables:
                  saveDaxPath = True
            elif userArguments[location] == '-b' or userArguments[location] == '--basename':
               keepArgument = True
            elif userArguments[location] == '-c' or userArguments[location] == '--cache':
               keepArgument = False
               discardedArguments.append(userArguments[location])
            elif userArguments[location] == '-C' or userArguments[location] == '--cluster':
               keepArgument = True
            elif userArguments[location] == '--conf':
               keepArgument = False
               discardedArguments.append(userArguments[location])
            elif userArguments[location] == '--dir':
               keepArgument = False
               discardedArguments.append(userArguments[location])
            elif userArguments[location] == '-f' or userArguments[location] == '--force':
               keepArgument = True
            elif userArguments[location] == '--force-replan':
               keepArgument = True
            elif userArguments[location] == '-g' or userArguments[location] == '--group':
               keepArgument = True
            elif userArguments[location] == '-h' or userArguments[location] == '--help':
               keepArgument = False
               discardedArguments.append(userArguments[location])
            elif userArguments[location] == '--inherited-rc-files':
               keepArgument = True
            elif userArguments[location] == '-j' or userArguments[location] == '--j':
               keepArgument = True
            elif userArguments[location] == '-n' or userArguments[location] == '--nocleanup':
               keepArgument = False
            elif userArguments[location] == '--cleanup':
               keepArgument = False
               discardedArguments.append(userArguments[location])
            elif userArguments[location] == '-o' or \
                 userArguments[location] == '--output' or userArguments[location] == '--output-site':
               keepArgument = False
               discardedArguments.append(userArguments[location])
            elif userArguments[location] == '-q' or userArguments[location] == '--quiet':
               keepArgument = True
            elif userArguments[location] == '--relative-submit-dir':
               keepArgument = False
               discardedArguments.append(userArguments[location])
            elif userArguments[location] == '-s' or userArguments[location] == '--sites':
               keepArgument = False
               discardedArguments.append(userArguments[location])
            elif userArguments[location] == '-S' or userArguments[location] == '--submit':
               keepArgument = False
               discardedArguments.append(userArguments[location])
            elif userArguments[location].startswith('-v') or userArguments[location] == '--verbose':
               keepArgument = True
            elif userArguments[location] == '-V' or userArguments[location] == '--version':
               keepArgument = False
               discardedArguments.append(userArguments[location])
            else:
               keepArgument = False
               recognizeArgument = False
               notRecognizedArguments.append(userArguments[location])

            if keepArgument and not saveDaxPath:
               auxiliaryArguments.append(userArguments[location])
            location += 1
            while location < len(userArguments) and not userArguments[location].startswith('-'):
               if   saveDaxPath:
                  self.checkDaxPath = userArguments[location]
               elif keepArgument:
                  auxiliaryArguments.append(userArguments[location])
               elif recognizeArgument:
                  discardedArguments.append(userArguments[location])
               else:
                  notRecognizedArguments.append(userArguments[location])
               location += 1
         elif userArguments[location].endswith('.yml'):
            self.checkDaxPath = userArguments[location]
            location += 1
         else:
            location += 1

      substitutions = {}
      substitutions["USESETUP"]                 = self.useSetup
      substitutions["PEGASUSVERSION"]           = self.pegasusVersion
      substitutions["RUNNAME"]                  = self.runName
      substitutions["JOBID"]                    = self.localJobId
      substitutions["INSTANCEID"]               = self.instanceId
      substitutions["INSTANCEDIRECTORY"]        = self.instanceDirectory
      substitutions["SCRATCHDIRECTORY"]         = self.scratchDirectory
      substitutions["EXECUTABLE"]               = self.executable
      substitutions["GRIDSITE"]                 = self.gridsite
      arguments = ' '.join(auxiliaryArguments)
      substitutions["ARGUMENTS"]                = arguments.strip()
      arguments = ' '.join(discardedArguments)
      substitutions["DISCARDED"]                = arguments.strip()
      arguments = ' '.join(notRecognizedArguments)
      substitutions["NOTRECOGNIZED"]            = arguments.strip()
      substitutions["TS_START"]                 = self.timestampStart
      substitutions["TS_FINISH"]                = self.timestampFinish
      substitutions["TIME_RESULTS"]             = self.timeResults
      substitutions["HUBUSERNAME"]              = self.hubUserName
      substitutions["HUBUSERID"]                = str(self.hubUserId)
      substitutions["SUBMISSIONSCRIPTCOMMANDS"] = commandsPEGASUS
      substitutions["PREMANAGERCOMMANDS"]       = "\n".join(self.preManagerCommands)
      substitutions["POSTMANAGERCOMMANDS"]      = "\n".join(self.postManagerCommands)
      if self.checkDaxPath:
         if   self.checkDaxPath.endswith('.dax'):
            substitutions["DAX"]                = "--dax %s" % (self.checkDaxPath)
         elif self.checkDaxPath.endswith('.xml'):
            substitutions["DAX"]                = "--dax %s" % (self.checkDaxPath)
         elif self.checkDaxPath.endswith('.yml'):
            substitutions["DAX"]                = self.checkDaxPath
      else:
         substitutions["DAX"] = ""
      substitutions["TIMEPATHS"]                = ' '.join(self.timePaths)

      template = ParameterTemplate(rawSubmissionScript)
      try:
         submissionScript = template.substitute_recur(substitutions)
      except KeyError as e:
         submissionScript = ""
         self.logger.log(logging.ERROR,getLogMessage("Pattern substitution failed for @@{%s}\n" % (e.args[0])))
         self.logger.log(logging.ERROR,getLogMessage("in SubmissionScripts/Distributor/Batch/PEGASUS/serial\n"))
      except TypeError:
         submissionScript = ""
         self.logger.log(logging.ERROR,getLogMessage("Submission script substitution failed:\n%s\n" % (rawSubmissionScript)))
      else:
         if self.SUBMISSIONSCRIPTCOMMANDPREFIX:
            submissionScript = re.sub("(\n)*\n%s" % (self.SUBMISSIONSCRIPTCOMMANDPREFIX),
                                           "\n%s" % (self.SUBMISSIONSCRIPTCOMMANDPREFIX),submissionScript)

      return(submissionScript)


   def __buildCatalogs(self):
      substitutions = {}
      substitutions["SUBMITTERCLASS"]       = str(self.submitterClass)
      substitutions["SESSION"]              = self.session
      substitutions["INSTANCETOKEN"]        = self.instanceToken
      substitutions["WSJOBID"]              = self.wsJobId
      substitutions["RUNNAME"]              = self.runName
      substitutions["JOBID"]                = self.localJobId
      substitutions["INSTANCEID"]           = self.instanceId
      substitutions["INSTANCEDIRECTORY"]    = self.instanceDirectory
      substitutions["BASESCRATCHDIRECTORY"] = os.path.basename(self.scratchDirectory)
      substitutions["SCRATCHDIRECTORY"]     = self.scratchDirectory
      substitutions["MEMORY"]               = self.memoryMB
      if self.x509SubmitProxy != "":
         substitutions["X509SUBMITPROXY"]   = self.x509SubmitProxy
      else:
         substitutions["X509SUBMITPROXY"]   = os.path.join(os.sep,'tmp','hub-proxy.%s' % (self.hubUserName))
      substitutions["PEGASUSHOME"]          = self.pegasusHome
      substitutions["SSHPRIVATEKEYPATH"]    = self.sshIdentityPath

      for templateType in self.pegasusTemplates:
         pegasusTemplatePath = self.pegasusTemplates[templateType]
         if pegasusTemplatePath != "":
            try:
               fpTemplate = open(pegasusTemplatePath,'r')
               try:
                  template = ''.join(fpTemplate.readlines())
               except (IOError,OSError):
                  self.logger.log(logging.ERROR,getLogMessage("%s could not be read" % (pegasusTemplatePath)))
               else:
                  pegasusTemplate = ParameterTemplate(template)
                  try:
                     pegasusText = pegasusTemplate.substitute_recur(substitutions)
                  except KeyError as e:
                     pegasusText = ""
                     self.logger.log(logging.ERROR,getLogMessage("Pattern substitution failed for @@{%s}\n" % (e.args[0])))
                     self.logger.log(logging.ERROR,getLogMessage("in Pegasus %s template\n" % (templateType)))
                  except TypeError:
                     pegasusText = ""
                     self.logger.log(logging.ERROR,getLogMessage("Pegasus template substitution failed:\n%s\n" % (template)))

                  if   templateType == 'rc':
                     pegasusFile = "%s_%s.pegasusrc" % (self.localJobId,self.instanceId)
                  elif templateType == 'sites':
                     if pegasusTemplatePath.endswith('.yml'):
                        pegasusFile = "%s_%s_sites.yml" % (self.localJobId,self.instanceId)
                     else:
                        pegasusFile = "%s_%s_sites.xml" % (self.localJobId,self.instanceId)
                  elif templateType == 'tc':
                     if pegasusTemplatePath.endswith('.yml'):
                        pegasusFile = "%s_%s_tc.yml" % (self.localJobId,self.instanceId)
                     else:
                        pegasusFile = "%s_%s_tc.txt" % (self.localJobId,self.instanceId)
                  pegasusPath = os.path.join(self.instanceDirectory,pegasusFile)
                  try:
                     fpPegasusFile = open(pegasusPath,'w')
                     try:
                        fpPegasusFile.write(pegasusText)
                     except (IOError,OSError):
                        self.logger.log(logging.ERROR,getLogMessage("%s could not be written" % (pegasusPath)))
                     finally:
                        fpPegasusFile.close()
                  except (IOError,OSError):
                     self.logger.log(logging.ERROR,getLogMessage("%s could not be opened" % (pegasusPath)))
               finally:
                  fpTemplate.close()
            except (IOError,OSError):
               self.logger.log(logging.ERROR,getLogMessage("%s could not be opened" % (pegasusTemplatePath)))


   def getBatchScript(self):
      batchScriptName = "%s_%s.pegasus" % (self.localJobId,self.instanceId)
      if self.isMultiCoreRequest:
         batchScript = ""
      else:
         batchScript = self.__buildSerialFile()
      batchScriptExecutable = True

      self.__buildCatalogs()

      return(batchScriptName,batchScript,batchScriptExecutable)


   def getBatchLog(self):
      batchLogName = ""

      return(batchLogName)


   def getBatchNodeList(self):
      return(self.nodeFileName,self.nodeList)


   def getBatchToolInputTemplate(self):
      return(self.toolInputTemplateFileName,self.toolInputTemplate)


   def getBatchToolOutputTemplate(self):
      return(self.toolOutputTemplateFileName,self.toolOutputTemplate)


   def getUserDaxExecutables(self):
      daxExecutables = []
      if self.checkDaxPath:
         daxExecutables = self.jobOutput.getDaxExecutables(self.checkDaxPath)

      return(daxExecutables)


   def getRemoteJobIdNumber(self,
                            remoteJobId):
      try:
# 1 job(s) submitted to cluster 105.
         remoteJobIdNumber = re.search('cluster [0-9]+',remoteJobId).group().split()[1] + ".0"
      except:
         remoteJobIdNumber = "-1"

      return(remoteJobIdNumber)


