# -*- coding: utf-8 -*-
#
#    Copyright (c) 2015 Billy2011, MediaPortal Team
#
# Copyright (C) 2009-2010 Fluendo, S.L. (www.fluendo.com).
# Copyright (C) 2009-2010 Marc-Andre Lureau <marcandre.lureau@gmail.com>

# This file may be distributed and/or modified under the terms of
# the GNU General Public License version 2 as published by
# the Free Software Foundation.
# This file is distributed without any warranty; without even the implied
# warranty of merchantability or fitness for a particular purpose.
# See "LICENSE" in the source distribution for more information.

from itertools import ifilter
import logging
import os, os.path
import tempfile
import urlparse

from twisted.python import log
from twisted.web import client
from twisted.internet import defer, reactor, task
from twisted.internet.task import deferLater

from m3u8 import M3U8

class HLSFetcher(object):

	def __init__(self, url, options=None, program=1, headers=None):
		if '|X-Forwarded-For=' in url:
			url, header_val = url.split('|X-Forwarded-For=')
			headers = {'X-Forwarded-For':header_val}
		self.url = url
		self.headers = headers
		self.program = program
		if options:
			self.path = options.get('path',None)
			self.referer = options.get('referer',None)
			self.bitrate = options.get('bitrate',200000)
			self.n_segments_keep = options.get('keep',2)
			self.nbuffer = options.get('buffer',3)
		else:
			self.path = None
			self.referer = None
			self.bitrate = 200000
			self.n_segments_keep = 3
			self.nbuffer = 3
		if not self.path:
			self.path = tempfile.mkdtemp()

		self._program_playlist = None
		self._file_playlist = None
		self._cookies = {}
		self._cached_files = {} # sequence n -> path
		self._run = True

		self._files = None # the iter of the playlist files download
		self._next_download = None # the delayed download defer, if any
		self._file_playlisted = None # the defer to wait until new files are added to playlist

		self._pl_task = None
		self._seg_task = None

	def _get_page(self, url):
		def got_page(content):
			print("Cookies: %r" % self._cookies)
			return content
		def got_page_error(e, url):
			print(url)
			log.err(e)
			return e

		#print 'getpage:',url
		url = url.encode("utf-8")
		if 'HLS_RESET_COOKIES' in os.environ.keys():
			self._cookies = {}
		headers = {}
		if self.referer:
			headers['Referer'] = self.referer
		if self.headers:
			headers.update(self.headers)
		d = client.getPage(url, cookies=self._cookies, headers=headers)
		d.addCallback(got_page)
		d.addErrback(got_page_error, url)
		return d

	def _download_page(self, url, path):
		# client.downloadPage does not support cookies!
		def _check(x):
			print("Received segment of %r bytes." % len(x))
			return x

		d = self._get_page(url)
		d.addCallback(_check)
		return d

		return d

	def _download_segment(self, f):
		print '_download_segment:'
		url = make_url(self._file_playlist.url, f['file'])
		name = urlparse.urlparse(f['file']).path.split('/')[-1]
		path = os.path.join(self.path, name)
		d = self._download_page(url, path)
		if self.n_segments_keep != 0:
			file = open(path, 'wb')
			d.addCallback(lambda x: file.write(x))
			d.addBoth(lambda _: file.close())
			d.addCallback(lambda _: path)
			d.addErrback(self._got_file_failed)
			d.addCallback(self._got_file, url, f)
		else:
			d.addCallback(lambda _: (None, path, f))
		return d

	def delete_cache(self, f):
		keys = self._cached_files.keys()
		for i in ifilter(f, keys):
			filename = self._cached_files[i]
			print("Removing %r" % filename)
			os.remove(filename)
			del self._cached_files[i]
		self._cached_files

	def _got_file_failed(self, e):
		if self._new_filed:
			self._new_filed.errback(e)
			self._new_filed = None

	def _got_file(self, path, url, f):
		print("Saved " + url + " in " + path)
		self._cached_files[f['sequence']] = path
		if self.n_segments_keep != -1:
			self.delete_cache(lambda x: x <= f['sequence'] - self.n_segments_keep)
		if self._new_filed:
			self._new_filed.callback((path, url, f))
			self._new_filed = None
		return (path, url, f)

	def _get_next_file(self):
		next = self._files.next()
		if next:
			d = self._download_segment(next)
			return d
		elif not self._file_playlist.endlist():
			self._seg_task.stop()
			self._file_playlisted = defer.Deferred()
			self._file_playlisted.addCallback(lambda x: self._get_next_file())
			self._file_playlisted.addCallback(self._next_file_delay)
			self._file_playlisted.addCallback(self._seg_task.start)
			return self._file_playlisted

	def _handle_end(self, failure):
		failure.trap(StopIteration)
		print "End of media"
		reactor.stop()

	def _next_file_delay(self, f):
		delay = f[2]["duration"]
		# FIXME not only the last nbuffer, but the nbuffer -1 ...
		# I hope this is fixed (Billy2011)
		if self.nbuffer > 0:
			print '_cached_files:',len(self._cached_files)
			for i in range(0,self.nbuffer):
				#if self._cached_files.has_key(f[2]['sequence'] - (self.nbuffer - 1)):
				if self._cached_files.has_key(f[2]['sequence'] - i):
					return delay
			delay = 0
		elif self._file_playlist.endlist():
			delay = 1
		return delay

	def _get_files_loop(self):
		if not self._seg_task:
			self._seg_task = task.LoopingCall(self._get_next_file)
		d = self._get_next_file()
		d.addCallback(self._next_file_delay)
		d.addCallback(self._seg_task.start)
		return d

	def _playlist_updated(self, pl):
		print '_playlist_updated:'
		if pl.has_programs():
			# if we got a program playlist, save it and start a program
			self._program_playlist = pl
			(program_url, _) = pl.get_program_playlist(self.program, self.bitrate)
			l = make_url(self.url, program_url)
			return self._reload_playlist(M3U8(l))
		elif pl.has_files():
			# we got sequence playlist, start reloading it regularly, and get files
			self._file_playlist = pl
			if not self._files:
				self._files = pl.iter_files()
			if not pl.endlist():
				if not self._pl_task:
					self._pl_task = task.LoopingCall(self._reload_playlist, pl)
					self._pl_task.start(10, False).addCallback(self.taskStopped, 'pl_task').addErrback(self.taskStopped, 'pl_task', True)
			if self._file_playlisted:
				self._file_playlisted.callback(pl)
				self._file_playlisted = None
		else:
			raise
		return pl

	def _got_playlist_content(self, content, pl):
		print '_got_playlist_content:'
		if not self._run: return None
		if not pl.update(content):
			print 'pl:False'
			# if the playlist cannot be loaded, start a reload timer
			self._pl_task.stop()
			dly = pl.reload_delay()
			self._pl_task.start(dly, False).addCallback(self.taskStopped, 'pl_task').addErrback(self.taskStopped, 'pl_task', True)
			d = deferLater(reactor, dly, self._fetch_playlist, pl)
			d.addCallback(self._got_playlist_content, pl)
			return d
		return pl

	def taskStopped(self, res, nm, err=False):
		print nm,':'
		if not err:
			print 'stopped.'
		else:
			print res

	def _fetch_playlist(self, pl):
		print('fetching %r' % pl.url)
		d = self._get_page(pl.url)
		return d

	def _reload_playlist(self, pl):
		if self._run:
			d = self._fetch_playlist(pl)
			d.addCallback(self._got_playlist_content, pl)
			d.addCallback(self._playlist_updated)
			return d
		else:
			return None

	def get_file(self, sequence):
		print 'get_file:'
		d = defer.Deferred()
		keys = self._cached_files.keys()
		try:
			sequence = ifilter(lambda x: x >= sequence, keys).next()
			filename = self._cached_files[sequence]
			print '*got file*'
			d.callback(filename)
		except:
			d.addCallback(lambda x: self.get_file(sequence))
			self._new_filed = d
			keys.sort()
			print('waiting for %r (available: %r)' % (sequence, keys))
		return d

	def _start_get_files(self, x):
		self._new_filed = defer.Deferred()
		self._get_files_loop()
		return self._new_filed

	def start(self):
		if self._run:
			self._files = None
			d = self._reload_playlist(M3U8(self.url))
			d.addCallback(self._start_get_files)
			return d

	def stop(self):
		self._run = False
		if self._pl_task != None:
			self._pl_task.stop()
		if self._seg_task != None:
			self._seg_task.stop()
		print "Canceling deferreds"
		if self._new_filed != None:
			self._new_filed.cancel()

def make_url(base_url, url):
    if urlparse.urlsplit(url).scheme == '':
        url = urlparse.urljoin(base_url, url)
    if 'HLS_PLAYER_SHIFT_PORT' in os.environ.keys():
        shift = int(os.environ['HLS_PLAYER_SHIFT_PORT'])
        p = urlparse.urlparse(url)
        loc = p.netloc
        if loc.find(":") != -1:
            loc, port = loc.split(':')
            port = int(port) + shift
            loc = loc + ":" + str(port)
        elif p.scheme == "http":
            port = 80 + shift
            loc = loc + ":" + str(shift)
        p = urlparse.ParseResult(scheme=p.scheme,
                                 netloc=loc,
                                 path=p.path,
                                 params=p.params,
                                 query=p.query,
                                 fragment=p.fragment)
        url = urlparse.urlunparse(p)
    return url

