##########################################################################
# test_task_feather.py
#
# Copyright (C) 2018
# Associated Universities, Inc. Washington DC, USA.
#
# This script is free software; you can redistribute it and/or modify it
# under the terms of the GNU Library General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This library is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Library General Public
# License for more details.
#
# [Add the link to the JIRA ticket here once it exists]
#
# Based on the requirements listed in plone found here:
# https://casadocs.readthedocs.io/en/stable/api/tt/casatasks.imaging.feather.html
#
#
##########################################################################
import getpass
import os
import unittest
import shutil
import numpy as np

import casatools
from casatasks import feather, casalog
tb = casatools.table()
ia = casatools.image()

### DATA ###
datapath = casatools.ctsys.resolve('unittest/feather/')

#Input files
intimg = 'orion_tfeather.im'
sdimg = 'orion_tsdmem.image'

# Output files
output = 'feathered.im'
output2 = 'feathered2.im'

logpath = casalog.logfile()
logname = 'testlog.log'

sd_image = os.path.join(datapath, "single_dish_model.im")
int_image = os.path.join(datapath, "interferometer_model.im")
av_exp_image = os.path.join(datapath, "av_feather.im")
av_got_image = "av_feather.im"

def get_map(infile):

    tb.open(infile)
    res = tb.getcol('map')
    tb.close()
    
    return res

class feather_test(unittest.TestCase):
    
    @classmethod
    def setUpClass(cls):
        pass
    
    def setUp(self):
        if not os.path.exists(intimg):
            os.symlink(os.path.join(datapath, intimg), intimg)
        if not os.path.exists(sdimg):
            os.symlink(os.path.join(datapath, sdimg), sdimg)

    def tearDown(self):
        for f in [output, output2, logname, av_got_image]:  
            if os.path.exists(f):
                if os.path.isdir(f):
                    shutil.rmtree(f)
                else:
                    os.remove(f)
            
        casalog.setlogfile(logpath)
    
    @classmethod
    def tearDownClass(cls):
        os.unlink(intimg)
        os.unlink(sdimg)
    
    def test_combine(self):
        '''
            test_combine
            --------------
            
            Check that interferometric and Single dish images can be combined
        '''
        
        feather(imagename=output, highres=intimg, lowres=sdimg)
        self.assertTrue(os.path.exists(output))
        
    def test_imagename(self):
        '''
            test_imagename
            ----------------
            
            Check that the imagename parameter gives the name of the output image file
        '''
        
        feather(imagename=output, highres=intimg, lowres=sdimg)
        feather(imagename=output2, highres=intimg, lowres=sdimg)
        
        self.assertTrue(os.path.exists(output))
        self.assertTrue(os.path.exists(output2))
        
    def test_highres(self):
        '''
            test_highres
            --------------
            
            Check that the interferometric image is provided with this parameter
            This parameter is nessisary to run the task
        '''
        
        with self.assertRaises(AssertionError):
            feather(imagename=output, lowres=sdimg)
            
#             casalog.setlogfile(logname)
#             feather(imagename=output, lowres=sdpath)
#             self.assertTrue(('SEVERE' in open(logname).read()))
                   
    def test_lowres(self):
        '''
            test_lowres
            -------------
            
            Check that the single dish image is provided with this parameter
            This parameter is nessisary to run the task
        '''
        
        with self.assertRaises(AssertionError):
            feather(imagename=output, highres=intimg)

#             casalog.setlogfile(logname)
#             feather(imagename=output, highres=interpath)
#             self.assertTrue('SEVERE' in open(logname).read())
        
        
    def test_sdfactor(self):
        '''
            test_sdfactor
            ---------------
            
            Check that differing sdfactors results in differing image files
        '''
        
        feather(imagename=output, highres=intimg, lowres=sdimg)
        feather(imagename=output2, highres=intimg, lowres=sdimg, sdfactor=0.5)
        
        res1 = get_map(output)
        res2 = get_map(output2)
        
        self.assertFalse(np.all(np.isclose(res1, res2)))
        
    def test_effdishdiam(self):
        '''
            test_effdishdiam
            ------------------
            
            Check that chaging the effective dish diameter results in differing image files
        '''
        
        feather(imagename=output, highres=intimg, lowres=sdimg)
        feather(imagename=output2, highres=intimg, lowres=sdimg, effdishdiam=1)
        
        res1 = get_map(output)
        res2 = get_map(output2)
        
        self.assertFalse(np.all(np.isclose(res1, res2)))
        
        with self.assertRaises(RuntimeError):
            feather(imagename=output2, highres=intimg, lowres=sdimg, effdishdiam=1000)
        
    def test_lowpassfiltersd(self):
        '''
            test_lowpassfiltersd
            ----------------------
            
            Check that lowpassfiltersd = True results in a different image than the default
        '''
        
        feather(imagename=output, highres=intimg, lowres=sdimg)
        feather(imagename=output2, highres=intimg, lowres=sdimg, lowpassfiltersd=True)
        
        res1 = get_map(output)
        res2 = get_map(output2)
        
        self.assertFalse(np.all(np.isclose(res1, res2)))


    def test_av_feather(self):

        def debug_numpy_conflict():
            import sys
            import numpy as np
            import importlib
            import traceback

            print("\n🧪 NumPy Debug Info 🧪")
            print("NumPy version:", np.__version__)
            print("NumPy location:", np.__file__)
            print("sys.path[0]:", sys.path[0])
            print("array2string from NumPy:", getattr(np, 'array2string', '❌ MISSING'))
            """
            try:
                mod = importlib.import_module('numpy.core.arrayprint')
                print("✅ numpy.core.arrayprint loaded successfully")
                print("Available in arrayprint:", dir(mod))
            except Exception:
                print("❌ Failed to import numpy.core.arrayprint")
                traceback.print_exc()
            print("🔎 Loaded 'array2string' modules:")
            for name in sys.modules:
                if 'array2string' in name:
                    print(" -", name, "->", sys.modules[name])

            print("🔍 Searching for conflicting 'array2string' symbols...")
            for name, module in sys.modules.items():
                if not hasattr(module, '__file__'):
                    continue
                try:
                    if hasattr(module, 'array2string'):
                        print(f"⚠️ {name} has array2string: {module.array2string}")
                except Exception:
                    continue
            """
       

        def dump_sys_modules():
            import sys
            print("\n🔬 Loaded modules before feather():")
            for name, mod in sys.modules.items():
                if name.startswith("numpy") or "array2string" in name:
                    print(f"{name:40s} -> {getattr(mod, '__file__', 'built-in')}")
        

        # we start with no toolviper being installed
        user = getpass.getuser()
        print("*** user", user)
        if (
            user == "casatest"
            or os.path.exists("_feather_uninstall_reinstall_for_tests")
        ):
            # running on bamboo or in account that wants to test absences
            # of packages
            import subprocess
            import sys
            import importlib
            print("*** 1")
            for package in ["toolviper", "astroviper"]:
                print("*** 2")
                if importlib.util.find_spec(package):
                    print(f"Try to uninstall {package} to test error if not installed")

                    try:
                        result = subprocess.run(
                            [
                                sys.executable, "-m", "pip", "uninstall",
                                "-y", package
                            ], capture_output=True, text=True, check=True
                        )
                    except subprocess.CalledProcessError as e:
                        print("🔴 pip uninstall failed for " + package)
                        print("stdout:", e.stdout)
                        print("stderr:", e.stderr)
                        raise
                with self.assertRaises(RuntimeError) as cm:
                    print("checking for expoected error because of missing module here")
                    feather(
                        imagename=av_got_image, highres=int_image,
                        lowres=sd_image, method="astroviper", ncores=1
                    )
                    print("*** 4")
                exc = cm.exception
                print("*** 5")
                expected_msg = (
                    "The output suggests you may have an incomplete install of "
                    f"{package}"
                )
                print("*** 6")
                pos = str(exc).find(expected_msg)
                print("*** 7")
                self.assertEqual(
                    pos, -1, msg=f'Unexpected exception was thrown: {exc}'
                )
                print(f"install package {package}")
                print("before install of", package)
                try:
                    result = subprocess.run(
                        [
                            sys.executable, "-m", "pip", "-v",
                                "--no-input", "install", package
                            ], capture_output=True, text=True, check=True
                    )
                    print("appears ", package, " was installed")
                    print("try to import it")
                    self.assertTrue(
                        importlib.util.find_spec(package),
                        f"Coutld not import {package}. Looks like it was not "
                        "installed properly"
                    )
                except subprocess.CalledProcessError as e:
                    print("🔴 pip install failed for " + package)
                    print("stdout:", e.stdout)
                    print("stderr:", e.stderr)
                    raise
        else:
            print(
                "NOTE: Regular user account so not running module absence "
                "tests. If you want to run these tests, touch "
                "_feather_uninstall_reinstall_for_tests. WARNING: Doing so "
                "will cause this test to uinstall (if present) and (re)install "
                "modules toolviper and astroviper."
            )
        for ncores in [1, 2, 4, 8]:
            if os.path.exists(av_got_image):
                print("deleting", av_got_image)
                shutil.rmtree(av_got_image)
                print(av_got_image, "deleted")
            print("run feather with", ncores, "cores")
            debug_numpy_conflict()
            dump_sys_modules()
            import numpy.core.arrayprint
            print("🔥 arrayprint module:", numpy.core.arrayprint.__file__)
            print("🔥 arrayprint.dir():", dir(numpy.core.arrayprint))
            print("🔥 array2string:", getattr(numpy.core.arrayprint, 'array2string', '❌ NOT FOUND'))
            feather(
                imagename=av_got_image, highres=int_image,
                lowres=sd_image, method="astroviper", ncores=ncores
            )
            print("**** end casatasks.feather")
            print(f"open {av_got_image}")
            ia.open(av_got_image)
            got = ia.getchunk()
            ia.done()
            print(f"open {av_exp_image}")
            ia.open(av_exp_image)
            expec = ia.getchunk()
            ia.done()
            self.assertTrue(
                np.allclose(got, expec), "Incorrect astroviper feather output"
            )

if __name__ == '__main__':
    unittest.main()
