DEV Community

HAP
HAP

Posted on • Edited on

Record Insert with Relational Validation in One SQL Statement

With Django implementation.

Hello, folks! I'm back with another tidbit regarding extending Django to make it more useful. This time it's regarding objects that are created in a transaction. This can be problematic in Django as it sets foreign key constraints to DEFERRED INITIALLY DEFERRED. "Why?" I hear you ask. Because when constraints are deferred, they are not evaluated until the transaction is committed. This can lead to issues like exceptions not being thrown when they should (especially in test-harness situations like pytest, tox and the like).

So let's say you have a situation where you have models defined similar to these:

class Company(models.Model):
    id = models.BigAutoField(primary_key=True)
    name = models.TextField(null=False)
    ...

class Person(models.Model):
    id = models.BigAutoField(primary_key=True)
    surname = models.TextField(null=False)
    forename = models.TextField()
    midname = models.TextField()
    ...

class Employee(models.Model):
    id = models.BigAutoField(primary_key=True)
    company = models.ForeignKey("Company", null=False, on_delete=models.CASCADE)
    person = models.ForeignKey("Person", null=False, on_delete=models.CASCADE)
    start_date = models.DateTimeField(null=False)
    end_date = models.DatetimeField(null=False)
    ...
Enter fullscreen mode Exit fullscreen mode

Your standard-ish employee example, right? So let's say we have a company record:

{"id": 10, "name": Saskatoon Widgets, Inc."}
Enter fullscreen mode Exit fullscreen mode

And a person record:

{"id": 45651, "surname": "Bond", "forename": "James", "midname": "Pootwaddle"}
Enter fullscreen mode Exit fullscreen mode

So now, he got hired and we want to make him an employee. What if this is an API call? How will we know if the foreign key values are correct? I know you are all like Horshack and yelling "Serializers"! But, what if you're not validating using serializers? (It could happen!) Just follow along with me for a bit longer.

In autocommit mode (or if you're using short-lived transactions and you're not kicking off any other actions that would be inside of the transaction) the database driver would return an IntegrityError exception that could be handled.

def create_employee(emp_record: Dict):
    try:
        emp = Employee.objects.create(**emp_record)
    except IntegrityError as e:
        return Http422Response(str(e))
    return HttpResponse(emp)

...

rec = {"company": 10, "person": -1, ...}
create_employee(rec) # You get the 422 in autocommit or transaction execution
Enter fullscreen mode Exit fullscreen mode

But what if you use the newly created emp object with other functions while within the transaction?

def create_employee(emp_record: Dict):
    try:
        emp = Employee.objects.create(**emp_record)
        queue_new_hire_actions(emp) # Uh-oh!
    except IntegrityError as e:
        return Http422Response(str(e))
    return HttpResponse(emp)

...

rec = {"company": 10, "person": -1, ...}
create_employee(rec) 
Enter fullscreen mode Exit fullscreen mode

So you'd get some sort of exception here which may or may not get handled in the defined exception block above. But you also now have a queued record that will fail later as well.

So if you use serializers for validation or if you make your own validation functions, you run the risk of hitting the DB with extra queries. But in a situation where there's a lot of activity against an endpoint, you end up throwing a lot of extra query traffic against the database.

There is an alternative. Use a CTE to return the inserted record and use that record data to check the related tables. Something like this:

with new_rec as (
insert into employee (company_id, person_id, start_date)
values (10, -1, now())
returning *
)
select nr.id,
       nr.start_date, 
       cmp.id as "company_id",
       prs.id as "person_id"
  from new_rec as nr
  left
  join company as cmp
    on cmp.id = nr.company_id
  left
  join person as prs
    on prs.id = nr.person_id;
Enter fullscreen mode Exit fullscreen mode

So this would return a record that would match the class attributes for the Employee model. But now we can check for bad data immediately because, in this case, the returned "person_id" would be null which we could check immediately after create.

This can be extended to fetching immediately related records as well (that is, not walking down all relations).

with new_rec as (
insert into employee (company_id, person_id, start_date)
values (10, -1, now())
returning *
)
select nr.id,
       nr.start_date, 
       cmp.id as "company_id",
       row_to_json(cmp) as "company",
       prs.id as "person_id",
       row_to_json(prs) as "person"
  from new_rec as nr
  left
  join company as cmp
    on cmp.id = nr.company_id
  left
  join person as prs
    on prs.id = nr.person_id;
Enter fullscreen mode Exit fullscreen mode

So now you get the keys and the table record (as json) on query return. If missing, the key and the record will be null.

This type of query could more easily be done with SQLAlchemy, but we're talking about Django here. So we have to use some raw SQL building and (eventually) return a model instance.

Here's what I did for a Django implementation:

The input to the whole thing will be a model class for the target table and a dict for the data to be inserted.

These are the imports I've used along with some module-level globals. The ALIASES list is used to grab table aliases that will be consistently be used in the statement build.

import os
from datetime import datetime, timezone
from enum import Enum
from typing import List

from django.db import IntegrityError, connection, models

LETTERS = "abcdefghijklmnopqrstuvwxyz"
ALIASES = [f"{'_' * i}{letter}" for i in range(1, 3) for letter in LETTERS]
Enter fullscreen mode Exit fullscreen mode

So, first, I need to ensure that the model defaults are applied to the record.

def apply_model_defaults(model: models.Model, record: dict) -> dict:
    now = datetime.now(tz=timezone.utc)
    full_record = record.copy()

    for field in model._meta.concrete_fields:
        if field.primary_key:
            continue

        if isinstance(field, models.ForeignKey):
            fname = field.get_attname_column()[-1]
        else:
            fname = field.name

        if getattr(field, "auto_now_add", False) or getattr(
            field, "auto_now", False
        ):
            default = now
        else:
            default = field.default

        if default != models.NOT_PROVIDED:
            if fname not in record:
                if callable(default):
                    full_record[fname] = default()
                elif isinstance(default, Enum):
                    full_record[fname] = str(default.value)
                else:
                    full_record[fname] = default

        if not field.null and full_record.get(fname) is None:
            raise ValueError(f"{model.__name__}.{fname} cannot be None.")

    return full_record
Enter fullscreen mode Exit fullscreen mode

And I also need a function to resolve a model reference in the record to the model's primary key value.

def fk_or_model_pk(fkeys: List[models.Field], record: dict) -> dict:
    for f in fkeys:
        if f.name in record:
            if isinstance(record[f.name], models.Model):
                record[f.name] = getattr(
                    record[f.name], f.target_field.name, None
                )
    return record
Enter fullscreen mode Exit fullscreen mode

Next, I need to be able to generate the insert statement.

def build_insert(
    model: models.Model, fkeys: List[models.Field], record: dict
) -> str:
    """Build the base insert."""
    field_to_target_col = {f.name: f.get_attname_column()[-1] for f in fkeys}
    return """
insert into {table} ({column_list})
values ({data_list})
returning *
    """.format(
        table=model._meta.db_table,
        column_list=", ".join(
            field_to_target_col.get(col, col) for col in record
        ),
        data_list=", ".join(f"%({col})s" for col in record),
    )
Enter fullscreen mode Exit fullscreen mode

I need an select statement generator that will wrap the insert in a CTE, then select all of the foreign key table info that I need.

def build_insert_validated_select(
    model: models.Model, record: dict, fetch_related: bool = False
) -> str:
    """Wrap the base insert in a CTE for related table data verification.

    Immediate relations of the model can be fetched as json. (fetch_related)
    """
    remote_fields = []
    # Get the foreign key fields and build a dict holding
    # The necessary parts to generate the select columns 
    # and left-joins
    fk_fields = [
        {
            "remote_alias": ALIASES[fnum],
            "join_local_key": f"_nr.{f.get_attname_column()[-1]}",
            "local_key": f.get_attname_column()[-1],
            "remote_table": f.related_model._meta.db_table,
            "remote_key": f"{ALIASES[fnum]}.{f.target_field.name}",
            "remote_row": f"row_to_json({ALIASES[fnum]}) as {f.related_model._meta.db_table}_rec",  # noqa E501
            "_field": f,
        }
        for fnum, f in enumerate(model._meta.concrete_fields)
        if isinstance(f, models.ForeignKey)
    ]
    # Local fields are all defined fields that are
    # not ForeignKey instances.
    local_fields = [
        f"_nr.{f.name}"
        for f in model._meta.concrete_fields
        if not isinstance(f, models.ForeignKey)
    ]
    # Build a list of remote fields with aliases from the 
    # fk_fields list
    remote_fields = [
        f"{fk['remote_key']} as {fk['local_key']}" for fk in fk_fields
    ]
    # If fetch related, include the row_to_json() calls
    if fetch_related:
        remote_fields.extend(fk["remote_row"] for fk in fk_fields)
    # Build the left joins from the fk_fields
    if fk_fields:
        left_joins = f"{os.linesep}  ".join(
            "left join {remote_table} as {remote_alias} "
            "on {remote_key} = {join_local_key}".format(**fk)
            for fk in fk_fields
        )
    # Some pretty print formatting
    sep_indent = f",{os.linesep}       "
    select_cols = sep_indent.join(local_fields + remote_fields)
    insert_sql = build_insert(model, [f["_field"] for f in fk_fields], record)
    return f"""
with new_rec as (
{insert_sql}
)
select {select_cols}
  from new_rec as _nr
  {left_joins}
;
    """
Enter fullscreen mode Exit fullscreen mode

And finally, I need the main function call that will build and execute the statement and perform the post-insert data check.

def validated_create(
    model: models.Model, record: dict, fetch_related: bool = False
) -> models.Model:
    """Create record for model and verify existence of related data."""
    # Apply the model defaults to the data record dict
    record = apply_model_defaults(model, record)
    fk_fields = [
        f
        for f in model._meta.concrete_fields
        if isinstance(f, models.ForeignKey)
    ]
    # Resolve model references to primary key values
    record = fk_or_model_pk(fk_fields, record)

    # Build the statement
    val_ins_sel_sql = build_insert_validated_select(
        model, record, fetch_related=fetch_related
    )
    val_rec = None

    # Execute the statement and fetch the result 
    # as a dict
    with connection.cursor() as cur:
        cur.execute(val_ins_sel_sql, record)
        val_rec = dict(zip([d[0] for d in cur.description], cur.fetchone()))

    # Validate foreign key existence for any foreign keys in the input.
    # This is necessary in case the SQL is executed as part of a transaction
    # which will result in deferred constraint validation as django sets
    # foreign key constraints as deferred.
    for f in fk_fields:
        # Resolve key names for input and output
        # foreign key references
        if f.name in record:
            fkname = f.name
        elif f.get_attname_column()[-1] in record:
            fkname = f.get_attname_column()[-1]
        else:
            continue
        vfkname = f.get_attname_column()[-1]

        # Check to see if the values of the foreign
        # keys are different. If so, throw exception.
        # This is how the integrity checking is done
        # during a transaction.
        if val_rec[vfkname] != record[fkname]:
            related_table = f"{f.related_model._meta.db_table}"
            msg = "is not present in table"
            raise IntegrityError(
                f"Key ({fkname})=({record[fkname]}) {msg} {related_table}"
            )

        # instantiate related model reference if it exists and is not None
        remote_table_name = f.related_model._meta.db_table
        remote_table_ref = f"{remote_table_name}_rec"
        if remote_table_ref in val_rec:
            remote_table_rec = val_rec.pop(remote_table_ref)
            if not remote_table_rec:
                val_rec[f.name] = f.related_model(**remote_table_rec)

    # return a target model instance using the 
    # fetched data
    return model(**val_rec)
Enter fullscreen mode Exit fullscreen mode

So probing the model class structures, I can now build sql that should work for any properly defined Django ORM model class.

The point behind all of this is to validate related record existence without sending a query per model.

Top comments (0)