# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
from argparse import ArgumentParser

from modelscope.cli.base import CLICommand
from modelscope.cli.utils import concurrent_download
from modelscope.hub.api import HubApi
from modelscope.hub.constants import DEFAULT_MAX_WORKERS, DEFAULT_SKILLS_DIR
from modelscope.hub.file_download import (dataset_file_download,
                                          model_file_download)
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
                                              snapshot_download)
from modelscope.hub.utils.utils import convert_patterns, resolve_endpoint
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
from modelscope.utils.logger import get_logger

logger = get_logger(log_level=logging.WARNING)


def subparser_func(args):
    """ Function which will be called for a specific sub parser.
    """
    return DownloadCMD(args)


class DownloadCMD(CLICommand):
    name = 'download'

    def __init__(self, args):
        self.args = args

    @staticmethod
    def define_args(parsers: ArgumentParser):
        """ define args for download command.
        """
        parser: ArgumentParser = parsers.add_parser(DownloadCMD.name)
        group = parser.add_mutually_exclusive_group()
        group.add_argument(
            '--model',
            type=str,
            help='The id of the model to be downloaded. For download, '
            'the id of either a model or dataset must be provided.')
        group.add_argument(
            '--dataset',
            type=str,
            help='The id of the dataset to be downloaded. For download, '
            'the id of either a model or dataset must be provided.')
        group.add_argument(
            '--collection',
            type=str,
            default=None,
            help='The ID of the collection to download (skills only)')
        parser.add_argument(
            'repo_id',
            type=str,
            nargs='?',
            default=None,
            help='Optional, '
            'ID of the repo to download, It can also be set by --model or --dataset.'
        )
        parser.add_argument(
            '--repo-type',
            choices=['model', 'dataset'],
            default='model',
            help="Type of repo to download from (defaults to 'model').",
        )
        parser.add_argument(
            '--token',
            type=str,
            default=None,
            help='Optional. Access token to download controlled entities.')
        parser.add_argument(
            '--revision',
            type=str,
            default=None,
            help='Revision of the entity (e.g., model).')
        parser.add_argument(
            '--cache_dir',
            type=str,
            default=None,
            help='Cache directory to save entity (e.g., model).')
        parser.add_argument(
            '--local_dir',
            type=str,
            default=None,
            help='File will be downloaded to local location specified by'
            'local_dir, in this case, cache_dir parameter will be ignored.')
        parser.add_argument(
            'files',
            type=str,
            default=None,
            nargs='*',
            help='Specify relative path to the repository file(s) to download.'
            "(e.g 'tokenizer.json', 'onnx/decoder_model.onnx').")
        parser.add_argument(
            '--include',
            nargs='*',
            default=None,
            type=str,
            help='Glob patterns to match files to download.'
            'Ignored if file is specified')
        parser.add_argument(
            '--exclude',
            nargs='*',
            type=str,
            default=None,
            help='Glob patterns to exclude from files to download.'
            'Ignored if file is specified')
        parser.add_argument(
            '--endpoint',
            type=str,
            default=None,
            help='ModelScope server endpoint, e.g. modelscope.cn or '
            'modelscope.ai   Full URL like '
            'https://modelscope.cn is also accepted. Scheme (https://) is '
            'auto-completed if omitted. Falls back to env MODELSCOPE_DOMAIN, '
            'then defaults to https://www.modelscope.cn. '
            'When omitted, the CLI auto-detects the correct site '
            '(cn/intl) for download.')
        parser.add_argument(
            '--max-workers',
            type=int,
            default=DEFAULT_MAX_WORKERS,
            help='The maximum number of workers to download files.')

        parser.set_defaults(func=subparser_func)

    def execute(self):
        if self.args.model or self.args.dataset:
            # the position argument of files will be put to repo_id.
            if self.args.repo_id is not None:
                if self.args.files:
                    self.args.files.insert(0, self.args.repo_id)
                else:
                    self.args.files = [self.args.repo_id]
        else:
            if self.args.repo_id is not None:
                if self.args.repo_type == 'model':
                    self.args.model = self.args.repo_id
                elif self.args.repo_type == 'dataset':
                    self.args.dataset = self.args.repo_id
                else:
                    raise Exception('Not support repo-type: %s'
                                    % self.args.repo_type)
        if not self.args.model and not self.args.dataset and not self.args.collection:
            raise Exception('Model, dataset, or collection must be set.')
        if self.args.endpoint:
            endpoint = resolve_endpoint(self.args.endpoint)
        else:
            endpoint = None
        cookies = None
        if self.args.token is not None:
            api = HubApi(endpoint=endpoint)
            cookies = api.get_cookies(access_token=self.args.token)
        if self.args.model:
            if len(self.args.files) == 1:  # download single file
                model_file_download(
                    self.args.model,
                    self.args.files[0],
                    cache_dir=self.args.cache_dir,
                    local_dir=self.args.local_dir,
                    revision=self.args.revision,
                    cookies=cookies,
                    token=self.args.token,
                    endpoint=endpoint)
            elif len(
                    self.args.files) > 1:  # download specified multiple files.
                snapshot_download(
                    self.args.model,
                    revision=self.args.revision,
                    cache_dir=self.args.cache_dir,
                    local_dir=self.args.local_dir,
                    allow_file_pattern=self.args.files,
                    max_workers=self.args.max_workers,
                    cookies=cookies,
                    token=self.args.token,
                    endpoint=endpoint)
            else:  # download repo
                snapshot_download(
                    self.args.model,
                    revision=self.args.revision,
                    cache_dir=self.args.cache_dir,
                    local_dir=self.args.local_dir,
                    allow_file_pattern=convert_patterns(self.args.include),
                    ignore_file_pattern=convert_patterns(self.args.exclude),
                    max_workers=self.args.max_workers,
                    cookies=cookies,
                    token=self.args.token,
                    endpoint=endpoint)
            print(f'\nSuccessfully Downloaded from model {self.args.model}.\n')
        elif self.args.dataset:
            dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION
            if len(self.args.files) == 1:  # download single file
                dataset_file_download(
                    self.args.dataset,
                    self.args.files[0],
                    cache_dir=self.args.cache_dir,
                    local_dir=self.args.local_dir,
                    revision=dataset_revision,
                    cookies=cookies,
                    token=self.args.token,
                    endpoint=endpoint)
            elif len(
                    self.args.files) > 1:  # download specified multiple files.
                dataset_snapshot_download(
                    self.args.dataset,
                    revision=dataset_revision,
                    cache_dir=self.args.cache_dir,
                    local_dir=self.args.local_dir,
                    allow_file_pattern=self.args.files,
                    max_workers=self.args.max_workers,
                    cookies=cookies,
                    token=self.args.token,
                    endpoint=endpoint)
            else:  # download repo
                dataset_snapshot_download(
                    self.args.dataset,
                    revision=dataset_revision,
                    cache_dir=self.args.cache_dir,
                    local_dir=self.args.local_dir,
                    allow_file_pattern=convert_patterns(self.args.include),
                    ignore_file_pattern=convert_patterns(self.args.exclude),
                    max_workers=self.args.max_workers,
                    cookies=cookies,
                    token=self.args.token,
                    endpoint=endpoint)
            print(
                f'\nSuccessfully Downloaded from dataset {self.args.dataset}.\n'
            )
        elif self.args.collection:
            api = HubApi(endpoint=endpoint, token=self.args.token)
            local_dir = self.args.local_dir or DEFAULT_SKILLS_DIR
            data = api.get_collection(
                self.args.collection, repo_type='skill', endpoint=endpoint)
            elements = data.get('CollectionElements',
                                {}).get('CollectionElementVoList', [])

            logger.info(
                f'Collection {self.args.collection} has {len(elements)} elements.'
            )

            if not elements:
                print(f'No skill elements found in collection: '
                      f'{self.args.collection}')
                return

            # Validate elements have required fields
            valid_elements = []
            for elem in elements:
                if not elem.get('ElementPath') or not elem.get('ElementName'):
                    logger.warning('Skipping malformed collection element: %s',
                                   elem)
                    continue
                valid_elements.append(elem)

            if not valid_elements:
                print(f'No valid skill elements found in collection: '
                      f'{self.args.collection}')
                return

            print(f'Found {len(valid_elements)} skill(s) in collection, '
                  f'downloading...')

            def _download_one_skill(element):
                element_path = element['ElementPath']
                element_name = element['ElementName']
                skill_id = f'{element_path}/{element_name}'
                try:
                    skill_dir = api.download_skill(
                        skill_id=skill_id,
                        local_dir=local_dir,
                        endpoint=endpoint)
                    return (skill_id, skill_dir, None)
                except Exception as e:
                    return (skill_id, None, str(e))

            concurrent_download(
                _download_one_skill,
                valid_elements,
                max_workers=self.args.max_workers,
                item_name='skill')
        else:
            pass  # noop
