##########################################################################
# test_task_msuvbin.py
#
# Copyright (C) 2026
# Associated Universities, Inc. Washington DC, USA.
#
# This script is free software; you can redistribute it and/or modify it
# under the terms of the GNU Library General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This library is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Library General Public
# License for more details.
#
# Parent feature ticket:
# https://open-jira.nrao.edu/browse/CAS-13794
# Test ticket:
# https://open-jira.nrao.edu/browse/CAS-14130
#
##########################################################################

import os
import shutil
import unittest
import numpy as np

from pathlib import Path

import casatools

from casatools import ctsys
from casatools import image as iatool
from casatools import ms as mstool
from casatasks import msuvbin, tclean, flagdata

## DATA ##

datapath = ctsys.resolve("unittest/msuvbin")

## TOLERANCE ##

epsilon = 4e-5

class test_msuvbin(unittest.TestCase):
    def copy_input_ms(self):
        if Path(self.originalms).is_dir():
            shutil.rmtree(self.originalms)
        shutil.copytree(self.inputms, self.originalms)

    def setUp(self):
        self.basename = "test_msuvbin"
        self.inputms = os.path.join(datapath, "refim_point.ms")
        self.originalms = f"{self.basename}_original.ms"
        self.binnedms = f"{self.basename}_binned.ms"
        self.binnedflaggedms = f"{self.basename}_binned_flagged.ms"
        self.flaggedms = f"{self.basename}_flagged.ms"
        self.original_image = f"{self.basename}_with_original"
        self.binned_image = f"{self.basename}_with_binned"

        # Image parameters.
        self.imsize = 100
        self.cell = "10.0arcsec"
        self.start = "0.8GHz"
        self.width = "1.2GHz"
        self.nchan = 1
        self.ncorr = 1

        # Copy the input MS.
        self.copy_input_ms()

    def tearDown(self):
        # MSs.
        if Path(self.originalms).is_dir():
            shutil.rmtree(self.originalms)
        if Path(self.binnedms).is_dir():
            shutil.rmtree(self.binnedms)
        if Path(self.binnedflaggedms).is_dir():
            shutil.rmtree(self.binnedflaggedms)
        if Path(self.flaggedms).is_dir():
            shtuil.rmtree(self.flaggedms)

        # Images.
        for f in Path(".").glob(f"{self.original_image}*"):
            shutil.rmtree(f)
        for f in Path(".").glob(f"{self.binned_image}*"):
            shutil.rmtree(f)

    def test_msuvbin_bin(self):
        # Bin the visibilities.
        msuvbin(vis=self.originalms, outputvis=self.binnedms, 
                imsize=self.imsize, cell=self.cell, ncorr=self.ncorr, 
                nchan=self.nchan, start=self.start, width=self.width)

        # Deconvolution parameters.
        specmode = "mfs"
        niter = 100
        interpolation = "nearest"
        # Image the binned and original visibilities.
        tclean(vis=self.binnedms, imagename=self.binned_image, 
                imsize=self.imsize, cell=self.cell, specmode=specmode, 
                niter=niter, nchan=self.nchan, start=self.start, width=self.width, 
                interpolation=interpolation)
        tclean(vis=self.originalms, imagename=self.original_image,
                imsize=self.imsize, cell=self.cell, specmode=specmode, 
                niter=niter, nchan=self.nchan, start=self.start, width=self.width,
                interpolation=interpolation)
        # Compare the peak in the resulting images.
        myia = iatool()
        myia.open(f"{self.binned_image}.image")
        binned_stats = myia.statistics()
        myia.close()
        myia.open(f"{self.original_image}.image")
        original_stats = myia.statistics()
        myia.close()
        self.assertTrue(abs(original_stats["max"] - binned_stats["max"])/original_stats["max"] < epsilon)
        # Compare the location of the peak.
        np.testing.assert_array_equal(original_stats["maxpos"], binned_stats["maxpos"], 
                "Position of peak is different")
    
    def test_msuvbin_write_flags_back(self):
        msuvbin(vis=self.originalms, outputvis=self.binnedflaggedms,
                imsize=self.imsize, cell=self.cell, ncorr=self.ncorr,
                nchan=self.nchan, start=self.start, width=self.width)

        # Flag the inner 500 m.
        flagdata(vis=self.binnedflaggedms, mode='manual', uvrange='<500m', flagbackup=False)

        # Check that there is a hole.
        # All data inside the inner 500 m should be flagged.

        myms = mstool()
        myms.open(self.binnedflaggedms)
        flags = myms.getdata("flag")["flag"]
        uvdist = myms.getdata("uvdist")["uvdist"]
        self.assertTrue(np.all(flags[:,:,uvdist<500]))
        # There should be data left outside the hole.
        # Notice that the binned visibilities are on a regular grid,
        # so there's also flagged data outside the hole in bins where 
        # there were no visibilities.
        self.assertFalse(np.all(~flags[:,:,uvdist>500]))
        myms.close()

        # Write the flags back.
        msuvbin(vis=self.originalms, outputvis=self.binnedflaggedms, mode='write_flags_back', flagbackup=False)

        # The size of the hole will be a function of frequency.
        # Check at the longer wavelength.
        myms.open(self.originalms)
        flags = myms.getdata("flag")["flag"]
        uvdist = myms.getdata("uvdist")["uvdist"]
        self.assertTrue(np.all(flags[:,0:7,uvdist<500]))
        self.assertTrue(np.all(flags[:,7:,uvdist<300]))

if __name__ == '__main__':
    unittest.main()
