import os
import time
import base64
import mimetypes
import argparse
from pathlib import Path
from volcenginesdkarkruntime import Ark


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--prompt', required=True)
    ap.add_argument('--model', default='doubao-seedance-1-5-pro-251215')
    ap.add_argument('--image-url')
    ap.add_argument('--image-file')
    ap.add_argument('--poll-interval', type=int, default=5)
    ap.add_argument('--max-rounds', type=int, default=60)
    ap.add_argument('--json', action='store_true')
    args = ap.parse_args()

    api_key = os.environ.get('ARK_API_KEY')
    if not api_key:
        raise SystemExit('Missing ARK_API_KEY')

    client = Ark(base_url='https://ark.cn-beijing.volces.com/api/v3', api_key=api_key)

    content = [{'type': 'text', 'text': args.prompt}]
    image_url = args.image_url
    if args.image_file:
        image_path = Path(args.image_file)
        mime = mimetypes.guess_type(str(image_path))[0] or 'application/octet-stream'
        image_url = f"data:{mime};base64,{base64.b64encode(image_path.read_bytes()).decode('ascii')}"
    if image_url:
        content.append({'type': 'image_url', 'image_url': {'url': image_url}})

    print('----- create request -----', flush=True)
    create_result = client.content_generation.tasks.create(model=args.model, content=content)
    print(create_result, flush=True)
    task_id = create_result.id
    print(f'TASK_ID={task_id}', flush=True)

    for i in range(1, args.max_rounds + 1):
        result = client.content_generation.tasks.get(task_id=task_id)
        status = result.status
        print(f'round={i} status={status}', flush=True)
        if status == 'succeeded':
            print('----- task succeeded -----', flush=True)
            print(result, flush=True)
            content_obj = getattr(result, 'content', None)
            video_url = getattr(content_obj, 'video_url', None) if content_obj else None
            if video_url:
                print(f'VIDEO_URL={video_url}', flush=True)
            return
        if status == 'failed':
            print('----- task failed -----', flush=True)
            print(result, flush=True)
            err = getattr(result, 'error', None)
            if err:
                print(f'ERROR={err}', flush=True)
            raise SystemExit(1)
        time.sleep(args.poll_interval)

    raise SystemExit('Polling timed out')


if __name__ == '__main__':
    main()
