
from datetime import UTC, datetime

from airflow import models
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator


def parse_result(result):
    return list(result)[0][0]


with models.DAG(
    dag_id="trino",
    schedule=None,
    start_date=datetime(2024, 1, 1, tzinfo=UTC),
    catchup=False,
    tags=["trino"],
) as dag:
    trino_query = SQLExecuteQueryOperator(
        task_id="trino_query",
        sql="{{ dag_run.conf['sql'] }}",
        handler=parse_result,
        conn_id="trino",
    )

    save_result = SQLExecuteQueryOperator(
        task_id="save_result",
        sql=(
            "update ray.query_run"
            " set"
            " result = ('{{ task_instance.xcom_pull(task_ids='trino_query') }}' = '{{ dag_run.conf['expected_result'] }}'),"
            " return_value = json_build_object('value', '{{ task_instance.xcom_pull(task_ids='trino_query') }}'),"
            " result_status = 'success',"
            " query_run_status = 'success'"
            " where query_run_id = '{{ dag_run.conf['query_run_id'] }}'"
        ),
        conn_id="ray_db",
    )

    trino_query >> save_result
