#!/usr/bin/python3

#TODO: Factorize mirror detection/analysis code into mint-common
# It's used in here, but also in mintsources itself, and in mintupdate...

import datetime
import fileinput
import json
import lsb_release
import os
import pycurl
import random
import re
import sys
import urllib.request as urllibrequest

MIRROR_LIST = "/usr/share/mint-mirrors/linuxmint.list"
SOURCES_LIST = "/etc/apt/sources.list.d/official-package-repositories.list"

# In live installation mode, use the target list
TARGET_LIST = "/target/etc/apt/sources.list.d/official-package-repositories.list"
if os.path.exists(TARGET_LIST):
    SOURCES_LIST = TARGET_LIST

class Mirror():

    def __init__(self, country_code, url, name):
        self.country_code = country_code
        self.url = url
        self.name = name

class MirrorManager():

    def __init__(self, include_region=False):

        # Find local mirror
        self.exit_if(not os.path.exists(SOURCES_LIST), "No sources found at %s. Exiting." % SOURCES_LIST)
        self.exit_if(not os.path.exists(MIRROR_LIST), "No mirrors found at %s. Exiting." % MIRROR_LIST)

        self.mirror = None
        codename = lsb_release.get_distro_information()['CODENAME']
        with open(SOURCES_LIST, 'r') as sources_file:
            for line in sources_file:
                line = line.strip()
                if line.startswith("deb ") and "%s main upstream import" % codename in line:
                    self.mirror = line.split()[1]
                    if self.mirror.endswith("/"):
                        self.mirror = self.mirror[:-1]
                    break
        self.exit_if(self.mirror is None, "Unable to find the mirror being used.. exiting.")
        print ("Current mirror: %s" % self.mirror)

        # Establish list of local mirrors
        with open('/usr/lib/linuxmint/mintSources/countries.json') as data_file:
            self.countries = json.load(data_file)

        try:
            lookup = str(urllibrequest.urlopen('http://geoip.ubuntu.com/lookup', timeout=20).read())
            local_country_code = re.search('<CountryCode>(.*)</CountryCode>', lookup).group(1)
            if local_country_code == 'None':
                local_country_code = None
        except Exception as error:
            print(error)
            local_country_code = None
        self.exit_if(local_country_code is None, "Unable to detect the local country code. Exiting.")
        print ("Local country code: %s" % local_country_code)

        self.bordering_countries = []
        self.subregion = []
        self.region = []
        self.local_country = self.get_country(local_country_code)
        self.exit_if(self.local_country is None, "Unable to detect the local country. Exiting.")

        for country in self.countries:
            country_code = country["cca2"]
            if country["region"] == self.local_country["region"]:
                if country["subregion"] == self.local_country["subregion"]:
                    self.subregion.append(country_code)
                else:
                    self.region.append(country_code)
            if country["cca3"] in self.local_country["borders"]:
                self.bordering_countries.append(country_code)

        self.local_mirrors = []
        self.bordering_mirrors = []
        self.subregional_mirrors = []
        self.regional_mirrors = []
        self.other_mirrors = []

        for mirror in self.read_mirror_list(MIRROR_LIST):
            if mirror.country_code == local_country_code:
                self.local_mirrors.append(mirror)
            elif mirror.country_code in self.bordering_countries:
                self.bordering_mirrors.append(mirror)
            elif mirror.country_code in self.subregion:
                self.subregional_mirrors.append(mirror)
            elif mirror.country_code in self.region:
                self.regional_mirrors.append(mirror)

        self.bordering_mirrors = sorted(self.bordering_mirrors, key=lambda x: x.country_code)
        self.subregional_mirrors = sorted(self.subregional_mirrors, key=lambda x: x.country_code)
        self.regional_mirrors = sorted(self.regional_mirrors, key=lambda x: x.country_code)

        if (include_region):
            self.elligible_mirrors = self.local_mirrors + self.bordering_mirrors + self.subregional_mirrors + self.regional_mirrors
        else:
            self.elligible_mirrors = self.local_mirrors
            if len(self.elligible_mirrors) == 0:
                print ("Found no local mirrors, trying bordering countries..")
                self.elligible_mirrors = self.elligible_mirrors + self.bordering_mirrors
            if len(self.elligible_mirrors) == 0:
                print ("Found no local mirrors in bordering countries, trying subregional countries..")
                self.elligible_mirrors = self.elligible_mirrors + self.subregional_mirrors
            if len(self.elligible_mirrors) == 0:
                print ("Found no local mirrors in subregion, trying regional countries..")
                self.elligible_mirrors = self.elligible_mirrors + self.regional_mirrors
        self.exit_if(len(self.elligible_mirrors) == 0, "Found no elligible mirrors. Exiting")

        #Find the age of the Mint repositories
        self.mint_timestamp = self.get_url_last_modified("http://packages.linuxmint.com/db/version")
        self.exit_if(self.mint_timestamp is None, "Unable to find the age of the Mint repository. Exiting.")
        self.mint_date = datetime.datetime.fromtimestamp(self.mint_timestamp)
        now = datetime.datetime.now()
        mint_age = (now - self.mint_date).days
        #self.exit_if (mint_age <= 2, "The Mint repository is less than 2 days old. Exiting.")

        self.valid_mirrors = []
        for mirror in self.elligible_mirrors:
            (mirror_up_to_date, mirror_age) = self.is_mirror_up_to_date(mirror)
            if mirror_up_to_date:
                print ("UP TO DATE: ", mirror.country_code, mirror.name, mirror.url, "%d days" % mirror_age)
                self.valid_mirrors.append(mirror)
            else:
                print ("OUTDATED", mirror.country_code, mirror.name, mirror.url, "%d days" % mirror_age)

        self.exit_if(len(self.valid_mirrors) == 0, "Found no up to date mirrors")

        # Pick a random mirror from the list of valid alternatives
        selected_mirror = random.choice(self.valid_mirrors)
        self.exit_if(selected_mirror is None, "Could not select a mirror. Exiting.")
        print ("Randomly picked: ", selected_mirror.name)

        # Update sources.list
        for line in fileinput.input(SOURCES_LIST, inplace=1):
            if self.mirror in line:
                line = line.replace(self.mirror, selected_mirror.url)
            line = line.rstrip()
            print (line) # prints in the file
        fileinput.close()
        print ("Updated %s." % SOURCES_LIST)
        print ("Type 'apt update' to update your APT cache.")

    def exit_if(self, condition, error=None):
        if (condition):
            if (error is not None):
                print (error)
            sys.exit()

    def get_country(self, country_code):
        for country in self.countries:
            if country["cca2"] == country_code:
                return country
        return None

    def get_url_last_modified(self, url):
        try:
            c = pycurl.Curl()
            c.setopt(pycurl.URL, url)
            c.setopt(pycurl.CONNECTTIMEOUT, 5)
            c.setopt(pycurl.TIMEOUT, 30)
            c.setopt(pycurl.FOLLOWLOCATION, 1)
            c.setopt(pycurl.NOBODY, 1)
            c.setopt(pycurl.OPT_FILETIME, 1)
            c.perform()
            filetime = c.getinfo(pycurl.INFO_FILETIME)
            if filetime < 0:
                return None
            else:
                return filetime
        except:
            return None

    def read_mirror_list(self, path):
        mirror_list = []
        country_code = None
        mirrorsfile = open(path, "r")
        for line in mirrorsfile.readlines():
            line = line.strip()
            if line != "":
                if ("#LOC:" in line):
                    country_code = line.split(":")[1]
                else:
                    if country_code is not None:
                        if ("ubuntu-ports" not in line):
                            elements = line.split(" ")
                            url = elements[0]
                            if len(elements) > 1:
                                name = " ".join(elements[1:])
                            else:
                                name = url
                            if url[-1] == "/":
                                url = url[:-1]
                            mirror = Mirror(country_code, url, name)
                            mirror_list.append(mirror)
        return mirror_list

    def is_mirror_up_to_date(self, mirror):
        mirror_timestamp = self.get_url_last_modified("%s/db/version" % mirror.url)
        if mirror_timestamp is None:
            return (False, 0)
        mirror_date = datetime.datetime.fromtimestamp(mirror_timestamp)
        mirror_age = (self.mint_date - mirror_date).days
        if (mirror_age <= 2):
            return (True, mirror_age)
        else:
            return (False, mirror_age)

    def is_mirror_default(self):
        return (self.mirror == "http://packages.linuxmint.com")


if __name__ == "__main__":
    try:
        if os.getuid() != 0:
            print ("Please run this command as root!")
        elif "-h" in sys.argv:
            print ("Usage mint-switch-to-local-mirror [--include-region]")
            print ("")
        elif "--include-region" in sys.argv:
            manager = MirrorManager(include_region=True)
        else:
            manager = MirrorManager()
    except Exception as e:
        print ("Error: %s" % e)
