# -*- coding: utf-8 -*- #
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Command for creating TPU node and GCE VM combination."""

from __future__ import absolute_import
from __future__ import division
from __future__ import unicode_literals

from apitools.base.py.exceptions import HttpConflictError

from googlecloudsdk.api_lib.compute import utils
from googlecloudsdk.calliope import base
from googlecloudsdk.command_lib.compute import flags
from googlecloudsdk.command_lib.compute.instances import flags as instance_flags
from googlecloudsdk.command_lib.compute.tpus import flags as tpus_flags
from googlecloudsdk.command_lib.compute.tpus.execution_groups import util as tpu_utils


@base.ReleaseTracks(base.ReleaseTrack.ALPHA)
class Create(base.CreateCommand):
  """Create Google Compute TPUs along with VMs."""

  @classmethod
  def Args(cls, parser):
    flags.AddZoneFlag(parser, resource_type='tpu', operation_type='create')
    tpus_flags.AddTpuNameOverrideArg(parser)
    tpus_flags.AddPreemptibleFlag(parser)
    tpus_flags.AddTfVersionFlag(parser)
    tpus_flags.AddVmOnlyFlag(parser)
    tpus_flags.AddDeepLearningImagesFlag(parser)
    tpus_flags.AddDryRunFlag(parser)
    tpus_flags.GetAcceleratorTypeFlag().AddToParser(parser)
    tpus_flags.AddPreemptibleVmFlag(parser)
    tpus_flags.AddPortForwardingFlag(parser)
    tpus_flags.AddGceImageFlag(parser)
    tpus_flags.AddDiskSizeFlag(parser)

    instance_flags.AddMachineTypeArgs(parser)

  def Run(self, args):
    responses = []
    if not args.vm_only:
      if args.dry_run:
        responses.append(
            'Creating TPU with Name:{}, Accelerator type:{}, TF version:{}, '
            'Zone:{}'.format(
                args.name,
                args.accelerator_type,
                args.tf_version,
                args.zone))
      else:
        tpu = tpu_utils.TPUNode(self.ReleaseTrack())
        try:
          tpu_operation_ref = tpu.Create(args.name,
                                         args.accelerator_type, args.tf_version,
                                         args.zone, args.preemptible)
        except HttpConflictError:
          responses.append('TPU Node with name:{} already exists, '
                           'try a different name'.format(
                               args.name))
          return responses

    if args.dry_run:
      responses.append('Creating GCE VM with Name:{}, Zone:{}, Machine Type:{},'
                       ' Disk Size(GB):{}, Preemptible:{}'.format(
                           args.name, args.zone,
                           args.machine_type, utils.BytesToGb(args.disk_size),
                           args.preemptible_vm))
    else:
      instance = tpu_utils.Instance(self.ReleaseTrack())
      try:
        instance_operation_ref = instance.Create(
            args.name, args.zone, args.machine_type,
            utils.BytesToGb(args.disk_size), args.preemptible_vm)
      except HttpConflictError:
        err_msg = ('GCE VM with name:{} already exists, '
                   'try a different name.').format(args.name)
        if not args.vm_only:
          err_msg += (' TPU Node:{} creation is underway and will '
                      'need to be deleted.'.format(args.name))
        responses.append(err_msg)
        return responses

    if not args.vm_only and not args.dry_run:
      responses.append(
          tpu.WaitForOperation(tpu_operation_ref, 'Creating TPU node:{}'.format(
              args.name)))

    if not args.dry_run:
      instance_create_response = instance.WaitForOperation(
          instance_operation_ref, 'Creating GCE VM:{}'.format(args.name))
      responses.append(instance_create_response)

    if args.dry_run:
      responses.append('SSH to GCE VM:{}'.format(args.name))
    else:
      ssh_helper = tpu_utils.SSH(self.ReleaseTrack())
      responses.append(ssh_helper.SSHToInstance(args, instance_create_response))

    return responses
