# @package      hubzero-submit-server
# @file         MySQLDatabase.py
# @copyright    Copyright (c) 2012-2020 The Regents of the University of California.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2012-2020 The Regents of the University of California.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of The Regents of the University of California.
#
import time
import MySQLdb
import logging
import traceback

from hubzero.submit.LogMessage import getLogIDMessage as getLogMessage

class MySQLDatabase:
   def __init__(self,
                mysqlHost="",
                mysqlUser="",
                mysqlPassword="",
                mysqlCA="",
                mysqlCiphers=[]):
      self.logger        = logging.getLogger(__name__)
      self.mysqlHost     = mysqlHost
      self.mysqlUser     = mysqlUser
      self.mysqlPassword = mysqlPassword
      self.mysqlDB       = ""
      self.mysqlCA       = mysqlCA
      self.mysqlCiphers  = mysqlCiphers
      self.userVersion   = 0
      self.schemaTable   = 'jobMonitorSchema'
      self.db            = None
      self.cursor        = None


   class MyDatabaseException(Exception):
      pass


   def getDBuserVersion(self):
      userVersion = -1
      sqlCommand = "SELECT count(*) FROM information_schema.tables WHERE table_name=%s"
      sqlParameters = (self.schemaTable,)
      try:
         result = self.select(sqlCommand,sqlParameters)
      except:
         pass
      else:
         if len(result) == 1:
            row = result[0]
            schemaTableExists = int(row[0])
            if not schemaTableExists:
               sqlScript = """CREATE TABLE %s (
                                 schemaId    INT          AUTO_INCREMENT,
                                 comment     VARCHAR(512) NOT NULL,
                                 whenAdded   TIMESTAMP    NOT NULL,
                                 PRIMARY KEY(schemaId)
                              ) ENGINE=InnoDB;""" % (self.schemaTable)
               try:
                  result = self.script(sqlScript)
               except:
                  pass
               else:
                  userVersion = 0
            else:
               sqlCommand = "SELECT COUNT(schemaId) FROM %s" % (self.schemaTable)
               result = self.select(sqlCommand)
               if len(result) == 1:
                  row = result[0]
                  userVersion = int(row[0])
               else:
                  userVersion = 0

      if userVersion < 0:
         self.logger.log(logging.ERROR,getLogMessage("Error in determining schema version.\n%s" % (str(result))))

      return(userVersion)


   def updateDBuserVersion(self,
                           comment):
      sqlCommand = "INSERT INTO " + self.schemaTable + " (comment) VALUES(%s)"
      sqlParameters = (comment,)
      try:
         self.insert(sqlCommand,sqlParameters)
      except:
         pass
      else:
         sqlCommand = "SELECT COUNT(schemaId) FROM %s" % (self.schemaTable)
         try:
            result = self.select(sqlCommand)
         except:
            pass
         else:
            row = result[0]
            userVersion = int(row[0])
            self.logger.log(logging.INFO,getLogMessage("SQL database updated to user_version = %d\n   %s" % (userVersion,comment)))


   def setDBuserVersion(self,
                        userVersion):
      self.userVersion = userVersion
      self.logger.log(logging.INFO,getLogMessage("SQL database set to user_version = %d" % (self.userVersion)))


   def connect(self,
               mysqlDB,
               maximumAttempts=8):
      connected = False
      if   self.db and (mysqlDB == self.mysqlDB):
         connected = True
      elif self.db:
         self.disconnect()

      if not connected:
         sslSettings = {}
         if self.mysqlCA:
            sslSettings['ca'] = self.mysqlCA
         if self.mysqlCiphers:
            sslSettings['cipher'] = ':'.join(self.mysqlCiphers)

         attempts = 0
         delay = 1
         maxdelay = 256
         while attempts < maximumAttempts:
            try:
               attempts += 1
               if sslSettings:
                  self.db = MySQLdb.connect(host=self.mysqlHost,
                                            user=self.mysqlUser,
                                            passwd=self.mysqlPassword,
                                            ssl=sslSettings,
                                            db=mysqlDB)
               else:
                  self.db = MySQLdb.connect(host=self.mysqlHost,
                                            user=self.mysqlUser,
                                            passwd=self.mysqlPassword,
                                            db=mysqlDB)
               self.cursor = self.db.cursor()
               self.mysqlDB = mysqlDB
               connected = True
               break
            except MySQLdb.Error as e:
               self.logger.log(logging.ERROR,getLogMessage("Exception in MySQLDatabase.connect: %d %s" % (e.args[0],e.args[1])))

            time.sleep(delay)
            if delay < maxdelay:
               delay = delay * 2

      return(connected)


   def isConnected(self,
                   mysqlDB):
      if self.db and (mysqlDB == self.mysqlDB):
         connected = True
      else:
         connected = False

      return(connected)


   def commit(self):
      if self.db:
         self.db.commit()


   def disconnect(self):
      if self.db:
         self.db.close()
         self.db      = None
         self.cursor  = None
         self.mysqlDB = ""


   def close(self):
      if self.db:
         self.db.close()
         self.db      = None
         self.cursor  = None
         self.mysqlDB = ""


   def select(self,
              sqlCommand,
              sqlParameters=None):
#     self.logger.log(logging.DEBUG,getLogMessage(sqlCommand))
      result = ()
      try:
         if sqlParameters:
            count = self.cursor.execute(sqlCommand,args=sqlParameters)
         else:
            count = self.cursor.execute(sqlCommand)
         result = self.cursor.fetchall()
      except MySQLdb.MySQLError as e:
         self.logger.log(logging.ERROR,getLogMessage("Exception in MySQLDatabase.select: %d %s\n%s" % \
                                                                     (e.args[0],e.args[1],sqlCommand)))
         raise MySQLDatabase.MyDatabaseException
      except:
         self.logger.log(logging.ERROR,getLogMessage("Some other MySQL exception."))
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         raise MySQLDatabase.MyDatabaseException
      else:
#        self.logger.log(logging.DEBUG,getLogMessage("select(count) = %d" % (count)))
         pass

      return(result)


   def update(self,
              sqlCommand,
              sqlParameters=None):
#     self.logger.log(logging.DEBUG,getLogMessage(sqlCommand))
      try:
         if sqlParameters:
            count = self.cursor.execute(sqlCommand,args=sqlParameters)
         else:
            count = self.cursor.execute(sqlCommand)
      except MySQLdb.MySQLError as e:
         self.logger.log(logging.ERROR,getLogMessage("Exception in MySQLDatabase.update: %d %s\n%s" % \
                                                                     (e.args[0],e.args[1],sqlCommand)))
         raise MySQLDatabase.MyDatabaseException
      except:
         self.logger.log(logging.ERROR,getLogMessage("Some other MySQL exception."))
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         raise MySQLDatabase.MyDatabaseException
      else:
#        self.logger.log(logging.DEBUG,getLogMessage("update(count) = %d" % (count)))
         pass


   def insert(self,
              sqlCommand,
              sqlParameters=None):
#     self.logger.log(logging.DEBUG,getLogMessage(sqlCommand))
      try:
         if sqlParameters:
            count = self.cursor.execute(sqlCommand,args=sqlParameters)
         else:
            count = self.cursor.execute(sqlCommand)
      except MySQLdb.MySQLError as e:
         self.logger.log(logging.ERROR,getLogMessage("Exception in MySQLDatabase.insert: %d %s\n%s" % \
                                                                     (e.args[0],e.args[1],sqlCommand)))
         raise MySQLDatabase.MyDatabaseException
      except:
         self.logger.log(logging.ERROR,getLogMessage("Some other MySQL exception."))
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         raise MySQLDatabase.MyDatabaseException
      else:
#        self.logger.log(logging.DEBUG,getLogMessage("insert(count) = %d" % (count)))
         pass


   def delete(self,
              sqlCommand,
              sqlParameters=None):
#     self.logger.log(logging.DEBUG,getLogMessage(sqlCommand))
      result = 0
      try:
         if sqlParameters:
            count = self.cursor.execute(sqlCommand,args=sqlParameters)
         else:
            count = self.cursor.execute(sqlCommand)
         result = count
      except MySQLdb.MySQLError as e:
         self.logger.log(logging.ERROR,getLogMessage("Exception in MySQLDatabase.delete: %d %s\n%s" % \
                                                                     (e.args[0],e.args[1],sqlCommand)))
         raise MySQLDatabase.MyDatabaseException
      except:
         self.logger.log(logging.ERROR,getLogMessage("Some other MySQL exception."))
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         raise MySQLDatabase.MyDatabaseException
      else:
#        self.logger.log(logging.DEBUG,getLogMessage("delete(count) = %d" % (count)))
         pass

      return(result)


   def script(self,
              sqlScript):
#     self.logger.log(logging.DEBUG,getLogMessage(sqlScript))
      try:
         for sqlCommand in sqlScript.split(';'):
            if sqlCommand.strip():
#              self.logger.log(logging.DEBUG,getLogMessage(sqlCommand))
               count = self.cursor.execute(sqlCommand)
#              self.logger.log(logging.DEBUG,getLogMessage("script(count) = %s" % (count)))
      except MySQLdb.MySQLError as e:
         self.logger.log(logging.ERROR,getLogMessage("Exception in MySQLDatabase.script: %d %s" % (e.args[0],e.args[1])))
         raise MySQLDatabase.MyDatabaseException
      except:
         self.logger.log(logging.ERROR,getLogMessage("Some other MySQL exception."))
         self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))
         raise MySQLDatabase.MyDatabaseException


   def dump(self):
      pass


