Let’s work through a Text2SQL use case where we are starting from scratch without a nice and clean dataset of questions, SQL queries, or expected responses.
Let’s start by implementing a simple text2sql logic.
Copy
Ask AI
import osimport openaiclient = openai.AsyncClient()columns = conn.query("DESCRIBE nba").to_df().to_dict(orient="records")# We will use GPT4o to startTASK_MODEL = "gpt-4o"CONFIG = {"model": TASK_MODEL}system_prompt = ( "You are a SQL expert, and you are given a single table named nba with the following columns:\n" f"{",".join(column["column_name"] + ": " + column["column_type"] for column in columns)}\n" "Write a SQL query corresponding to the user's request. Return just the query text, " "with no formatting (backticks, markdown, etc.).")async def generate_query(input): response = await client.chat.completions.create( model=TASK_MODEL, temperature=0, messages=[ { "role": "system", "content": system_prompt, }, { "role": "user", "content": input, }, ], ) return response.choices[0].message.content
Copy
Ask AI
query = await generate_query("Who won the most games?")print(query)
SELECT Team, COUNT(*) AS Wins FROM nba WHERE WINorLOSS = ‘W’ GROUP BY Team ORDER BY Wins DESC LIMIT 1;
To setup an experiment we need a dataset, task and evaluator. Let’s setup each.Setup dataset
Copy
Ask AI
questions = [ "Which team won the most games?", "Which team won the most games in 2015?", "Who led the league in 3 point shots?", "Which team had the biggest difference in records across two consecutive years?", "What is the average number of free throws per year?",]
Let’s store the data above as a versioned dataset in Arize AX.
Copy
Ask AI
arize_client = ArizeDatasetsClient( developer_key=os.environ.get("ARIZE_DEVELOPER_KEY"), api_key=os.environ.get("ARIZE_API_KEY"),)# Create a dataset from a DataFrame add your own data heretest_df = pd.DataFrame([{"question": question} for question in questions])dataset_id = arize_client.create_dataset( space_id=space_id, dataset_name=dataset_name, dataset_type=GENERATIVE, data=test_df,)dataset = arize_client.get_dataset(space_id=space_id, dataset_id=dataset_id)dataset.head()
Setup taskNext, we’ll define the task. The task is to generate SQL queries from natural language questions.
Setup evaluatorFinally, we’ll define the evaluator. We’ll use the following simple scoring functions to see if the generated SQL queries are correct.
Copy
Ask AI
# Test if there are no sql execution errorsdef no_error(output): output = json.loads(output) return 1.0 if output.get("error") is None else 0.0# Test if the query has resultsdef has_results(output): output = json.loads(output) results = output.get("results") has_results = results is not None and len(results) > 0 return 1.0 if has_results else 0.0
Now that we ran the initial evaluation, it looks like three of the results are valid, one produces SQL errors, and one has no results.The second query for `Which team won the most games in 2015` looks for Date LIKE '2015%' which is not correct. The fourth query does not have TEAM in the group by clause.Let’s try to improve the prompt with few-shot examples and see if we can get better results.
Copy
Ask AI
samples = conn.query("SELECT * FROM nba LIMIT 1").to_df().to_dict(orient="records")[0]sample_rows = "\n".join( f"{column['column_name']} | {column['column_type']} | {samples[column['column_name']]}" for column in columns)system_prompt = ( "You are a SQL expert, and you are given a single table named nba with the following columns:\n\n" "Column | Type | Example\n" "-------|------|--------\n" f"{sample_rows}\n" "\n" "Write a DuckDB SQL query corresponding to the user's request. " "Return just the query text, with no formatting (backticks, markdown, etc.).")async def generate_query(input): response = await client.chat.completions.create( model=TASK_MODEL, temperature=0, messages=[ { "role": "system", "content": system_prompt, }, { "role": "user", "content": input, }, ], ) return response.choices[0].message.contentprint(await generate_query("Which team won the most games in 2015?"))
SELECT Team, COUNT(*) AS Wins FROM nba WHERE WINorLOSS = ‘W’ AND Date LIKE ’%/15’ GROUP BY Team ORDER BY Wins DESC LIMIT 1;
Looking better! Finally, let’s add a scoring function that compares the results, if they exist, with the expected results. And then we can run this as another experiment and compare the results.
Copy
Ask AI
from phoenix.evals.models import OpenAIModelfrom phoenix.evals.classify import llm_classifyfrom arize.experimental.datasets.experiments.types import EvaluationResultIS_SQL_EVAL_TEMPLATE = """You are a SQL expert, is the following a valid SQL query that executes without errors? Return the single workd "valid" if is valid, and "invalid" if it is not.[BEGIN SQL QUERY]{query}[END SQL QUERY]"""def check_is_sql(output): output = json.loads(output) query = output.get("query") df_in = pd.DataFrame({"query": query}, index=[0]) if query else None eval_df = llm_classify( dataframe=df_in, template=IS_SQL_EVAL_TEMPLATE, model=OpenAIModel(model="gpt-4o"), rails=["valid", "invalid"], provide_explanation=True, ) # return score, label, explanation return EvaluationResult( score=1, label=eval_df["label"][0], explanation=eval_df["explanation"][0], )experiment = arize_client.run_experiment( space_id=space_id, dataset_id=dataset_id, task=task, evaluators=[no_error, has_results, check_is_sql], experiment_name="text2sql_test_new_prompt_and_eval-6",)
You can see that the newer SQL has improved some cases, but there are still some other errors to iron out. As you experiment with different models, prompts, and techniques, you can continuously optimize your applications until they reach the performance thresholds you want.