diff --git a/superset-frontend/spec/javascripts/views/CRUD/alert/AlertReportModal_spec.jsx b/superset-frontend/spec/javascripts/views/CRUD/alert/AlertReportModal_spec.jsx index f5edd0719b..b474ea7ea9 100644 --- a/superset-frontend/spec/javascripts/views/CRUD/alert/AlertReportModal_spec.jsx +++ b/superset-frontend/spec/javascripts/views/CRUD/alert/AlertReportModal_spec.jsx @@ -260,4 +260,39 @@ describe('AlertReportModal', () => { expect(addWrapper.find('input[name="grace_period"]')).toExist(); expect(wrapper.find('input[name="grace_period"]')).toHaveLength(0); }); + + it('only allows grace period values > 1', async () => { + const props = { + ...mockedProps, + isReport: false, + }; + + const addWrapper = await mountAndWait(props); + + const input = addWrapper.find('input[name="grace_period"]'); + + input.simulate('change', { target: { name: 'grace_period', value: 7 } }); + expect(input.instance().value).toEqual('7'); + + input.simulate('change', { target: { name: 'grace_period', value: 0 } }); + expect(input.instance().value).toEqual(''); + + input.simulate('change', { target: { name: 'grace_period', value: -1 } }); + expect(input.instance().value).toEqual('1'); + }); + + it('only allows working timeout values > 1', () => { + const input = wrapper.find('input[name="working_timeout"]'); + + input.simulate('change', { target: { name: 'working_timeout', value: 7 } }); + expect(input.instance().value).toEqual('7'); + + input.simulate('change', { target: { name: 'working_timeout', value: 0 } }); + expect(input.instance().value).toEqual(''); + + input.simulate('change', { + target: { name: 'working_timeout', value: -1 }, + }); + expect(input.instance().value).toEqual('1'); + }); }); diff --git a/superset-frontend/src/views/CRUD/alert/AlertReportModal.tsx b/superset-frontend/src/views/CRUD/alert/AlertReportModal.tsx index 45a6199c6b..4e2d288021 100644 --- a/superset-frontend/src/views/CRUD/alert/AlertReportModal.tsx +++ b/superset-frontend/src/views/CRUD/alert/AlertReportModal.tsx @@ -42,6 +42,7 @@ import { } from './types'; const SELECT_PAGE_SIZE = 2000; // temporary fix for paginated query +const TIMEOUT_MIN = 1; type SelectValue = { value: string; @@ -837,6 +838,23 @@ const AlertReportModal: FunctionComponent = ({ updateAlertState(target.name, target.value); }; + const onTimeoutVerifyChange = ( + event: React.ChangeEvent, + ) => { + const { target } = event; + const value = +target.value; + + // Need to make sure grace period is not lower than TIMEOUT_MIN + if (value === 0) { + updateAlertState(target.name, null); + } else { + updateAlertState( + target.name, + value ? Math.max(value, TIMEOUT_MIN) : value, + ); + } + }; + const onSQLChange = (value: string) => { updateAlertState('sql', value || ''); }; @@ -1283,10 +1301,11 @@ const AlertReportModal: FunctionComponent = ({
seconds
@@ -1297,10 +1316,11 @@ const AlertReportModal: FunctionComponent = ({
seconds
diff --git a/superset/reports/schemas.py b/superset/reports/schemas.py index 0a872fab5f..fe9a2fbe7c 100644 --- a/superset/reports/schemas.py +++ b/superset/reports/schemas.py @@ -17,8 +17,9 @@ from typing import Any, Dict, Union from croniter import croniter +from flask_babel import gettext as _ from marshmallow import fields, Schema, validate, validates_schema -from marshmallow.validate import Length, ValidationError +from marshmallow.validate import Length, Range, ValidationError from superset.models.reports import ( ReportRecipientType, @@ -158,14 +159,22 @@ class ReportSchedulePostSchema(Schema): ), ) validator_config_json = fields.Nested(ValidatorConfigJSONSchema) - log_retention = fields.Integer(description=log_retention_description, example=90) + log_retention = fields.Integer( + description=log_retention_description, + example=90, + validate=[Range(min=1, error=_("Value must be greater than 0"))], + ) grace_period = fields.Integer( - description=grace_period_description, example=60 * 60 * 4, default=60 * 60 * 4 + description=grace_period_description, + example=60 * 60 * 4, + default=60 * 60 * 4, + validate=[Range(min=1, error=_("Value must be greater than 0"))], ) working_timeout = fields.Integer( description=working_timeout_description, example=60 * 60 * 1, default=60 * 60 * 1, + validate=[Range(min=1, error=_("Value must be greater than 0"))], ) recipients = fields.List(fields.Nested(ReportRecipientSchema)) @@ -225,15 +234,22 @@ class ReportSchedulePutSchema(Schema): ) validator_config_json = fields.Nested(ValidatorConfigJSONSchema, required=False) log_retention = fields.Integer( - description=log_retention_description, example=90, required=False + description=log_retention_description, + example=90, + required=False, + validate=[Range(min=1, error=_("Value must be greater than 0"))], ) grace_period = fields.Integer( - description=grace_period_description, example=60 * 60 * 4, required=False + description=grace_period_description, + example=60 * 60 * 4, + required=False, + validate=[Range(min=1, error=_("Value must be greater than 0"))], ) working_timeout = fields.Integer( description=working_timeout_description, example=60 * 60 * 1, allow_none=True, required=False, + validate=[Range(min=1, error=_("Value must be greater than 0"))], ) recipients = fields.List(fields.Nested(ReportRecipientSchema), required=False) diff --git a/tests/reports/api_tests.py b/tests/reports/api_tests.py index 7b87f4d9b5..ddd528a8f8 100644 --- a/tests/reports/api_tests.py +++ b/tests/reports/api_tests.py @@ -433,6 +433,8 @@ class TestReportSchedulesApi(SupersetTestCase): "recipient_config_json": {"target": "channel"}, }, ], + "grace_period": 14400, + "working_timeout": 3600, "chart": chart.id, "database": example_db.id, } @@ -443,6 +445,8 @@ class TestReportSchedulesApi(SupersetTestCase): created_model = db.session.query(ReportSchedule).get(data.get("id")) assert created_model is not None assert created_model.name == report_schedule_data["name"] + assert created_model.grace_period == report_schedule_data["grace_period"] + assert created_model.working_timeout == report_schedule_data["working_timeout"] assert created_model.description == report_schedule_data["description"] assert created_model.crontab == report_schedule_data["crontab"] assert created_model.chart.id == report_schedule_data["chart"] @@ -514,6 +518,78 @@ class TestReportSchedulesApi(SupersetTestCase): rv = self.client.post(uri, json=report_schedule_data) assert rv.status_code == 400 + # Test that report can be created with null grace period + report_schedule_data = { + "type": ReportScheduleType.ALERT, + "name": "new3", + "description": "description", + "crontab": "0 9 * * *", + "recipients": [ + { + "type": ReportRecipientType.EMAIL, + "recipient_config_json": {"target": "target@superset.org"}, + }, + { + "type": ReportRecipientType.SLACK, + "recipient_config_json": {"target": "channel"}, + }, + ], + "working_timeout": 3600, + "chart": chart.id, + "database": example_db.id, + } + uri = "api/v1/report/" + rv = self.client.post(uri, json=report_schedule_data) + assert rv.status_code == 201 + + # Test that grace period and working timeout cannot be < 1 + report_schedule_data = { + "type": ReportScheduleType.ALERT, + "name": "new3", + "description": "description", + "crontab": "0 9 * * *", + "recipients": [ + { + "type": ReportRecipientType.EMAIL, + "recipient_config_json": {"target": "target@superset.org"}, + }, + { + "type": ReportRecipientType.SLACK, + "recipient_config_json": {"target": "channel"}, + }, + ], + "working_timeout": -10, + "chart": chart.id, + "database": example_db.id, + } + uri = "api/v1/report/" + rv = self.client.post(uri, json=report_schedule_data) + assert rv.status_code == 400 + + report_schedule_data = { + "type": ReportScheduleType.ALERT, + "name": "new3", + "description": "description", + "crontab": "0 9 * * *", + "recipients": [ + { + "type": ReportRecipientType.EMAIL, + "recipient_config_json": {"target": "target@superset.org"}, + }, + { + "type": ReportRecipientType.SLACK, + "recipient_config_json": {"target": "channel"}, + }, + ], + "grace_period": -10, + "working_timeout": 3600, + "chart": chart.id, + "database": example_db.id, + } + uri = "api/v1/report/" + rv = self.client.post(uri, json=report_schedule_data) + assert rv.status_code == 400 + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_create_report_schedule_chart_dash_validation(self): """