# @package      hubzero-submit-monitors
# @file         Cloud.py
# @copyright    Copyright (c) 2012-2020 The Regents of the University of California.
# @license      http://opensource.org/licenses/MIT MIT
#
# Copyright (c) 2012-2020 The Regents of the University of California.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# HUBzero is a registered trademark of The Regents of the University of California.
#
import os
import sys
import socket
import time
import re
import traceback
import logging

from operator import attrgetter

import boto
from boto.s3.key import Key
from boto.s3.connection import OrdinaryCallingFormat
from boto.s3.connection import S3Connection
from boto.ec2.connection import EC2Connection
from boto.ec2.regioninfo import RegionInfo
from boto.ec2 import regions

from hubzero.submit.LogMessage import getLogIDMessage as getLogMessage

try:
   iterRange = xrange
except NameError as e:
   iterRange = range

class Cloud:
   def __init__(self,
                cloudsInfo,
                cloudName):
      self.logger = logging.getLogger(__name__)

      host,port,accessId,accessSecret,image,maximumInstances,maximumJobs,maximumIdleTime = cloudsInfo.getCloudInfo(cloudName)
      self.name             = cloudName
      self.host             = host
      self.port             = port
      self.accessId         = accessId
      self.accessSecret     = accessSecret
      self.imageName        = image
      self.maximumInstances = maximumInstances
      self.maximumJobs      = maximumJobs
      self.maximumIdleTime  = maximumIdleTime

      self.HEADERS = {'ID': {'get': attrgetter('id'), 'length':15},
                      'Zone': {'get': attrgetter('placement'), 'length':20},
                      'Groups': {'get': attrgetter('groups'), 'length':30},
                      'Hostname': {'get': attrgetter('public_dns_name'), 'length':50},
                      'State': {'get': attrgetter('state'), 'length':15},
                      'Image': {'get': attrgetter('image_id'), 'length':25},
                      'Type': {'get': attrgetter('instance_type'), 'length':15},
                      'IP': {'get': attrgetter('ip_address'), 'length':16},
                      'PrivateIP': {'get': attrgetter('private_ip_address'), 'length':16},
                      'Key': {'get': attrgetter('key_name'), 'length':25},
                      'T:': {'length': 30},
                     }
#     self.headers = ('ID', 'Zone', 'Groups', 'Hostname', 'State', 'Image', 'Type', 'IP', 'PrivateIP', 'Key', 'T:')
      self.headers = ('ID', 'Hostname', 'IP', 'Image', 'State')

      self.region = RegionInfo(name="nimbus",endpoint=self.host)
      self.ec2conn = boto.connect_ec2(self.accessId,self.accessSecret,region=self.region,port=self.port)

      self.instances = {}
      self.nextInstanceId = 1


   def __getColumn(self,
                   name,
                   instance):
      if name.startswith('T:'):
         _, tag = name.split(':', 1)
         return instance.tags.get(tag,'')

      return self.HEADERS[name]['get'](instance)


   def __getFormatString(self):
      formatString = ""
      for header in self.headers:
         if header.startswith('T:'):
            formatString += " %%-%ds" % self.HEADERS['T:']['length']
         else:
            formatString += " %%-%ds" % self.HEADERS[header]['length']
      formatString = formatString.lstrip()

      return(formatString)


   def __getInstanceProperty(self,
                             instance,
                             property):
      return(attrgetter(property)(instance))


   def __getAllReservations(self):
      reservations = None
      delay = 0.
      nTry = 0
      connected = False
      while not connected:
         time.sleep(delay)
         try:
            nTry += 1
            reservations = self.ec2conn.get_all_instances()
            connected = True
         except socket.error:
            pass
         except:
            self.logger.log(logging.ERROR,getLogMessage(traceback.format_exc()))

         delay = 10.

      return(reservations)


   def loadActiveInstances(self,
                           fpDump):
      line = fpDump.readline()
      self.nextInstanceId = int(line)
      line = fpDump.readline()
      nInstances = int(line)
      for instance in iterRange(nInstances):
         line = fpDump.readline()
         instanceId = line.strip()
         line = fpDump.readline()
         creationTime,shutdownTime,mostRecentJobAssignment,mostRecentJobTermination = line.strip().split()
         line = fpDump.readline()
         assignedJobs = []
         nAssignedJobs = int(line)
         for assignedJob in iterRange(nAssignedJobs):
            line = fpDump.readline()
            assignedJobs.append(line.strip())
         self.instances[instanceId] = {'creationTime':float(creationTime), \
                                       'shutdownTime':float(shutdownTime), \
                                       'mostRecentJobAssignment':float(mostRecentJobAssignment), \
                                       'mostRecentJobTermination':float(mostRecentJobTermination), \
                                       'assignedJobs':assignedJobs \
                                      }


   def dumpActiveInstances(self,
                           fpDump):
      fpDump.write("%d\n" % (self.nextInstanceId))
      nInstances = len(self.instances)
      fpDump.write("%d\n" % (nInstances))
      for instanceId in self.instances:
         fpDump.write("%s\n" % (instanceId))
         fpDump.write("%f %f %f %f\n" % (self.instances[instanceId]['creationTime'], \
                                         self.instances[instanceId]['shutdownTime'], \
                                         self.instances[instanceId]['mostRecentJobAssignment'], \
                                         self.instances[instanceId]['mostRecentJobTermination']))
         nAssignedJobs = len(self.instances[instanceId]['assignedJobs'])
         fpDump.write("%d\n" % (nAssignedJobs))
         for assignedJob in self.instances[instanceId]['assignedJobs']:
            fpDump.write("%s\n" % (assignedJob))


   def getInstanceHostname(self,
                           instanceId):
      hostName = ""
      for reservations in self.__getAllReservations():
         for instance in reservations.instances:
            if(self.__getInstanceProperty(instance,'id') == instanceId):
               hostName = self.__getInstanceProperty(instance,'public_dns_name')

      return(hostName)


   def terminate(self):
      for reservations in self.__getAllReservations():
         for instance in reservations.instances:
            instanceId = self.__getInstanceProperty(instance,'id')
            if self.terminateInstance(instanceId):
               self.logger.log(logging.INFO,getLogMessage("instance %s terminated on cloud %s" % (instanceId,self.name)))

      del self.instances
      self.instances = {}


   def getNextInstanceId(self):
      return(self.nextInstanceId)


   def createInstance(self):
      instanceId = None
      nInstances = len(self.instances)
      if nInstances < self.maximumInstances:
         reservation = self.ec2conn.run_instances(self.imageName)
         instance = reservation.instances[0]
         instanceId = self.__getInstanceProperty(instance,'id')

# wait for instance to start before advertising
         while instance.state != 'running':
            time.sleep(10)
            instance.update()

# wait for good measure - allow ssh process to start
         time.sleep(30)

         self.instances[instanceId] = {'creationTime':time.time(), \
                                       'shutdownTime':time.time()+60.*60**24., \
                                       'mostRecentJobAssignment':0, \
                                       'mostRecentJobTermination':time.time(), \
                                       'assignedJobs':[] \
                                      }
         self.nextInstanceId += 1

      return(instanceId)


   def __reportInstance(self,
                        instance):
      report = ""
      formatString = self.__getFormatString()

      report  = formatString % self.headers
      report += '\n' + "-" * len(formatString % self.headers)
      report += '\n' + formatString % tuple(self.__getColumn(header,instance) for header in self.headers)

      instanceId = self.__getInstanceProperty(instance,'id')
      report += "\n\n%-15s %-24s %-12s %-24s %-24s %-13s\n" % ("Instance","CreationTime","JobsAssigned", \
                                                               "MostRecentJobAssignment","MostRecentJobTermination", \
                                                               "AssignedJobs")
      nAssignedJobs = len(self.instances[instanceId]['assignedJobs'])
      assignedJobs = ' '.join(self.instances[instanceId]['assignedJobs'])
      report += "%-15s %-24s %-12d %-24s %-24s %-13s\n" % (instanceId, \
                                                           time.ctime(self.instances[instanceId]['creationTime']), \
                                                           nAssignedJobs, \
                                                           time.ctime(self.instances[instanceId]['mostRecentJobAssignment']), \
                                                           time.ctime(self.instances[instanceId]['mostRecentJobTermination']), \
                                                           assignedJobs)

      return(report)


   def reportAllInstances(self):
      report = ""
      formatString = self.__getFormatString()

      report  = formatString % self.headers
      report += '\n' + "-" * len(formatString % self.headers)
      for reservations in self.__getAllReservations():
         groups = [group.id for group in reservations.groups]
         for instance in reservations.instances:
            instance.groups = ','.join(groups)
            report += '\n' + formatString % tuple(self.__getColumn(header,instance) for header in self.headers)

      report += "\n\n%-15s %-24s %-12s %-24s %-24s %-13s\n" % ("Instance","CreationTime","JobsAssigned", \
                                                               "MostRecentJobAssignment","MostRecentJobTermination", \
                                                               "AssignedJobs")
      instanceIds = self.instances.keys()
      instanceIds.sort()
      for instanceId in instanceIds:
         nAssignedJobs = len(self.instances[instanceId]['assignedJobs'])
         assignedJobs = ' '.join(self.instances[instanceId]['assignedJobs'])
         report += "%-15s %-24s %-12d %-24s %-24s %-13s\n" % (instanceId, \
                                                              time.ctime(self.instances[instanceId]['creationTime']), \
                                                              nAssignedJobs, \
                                                              time.ctime(self.instances[instanceId]['mostRecentJobAssignment']), \
                                                              time.ctime(self.instances[instanceId]['mostRecentJobTermination']), \
                                                              assignedJobs)

      return(report)


   def incrementInstance(self,
                         instanceId,
                         jobId):
      instanceIncremented = False
      if instanceId in self.instances:
         if not jobId in self.instances[instanceId]['assignedJobs']:
            nAssignedJobs = len(self.instances[instanceId]['assignedJobs'])
            self.instances[instanceId]['mostRecentJobAssignment'] = time.time()
            self.instances[instanceId]['assignedJobs'].append(jobId)
            instanceIncremented = True

      return(instanceIncremented)


   def decrementInstance(self,
                         instanceId,
                         jobId):
      instanceDecremented = False
      if instanceId in self.instances:
         if jobId in self.instances[instanceId]['assignedJobs']:
            nAssignedJobs = len(self.instances[instanceId]['assignedJobs'])
            self.instances[instanceId]['mostRecentJobTermination'] = time.time()
            self.instances[instanceId]['assignedJobs'].remove(jobId)
            instanceDecremented = True

      return(instanceDecremented)


   def terminateInstance(self,
                         instanceId):
      instanceTerminated = False
      for reservations in self.__getAllReservations():
         for instance in reservations.instances:
            if(self.__getInstanceProperty(instance,'id') == instanceId):
               instance.terminate()

               while instance.state != 'terminated':
                  time.sleep(5)
                  instance.update()
               instanceTerminated = True

      if instanceId in self.instances:
         del self.instances[instanceId]

      return(instanceTerminated)


   def getAvailableInstanceWithMinimumJobs(self,
                                           wallTime):
      minimumJobsInstanceId = None

      minimumJobs = self.maximumJobs+1
      instanceIds = self.instances.keys()
      for instanceId in instanceIds:
         if time.time()+float(wallTime)+60.*10. < self.instances[instanceId]['shutdownTime']:
            nAssignedJobs = len(self.instances[instanceId]['assignedJobs'])
            if nAssignedJobs < self.maximumJobs:
               if nAssignedJobs < minimumJobs:
                  minimumJobs = nAssignedJobs
                  minimumJobsInstanceId = instanceId

      return(minimumJobsInstanceId)


   def purgeIdleInstances(self):
      oldestIdleTime = time.time() - 60*self.maximumIdleTime
      markedForDeletion = []
      for reservations in self.__getAllReservations():
         for instance in reservations.instances:
            instanceId = self.__getInstanceProperty(instance,'id')
            if instanceId in self.instances:
               nAssignedJobs = len(self.instances[instanceId]['assignedJobs'])
               if nAssignedJobs == 0:
                  if self.instances[instanceId]['mostRecentJobTermination'] < oldestIdleTime:
                     markedForDeletion.append(instanceId)
            else:
               markedForDeletion.append(instanceId)

      nPurgedInstances = len(markedForDeletion)
      for instanceId in markedForDeletion:
         if self.terminateInstance(instanceId):
            self.logger.log(logging.INFO,getLogMessage("instance %s terminated on cloud %s" % (instanceId,self.name)))
      del markedForDeletion

      return(nPurgedInstances)


