##########################################################################
# test_task_msuvbinflag.py
#
# Copyright (C) 2018
# Associated Universities, Inc. Washington DC, USA.
#
# This script is free software; you can redistribute it and/or modify it
# under the terms of the GNU Library General Public License as published by
# the Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This library is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Library General Public
# License for more details.
#
# [Add the link to the JIRA ticket here once it exists]
#
# Based on the requirements listed in plone found here:
# https://casadocs.readthedocs.io/en/stable/api/tt/casatasks.manipulation.msuvbinflag.html
#
#
##########################################################################
import sys
import os
import numpy as np
import unittest
import shutil
import casatools
from casatasks import flagdata
from casatasks import msuvbinflag, msuvbin

tb = casatools.table()


ctsys_resolve = casatools.ctsys.resolve

## DATA ## 

datapath = ctsys_resolve('measurementset/evla')


class test_msuvbinflag(unittest.TestCase):

    origms = 'evla_15A-397_spw1_7_scan_4_6.ms'
    gridms='test_flag_uvgrid.ms'
    def setUp(self):
         msuvbin(vis=datapath+'/'+self.origms, outputvis=self.gridms, field='1', spw='0', imsize=300, cell='0.5arcsec', ncorr=2, nchan=2, mode='bin')

    def tearDown(self):
        shutil.rmtree(self.gridms)


    def test_radial_perplane_method(self):
        """
        test_radial_perplane_method: test the radial profile  per plane 
        """
        self.setUp()
        tb.open(self.gridms)
        flgorig=tb.getcol('FLAG')
        tb.done()
        msuvbinflag(binnedvis=self.gridms, method='radial_per_plane',  nsigma=3.0)
        tb.open(self.gridms)
        flgpost=tb.getcol('FLAG')
        tb.done()
        print(f'Number of flags post flagging {flgpost.sum()} and pre-flagging {flgorig.sum()}')
        self.assertGreaterEqual(flgpost.sum(), flgorig.sum())
    @unittest.skipIf(True, "Skip test. till impemented")
    def test_regionalMean_method(self):
        res = flagdata(vis=self.ms, mode='summary')
        self.assertEqual(res['flagged'], 16671)
        msuvbinflag(binnedvis=self.ms,  method='regionalMean', sizeRegion=20, sigma=5, ignorPoint=True)
        res = flagdata(vis=self.ms, mode='summary')
        self.assertEqual(res['flagged'], 15223)


    def test_radial_mean_annular_method(self):
        """
        test_radial_mean_annular_method: test the radial one plane mean annular algorithm
        """
        self.setUp()
        tb.open(self.gridms)
        flgorig=tb.getcol('FLAG')
        tb.done()
        msuvbinflag(binnedvis=self.gridms, method='radial_mean_annular',  nsigma=3.0, doplot=True)
        tb.open(self.gridms)
        flgpost=tb.getcol('FLAG')
        tb.done()
        print(f'Number of flags post flagging {flgpost.sum()} and pre-flagging {flgorig.sum()}')
        self.assertGreaterEqual(flgpost.sum(), flgorig.sum())
   
    @unittest.skipIf(True, "Skip test. Not implemented")
    def test_gradient_method(self):
        res = flagdata(vis=self.ms, mode='summary')
        self.assertEqual(res['flagged'], 14324)
        msuvbinflag(binnedvis=self.ms, method='gradient')
        res = flagdata(vis=self.ms, mode='summary')
        self.assertEqual(res['flagged'], 13324)


if __name__ == '__main__':
    unittest.main()
        