Skip to content

Commit 7142c62

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: GenAI Client(evals) - auto-infer metric/candidate and validate inputs for generate_loss_clusters
PiperOrigin-RevId: 897413410
1 parent 727b8e0 commit 7142c62

File tree

3 files changed

+370
-10
lines changed

3 files changed

+370
-10
lines changed

tests/unit/vertexai/genai/test_evals.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,170 @@ def test_loss_analysis_result_show(self, capsys):
569569
assert "c1" in captured.out
570570

571571

572+
def _make_eval_result(
573+
metrics=None,
574+
candidate_names=None,
575+
):
576+
"""Helper to create an EvaluationResult with the given metrics and candidates."""
577+
metrics = metrics or ["task_success_v1"]
578+
candidate_names = candidate_names or ["agent-1"]
579+
580+
metric_results = {}
581+
for m in metrics:
582+
metric_results[m] = common_types.EvalCaseMetricResult(metric_name=m)
583+
584+
eval_case_results = [
585+
common_types.EvalCaseResult(
586+
eval_case_index=0,
587+
response_candidate_results=[
588+
common_types.ResponseCandidateResult(
589+
response_index=0,
590+
metric_results=metric_results,
591+
)
592+
],
593+
)
594+
]
595+
metadata = common_types.EvaluationRunMetadata(
596+
candidate_names=candidate_names,
597+
)
598+
return common_types.EvaluationResult(
599+
eval_case_results=eval_case_results,
600+
metadata=metadata,
601+
)
602+
603+
604+
class TestResolveMetricName:
605+
"""Unit tests for _resolve_metric_name."""
606+
607+
def test_none_returns_none(self):
608+
assert _evals_utils._resolve_metric_name(None) is None
609+
610+
def test_string_passes_through(self):
611+
assert _evals_utils._resolve_metric_name("task_success_v1") == "task_success_v1"
612+
613+
def test_metric_object_extracts_name(self):
614+
metric = common_types.Metric(name="multi_turn_task_success_v1")
615+
assert _evals_utils._resolve_metric_name(metric) == "multi_turn_task_success_v1"
616+
617+
def test_object_with_name_attr(self):
618+
"""Tests that any object with a .name attribute works (e.g., LazyLoadedPrebuiltMetric)."""
619+
620+
class FakeMetric:
621+
name = "tool_use_quality_v1"
622+
623+
assert _evals_utils._resolve_metric_name(FakeMetric()) == "tool_use_quality_v1"
624+
625+
def test_lazy_loaded_prebuilt_metric_resolves_versioned_name(self):
626+
"""Tests that LazyLoadedPrebuiltMetric resolves to the versioned API spec name."""
627+
628+
class FakeLazyMetric:
629+
name = "MULTI_TURN_TASK_SUCCESS"
630+
631+
def _get_api_metric_spec_name(self):
632+
return "multi_turn_task_success_v1"
633+
634+
assert (
635+
_evals_utils._resolve_metric_name(FakeLazyMetric())
636+
== "multi_turn_task_success_v1"
637+
)
638+
639+
def test_lazy_loaded_prebuilt_metric_falls_back_to_name(self):
640+
"""Tests fallback to .name when _get_api_metric_spec_name returns None."""
641+
642+
class FakeLazyMetricNoSpec:
643+
name = "CUSTOM_METRIC"
644+
645+
def _get_api_metric_spec_name(self):
646+
return None
647+
648+
assert (
649+
_evals_utils._resolve_metric_name(FakeLazyMetricNoSpec()) == "CUSTOM_METRIC"
650+
)
651+
652+
653+
class TestResolveLossAnalysisConfig:
654+
"""Unit tests for _resolve_loss_analysis_config."""
655+
656+
def test_auto_infer_single_metric_and_candidate(self):
657+
eval_result = _make_eval_result(
658+
metrics=["task_success_v1"], candidate_names=["agent-1"]
659+
)
660+
resolved = _evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
661+
assert resolved.metric == "task_success_v1"
662+
assert resolved.candidate == "agent-1"
663+
664+
def test_explicit_metric_and_candidate(self):
665+
eval_result = _make_eval_result(
666+
metrics=["m1", "m2"], candidate_names=["c1", "c2"]
667+
)
668+
resolved = _evals_utils._resolve_loss_analysis_config(
669+
eval_result=eval_result, metric="m1", candidate="c2"
670+
)
671+
assert resolved.metric == "m1"
672+
assert resolved.candidate == "c2"
673+
674+
def test_config_provides_metric_and_candidate(self):
675+
eval_result = _make_eval_result(metrics=["m1"], candidate_names=["c1"])
676+
config = common_types.LossAnalysisConfig(
677+
metric="m1", candidate="c1", predefined_taxonomy="my_taxonomy"
678+
)
679+
resolved = _evals_utils._resolve_loss_analysis_config(
680+
eval_result=eval_result, config=config
681+
)
682+
assert resolved.metric == "m1"
683+
assert resolved.candidate == "c1"
684+
assert resolved.predefined_taxonomy == "my_taxonomy"
685+
686+
def test_explicit_args_override_config(self):
687+
eval_result = _make_eval_result(
688+
metrics=["m1", "m2"], candidate_names=["c1", "c2"]
689+
)
690+
config = common_types.LossAnalysisConfig(metric="m1", candidate="c1")
691+
resolved = _evals_utils._resolve_loss_analysis_config(
692+
eval_result=eval_result, config=config, metric="m2", candidate="c2"
693+
)
694+
assert resolved.metric == "m2"
695+
assert resolved.candidate == "c2"
696+
697+
def test_error_multiple_metrics_no_explicit(self):
698+
eval_result = _make_eval_result(metrics=["m1", "m2"], candidate_names=["c1"])
699+
with pytest.raises(ValueError, match="multiple metrics"):
700+
_evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
701+
702+
def test_error_multiple_candidates_no_explicit(self):
703+
eval_result = _make_eval_result(metrics=["m1"], candidate_names=["c1", "c2"])
704+
with pytest.raises(ValueError, match="multiple candidates"):
705+
_evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
706+
707+
def test_error_invalid_metric(self):
708+
eval_result = _make_eval_result(metrics=["m1"], candidate_names=["c1"])
709+
with pytest.raises(ValueError, match="not found in eval_result"):
710+
_evals_utils._resolve_loss_analysis_config(
711+
eval_result=eval_result, metric="nonexistent"
712+
)
713+
714+
def test_error_invalid_candidate(self):
715+
eval_result = _make_eval_result(metrics=["m1"], candidate_names=["c1"])
716+
with pytest.raises(ValueError, match="not found in eval_result"):
717+
_evals_utils._resolve_loss_analysis_config(
718+
eval_result=eval_result, candidate="nonexistent"
719+
)
720+
721+
def test_no_candidates_defaults_to_candidate_1(self):
722+
eval_result = _make_eval_result(metrics=["m1"], candidate_names=[])
723+
eval_result = eval_result.model_copy(
724+
update={"metadata": common_types.EvaluationRunMetadata()}
725+
)
726+
resolved = _evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
727+
assert resolved.metric == "m1"
728+
assert resolved.candidate == "candidate_1"
729+
730+
def test_no_eval_case_results_raises(self):
731+
eval_result = common_types.EvaluationResult()
732+
with pytest.raises(ValueError, match="no metric results"):
733+
_evals_utils._resolve_loss_analysis_config(eval_result=eval_result)
734+
735+
572736
class TestEvals:
573737
"""Unit tests for the GenAI client."""
574738

vertexai/_genai/_evals_utils.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,148 @@ def _display_loss_analysis_result(
449449
print(df.to_string()) # pylint: disable=print-function
450450

451451

452+
def _resolve_metric_name(
453+
metric: Optional[Any],
454+
) -> Optional[str]:
455+
"""Extracts a metric name string from a metric argument.
456+
457+
Accepts a string, a Metric object, or a LazyLoadedPrebuiltMetric
458+
(RubricMetric) and returns the metric name as a string.
459+
460+
For LazyLoadedPrebuiltMetric (e.g., RubricMetric.MULTI_TURN_TASK_SUCCESS),
461+
this resolves to the API metric spec name (e.g.,
462+
"multi_turn_task_success_v1") so it matches the keys in eval results.
463+
464+
Args:
465+
metric: A metric name string, Metric object, RubricMetric enum value, or
466+
None.
467+
468+
Returns:
469+
The metric name as a string, or None if metric is None.
470+
"""
471+
if metric is None:
472+
return None
473+
if isinstance(metric, str):
474+
return metric
475+
# LazyLoadedPrebuiltMetric: resolve to versioned API spec name.
476+
if hasattr(metric, "_get_api_metric_spec_name"):
477+
spec_name: Optional[str] = metric._get_api_metric_spec_name()
478+
if spec_name:
479+
return spec_name
480+
# Metric objects and other types with a .name attribute.
481+
if hasattr(metric, "name"):
482+
return str(metric.name)
483+
return str(metric)
484+
485+
486+
def _resolve_loss_analysis_config(
487+
eval_result: types.EvaluationResult,
488+
config: Optional[types.LossAnalysisConfig] = None,
489+
metric: Optional[str] = None,
490+
candidate: Optional[str] = None,
491+
) -> types.LossAnalysisConfig:
492+
"""Resolves and validates the LossAnalysisConfig for generate_loss_clusters.
493+
494+
Auto-infers `metric` and `candidate` from the EvaluationResult when not
495+
explicitly provided. Validates that provided values exist in the eval result.
496+
497+
Args:
498+
eval_result: The EvaluationResult from client.evals.evaluate().
499+
config: Optional explicit LossAnalysisConfig. If provided, metric and
500+
candidate from config take precedence over the separate arguments.
501+
metric: Optional metric name override.
502+
candidate: Optional candidate name override.
503+
504+
Returns:
505+
A resolved LossAnalysisConfig with metric and candidate populated.
506+
507+
Raises:
508+
ValueError: If metric/candidate cannot be inferred or are invalid.
509+
"""
510+
# Start from config if provided, otherwise create a new one.
511+
if config is not None:
512+
resolved_metric = metric or config.metric
513+
resolved_candidate = candidate or config.candidate
514+
resolved_config = config.model_copy(
515+
update={"metric": resolved_metric, "candidate": resolved_candidate}
516+
)
517+
else:
518+
resolved_config = types.LossAnalysisConfig(metric=metric, candidate=candidate)
519+
520+
# Collect available metric names from the eval result.
521+
available_metrics: set[str] = set()
522+
if eval_result.eval_case_results:
523+
for case_result in eval_result.eval_case_results:
524+
for resp_cand in case_result.response_candidate_results or []:
525+
for m_name in (resp_cand.metric_results or {}).keys():
526+
available_metrics.add(m_name)
527+
528+
# Collect available candidate names from metadata.
529+
available_candidates: list[str] = []
530+
if eval_result.metadata and eval_result.metadata.candidate_names:
531+
available_candidates = list(eval_result.metadata.candidate_names)
532+
533+
# Auto-infer metric if not provided.
534+
if not resolved_config.metric:
535+
if len(available_metrics) == 1:
536+
resolved_config = resolved_config.model_copy(
537+
update={"metric": next(iter(available_metrics))}
538+
)
539+
elif len(available_metrics) == 0:
540+
raise ValueError(
541+
"Cannot infer metric: no metric results found in eval_result."
542+
" Please provide metric explicitly via"
543+
" config=types.LossAnalysisConfig(metric='...')."
544+
)
545+
else:
546+
raise ValueError(
547+
"Cannot infer metric: multiple metrics found in eval_result:"
548+
f" {sorted(available_metrics)}. Please provide metric"
549+
" explicitly via config=types.LossAnalysisConfig(metric='...')."
550+
)
551+
552+
# Validate metric if provided explicitly.
553+
if available_metrics and resolved_config.metric not in available_metrics:
554+
raise ValueError(
555+
f"Metric '{resolved_config.metric}' not found in eval_result."
556+
f" Available metrics: {sorted(available_metrics)}."
557+
)
558+
559+
# Auto-infer candidate if not provided.
560+
if not resolved_config.candidate:
561+
if len(available_candidates) == 1:
562+
resolved_config = resolved_config.model_copy(
563+
update={"candidate": available_candidates[0]}
564+
)
565+
elif len(available_candidates) == 0:
566+
# Fallback: use default candidate naming convention from SDK.
567+
resolved_config = resolved_config.model_copy(
568+
update={"candidate": "candidate_1"}
569+
)
570+
logger.warning(
571+
"No candidate names found in eval_result.metadata."
572+
" Defaulting to 'candidate_1'. If this is incorrect, provide"
573+
" candidate explicitly via"
574+
" config=types.LossAnalysisConfig(candidate='...')."
575+
)
576+
else:
577+
raise ValueError(
578+
"Cannot infer candidate: multiple candidates found in"
579+
f" eval_result: {available_candidates}. Please provide"
580+
" candidate explicitly via"
581+
" config=types.LossAnalysisConfig(candidate='...')."
582+
)
583+
584+
# Validate candidate if provided explicitly and candidates are known.
585+
if available_candidates and resolved_config.candidate not in available_candidates:
586+
raise ValueError(
587+
f"Candidate '{resolved_config.candidate}' not found in"
588+
f" eval_result. Available candidates: {available_candidates}."
589+
)
590+
591+
return resolved_config
592+
593+
452594
def _poll_operation(
453595
api_client: BaseApiClient,
454596
operation: types.GenerateLossClustersOperation,

0 commit comments

Comments
 (0)