"""
Branch statement builders - SQLAlchemy style API for branch operations.
Similar to select(), insert(), delete(), update() from SQLAlchemy.
These builders produce SQL strings via compile() and are executed
by passing them to client.execute() or session.execute() as strings.
"""
from typing import Optional, Union, Type
from enum import Enum
from .branch import MergeConflictStrategy
from ._utils import get_table_name as _get_table_name, require_non_empty as _require_non_empty
[docs]
class DiffOutputOption(str, Enum):
"""Diff output format options"""
COUNT = 'count'
LIMIT = 'limit'
FILE = 'file'
AS = 'as'
[docs]
class BranchStatement:
"""Base class for branch statements.
These are NOT SQLAlchemy ClauseElements. They produce raw SQL strings
via compile() and should be executed as: client.execute(str(stmt))
"""
[docs]
def compile(self) -> str:
"""Compile to SQL string."""
raise NotImplementedError
def __str__(self) -> str:
return self.compile()
def __repr__(self) -> str:
return f"<{type(self).__name__}: {self.compile()}>"
[docs]
class CreateTableBranch(BranchStatement):
"""CREATE TABLE BRANCH statement builder."""
[docs]
def __init__(self, target_table: Union[str, Type]):
name = _get_table_name(target_table)
_require_non_empty(name, "target_table")
self._target = name
self._source = None
self._snapshot = None
self._account = None
[docs]
def from_table(self, source: Union[str, Type], snapshot: Optional[str] = None) -> 'CreateTableBranch':
"""Set source table and optional snapshot."""
name = _get_table_name(source)
_require_non_empty(name, "source_table")
self._source = name
self._snapshot = snapshot
return self
[docs]
def to_account(self, account: str) -> 'CreateTableBranch':
"""Set target account for cross-tenant branching (sys tenant only)."""
_require_non_empty(account, "account")
self._account = account
return self
[docs]
def compile(self) -> str:
if not self._source:
raise ValueError("source_table must be set via from_table()")
sql = f"data branch create table {self._target} from {self._source}"
if self._snapshot:
sql += f'{{snapshot="{self._snapshot}"}}'
if self._account:
sql += f" to account {self._account}"
return sql
[docs]
class CreateDatabaseBranch(BranchStatement):
"""CREATE DATABASE BRANCH statement builder."""
[docs]
def __init__(self, target_db: str):
_require_non_empty(target_db, "target_database")
self._target = target_db
self._source = None
self._snapshot = None
self._account = None
[docs]
def from_database(self, source: str, snapshot: Optional[str] = None) -> 'CreateDatabaseBranch':
"""Set source database and optional snapshot."""
_require_non_empty(source, "source_database")
self._source = source
self._snapshot = snapshot
return self
[docs]
def to_account(self, account: str) -> 'CreateDatabaseBranch':
"""Set target account for cross-tenant branching (sys tenant only)."""
_require_non_empty(account, "account")
self._account = account
return self
[docs]
def compile(self) -> str:
if not self._source:
raise ValueError("source_database must be set via from_database()")
sql = f"data branch create database {self._target} from {self._source}"
if self._snapshot:
sql += f'{{snapshot="{self._snapshot}"}}'
if self._account:
sql += f" to account {self._account}"
return sql
[docs]
class DeleteTableBranch(BranchStatement):
"""DELETE TABLE BRANCH statement builder."""
[docs]
def __init__(self, table: Union[str, Type]):
name = _get_table_name(table)
_require_non_empty(name, "table")
self._table = name
[docs]
def compile(self) -> str:
return f"data branch delete table {self._table}"
[docs]
class DeleteDatabaseBranch(BranchStatement):
"""DELETE DATABASE BRANCH statement builder."""
[docs]
def __init__(self, database: str):
_require_non_empty(database, "database")
self._database = database
[docs]
def compile(self) -> str:
return f"data branch delete database {self._database}"
[docs]
class DiffTableBranch(BranchStatement):
"""DIFF TABLE BRANCH statement builder."""
[docs]
def __init__(self, target_table: Union[str, Type]):
name = _get_table_name(target_table)
_require_non_empty(name, "target_table")
self._target = name
self._target_snapshot = None
self._base = None
self._base_snapshot = None
self._output_type = None
self._output_value = None
[docs]
def snapshot(self, name: str) -> 'DiffTableBranch':
"""Set snapshot for target table."""
_require_non_empty(name, "snapshot_name")
self._target_snapshot = name
return self
[docs]
def against(self, base: Union[str, Type], snapshot: Optional[str] = None) -> 'DiffTableBranch':
"""Set base table to compare against."""
name = _get_table_name(base)
_require_non_empty(name, "base_table")
self._base = name
self._base_snapshot = snapshot
return self
[docs]
def output_count(self) -> 'DiffTableBranch':
"""Output only count of differences."""
self._output_type = DiffOutputOption.COUNT
self._output_value = None
return self
[docs]
def output_limit(self, limit: int) -> 'DiffTableBranch':
"""Limit returned difference rows."""
if not isinstance(limit, int) or limit < 0:
raise ValueError("limit must be a non-negative integer")
self._output_type = DiffOutputOption.LIMIT
self._output_value = limit
return self
[docs]
def output_file(self, path: str) -> 'DiffTableBranch':
"""Export differences to file (local path or stage:// URL)."""
_require_non_empty(path, "path")
self._output_type = DiffOutputOption.FILE
self._output_value = path
return self
[docs]
def output_as(self, table_name: str) -> 'DiffTableBranch':
"""Save differences to table (not yet supported by MatrixOne)."""
_require_non_empty(table_name, "table_name")
self._output_type = DiffOutputOption.AS
self._output_value = table_name
return self
[docs]
def compile(self) -> str:
if not self._base:
raise ValueError("base_table must be set via against()")
sql = f"data branch diff {self._target}"
if self._target_snapshot:
sql += f'{{snapshot="{self._target_snapshot}"}}'
sql += f" against {self._base}"
if self._base_snapshot:
sql += f'{{snapshot="{self._base_snapshot}"}}'
if self._output_type == DiffOutputOption.COUNT:
sql += " output count"
elif self._output_type == DiffOutputOption.LIMIT:
sql += f" output limit {self._output_value}"
elif self._output_type == DiffOutputOption.FILE:
sql += f" output file '{self._output_value}'"
elif self._output_type == DiffOutputOption.AS:
sql += f" output as {self._output_value}"
return sql
[docs]
class MergeTableBranch(BranchStatement):
"""MERGE TABLE BRANCH statement builder."""
[docs]
def __init__(self, source_table: Union[str, Type]):
name = _get_table_name(source_table)
_require_non_empty(name, "source_table")
self._source = name
self._target = None
self._strategy = MergeConflictStrategy.SKIP
[docs]
def into(self, target: Union[str, Type]) -> 'MergeTableBranch':
"""Set target table to merge into."""
name = _get_table_name(target)
_require_non_empty(name, "target_table")
self._target = name
return self
[docs]
def when_conflict(self, strategy: Union[str, MergeConflictStrategy]) -> 'MergeTableBranch':
"""Set conflict resolution strategy: 'skip' or 'accept'."""
if isinstance(strategy, str):
try:
strategy = MergeConflictStrategy(strategy)
except ValueError:
raise ValueError(f"Invalid conflict strategy: '{strategy}'. Must be 'skip' or 'accept'.")
self._strategy = strategy
return self
[docs]
def compile(self) -> str:
if not self._target:
raise ValueError("target_table must be set via into()")
return f"data branch merge {self._source} into {self._target} " f"when conflict {self._strategy.value}"
# ---------------------------------------------------------------------------
# Top-level builder functions (like select(), insert(), etc.)
# ---------------------------------------------------------------------------
[docs]
def create_table_branch(target_table: Union[str, Type]) -> CreateTableBranch:
"""Build a CREATE TABLE BRANCH statement.
Example::
stmt = create_table_branch('dev').from_table('prod', snapshot='snap1')
client.execute(str(stmt))
"""
return CreateTableBranch(target_table)
[docs]
def create_database_branch(target_db: str) -> CreateDatabaseBranch:
"""Build a CREATE DATABASE BRANCH statement.
Example::
stmt = create_database_branch('dev_db').from_database('prod_db')
client.execute(str(stmt))
"""
return CreateDatabaseBranch(target_db)
[docs]
def delete_table_branch(table: Union[str, Type]) -> DeleteTableBranch:
"""Build a DELETE TABLE BRANCH statement.
Example::
stmt = delete_table_branch('branch_table')
client.execute(str(stmt))
"""
return DeleteTableBranch(table)
[docs]
def delete_database_branch(database: str) -> DeleteDatabaseBranch:
"""Build a DELETE DATABASE BRANCH statement.
Example::
stmt = delete_database_branch('branch_db')
client.execute(str(stmt))
"""
return DeleteDatabaseBranch(database)
[docs]
def diff_table_branch(target_table: Union[str, Type]) -> DiffTableBranch:
"""Build a DIFF TABLE BRANCH statement.
Example::
stmt = diff_table_branch('t1').against('t2').output_count()
client.execute(str(stmt))
"""
return DiffTableBranch(target_table)
[docs]
def merge_table_branch(source_table: Union[str, Type]) -> MergeTableBranch:
"""Build a MERGE TABLE BRANCH statement.
Example::
stmt = merge_table_branch('source').into('target').when_conflict('accept')
client.execute(str(stmt))
"""
return MergeTableBranch(source_table)