Twotm8 (p.2): Postgres and OpenSSL

scalascala3scala.jsscala-nativefly.iosn-bindgencpostgresopenssljwtseries:twotm8

Series TL;DR

Interfacing with native code

The general interaction with the various native libraries that we'll use is described in this diagram:

  • Compile and build time are conflated - the bindings are generated only at build time, and the application code can be compiled against the bindings, which will have (hopefully) rich type information

  • Binding generator may as well be human (you, yes, you!), if the interface is small and simple. But it can also be generated automatically, which is what we will explore in this project.

  • The bindings only provide an interface - the implementation itself is treated as opaque - in Scala Native it's expressed by annotating methods with @extern, saying that the method will be available at runtime, and we will attempt to call it with parameters and types specified by our version of the interface

  • Interaction with the dynamic library, which holds compiled implementations of interfaced methods, is happening only in binary format - there is no types, just registers, pointers, stack, and bytes.

  • Version mismatch between header files and runtime dynamic library can lead to pretty much any consequences you can think of.

Manually writing bindings

Before we dive into the deep end, let's demonstrate this process on a very small example.

Let's say we have a build like this:

build.sbt

lazy val bindingDemo =
  project
    .in(file("binding-demo"))
    .enablePlugins(ScalaNativePlugin)

project/plugins.sbt

addSbtPlugin("org.scala-native" % "sbt-scala-native" % "0.4.4")

And we are supposed to interface with a library defined only by this header file:

bad_math.h

int badmath_add(int i, int j);

Then, we can manually define the binding to this library:

binding-demo/Main.scala

import scala.scalanative.unsafe.*

@main def hello = 
  println(BadMath.add(1, 1))
  println(BadMath.add(2, 1))

object BadMath:
  @name("badmath_add") /* 1 */ 
  @extern /* 2 */
  def add(i: CInt, j: CInt): CInt = extern /* 3 */
  1. We are telling Scala Native that the function we are defining is actually named badmath_add, otherwise it will name it add. Remember, there's no namespaces, objects, or packages, at runtime - just flat list of functions.

  2. This definition is designated as coming externally - its existence and location will be determined during linking

  3. The actual definition of the function - whenever we use it in our Scala code, Scala Native will emit a function call with these specific parameter types (this affect memory layout, remember!) and will assume specified return type

If we try to run this program, we get a righteous error:

sbt:demos> bindingDemo/run
[info] Linking (563 ms)
[info] Discovered 668 classes and 3696 methods
[info] Optimizing (debug mode) (325 ms)
[info] Generating intermediate code (307 ms)
[info] Produced 8 files
[info] Compiling to native code (271 ms)
[error] Undefined symbols for architecture arm64:
[error]   "_badmath_add", referenced from:
[error]       __SM13Main$package$D5helloiEO in 5.ll.o
[error] ld: symbol(s) not found for architecture arm64
[error] clang: error: linker command failed with exit code 1 (use -v to see invocation)

Which is expected - we have not provided the definition of badmath_add function, and the linker loudly complains about it.

Note: if you noticed this weird __SM13Main$package$D5helloiEO name, it's because it's mangled, which is necessary to map Scala's rich package/object/class structure to a flat list of uniquely named functions

First of all, let's provide the actual implementation for our function, in C:

int badmath_add(int i, int j) {
  if (i == 1 && j == 1)
    return 1; // MWAHAHAHAHA
  else
    return i + j;
}

To be able to link against it, we need to turn it into a binary artifact, in this case a dynamic library:

$ clang -dynamiclib binding-demo/badmath.c -o libbadmath.dylib

And to make our application aware of the library, we need to modify the linking options:

build.sbt

lazy val bindingDemo =
  project
    .in(file("binding-demo"))
    .enablePlugins(ScalaNativePlugin)
    .settings(scalaVersion := "3.1.1")
    .settings(nativeConfig := {
      val conf = nativeConfig.value
      val base = (ThisBuild / baseDirectory).value /* 1 */
      conf.withLinkingOptions(
        conf.linkingOptions ++ 
        List(
          "-lbadmath", /* 2 */
          s"-L$base" /* 3 */
        )
      )
    })
  1. Compute the root folder of our project

  2. Add -lbadmath to linking options - during linking, the library libbadmath.dylib will be looked up in the paths provided/set by default.

    Note that lib and .dylib are added to the library name

  3. Add root folder of the project to the list of paths where libraries will be searched during linking

After doing this, our application will run successfully:

sbt:demos> bindingDemo/run
[info] Linking (586 ms)
[info] Discovered 668 classes and 3696 methods
[info] Optimizing (debug mode) (446 ms)
[info] Generating intermediate code (439 ms)
[info] Produced 8 files
[info] Compiling to native code (277 ms)
[info] Linking native code (immix gc, none lto) (71 ms)
[info] Total (1849 ms)
1
3

And we can confirm our binary is dynamically linked against the library:

> otool -L binding-demo/target/scala-3.1.1/bindingdemo-out
binding-demo/target/scala-3.1.1/bindingdemo-out:
	libbadmath.dylib (compatibility version 0.0.0, current version 0.0.0)
	/usr/lib/libSystem.B.dylib (compatibility version 1.0.0, current version 1311.0.0)
	/usr/lib/libc++.1.dylib (compatibility version 1.0.0, current version 1200.3.0)

And that's it!

Now, for large libraries like OpenSSL and libpq we won't be compiling them ourselves - both in the runtime environment and during development we'll link against a globally installed development package for both of those.

For example, on Debian, the following two development packages are available: libpq-dev and libssl-dev. On Mac, I installed libpq and openssl packages with Homebrew.

Another option is to explicitly add Postgres and OpenSSL sources to your repository, and reference the necessary header files locally. The advantage here being the ability to checkout out a particular version of the interface (that matches your future runtime), and, in case of smaller libraries, - being able to build the library from source, and potentially link it statically to your application binary.

We won't be using this option, but do check out my examples repository where this approach is used for several libraries of varying size.

Using a binding generator

Manually defining bindings is easy and works really well for a small amount of functions.

But when the library interface grows, and especially if it includes a lot of different types, writing typesafe bindings can become tedious.

That's where a binding generator comes in - its only purpose it to translate C's typesystem and defintions into idiomatic code in the desired platform - in our case Scala 3 Native.

The generator we will be using is sn-bindgen, homegrown and creatively named by yours truly.

So let's define a sibling project, which will use the generator instead:

project/plugins.sbt

addSbtPlugin("com.indoorvivants" % "bindgen-sbt-plugin" % "0.0.8")

build.sbt

// ...

lazy val bindingGenDemo =
  project
    .in(file("binding-generator-demo"))
    .enablePlugins(ScalaNativePlugin, BindgenPlugin) /* 4 */
    .settings(scalaVersion := "3.1.1")
    .settings(
      bindgenBindings +=
        bindgen.interface.Binding(
          headerFile = (ThisBuild / baseDirectory).value / "badmath.h", /* 1 */
          packageName = "lib_bad_math", /* 2 */
          linkName = Some("badmath")  /* 3 */
        )
    )
    .settings(nativeConfig := {
      val conf = nativeConfig.value
      val base = (ThisBuild / baseDirectory).value
      conf.withLinkingOptions(
        conf.linkingOptions ++ List(s"-L$base")
      )
    })
  1. We define a binding and point it at the same header file as before

  2. We want all the generated code to be put into lib_bad_math package

  3. The code should automatically be linked against a library named badmath

    There's a very nice feature of Scala Native which allows you to put an annotation on an extern function/object and the compiler will automatically adjust linking flags:

    @link("badmath")
    @extern 
    def badmath_add(i: CInt, j: CInt): CInt = extern
    

    Will automatically add a -lbadmath flag during linking

  4. To enable code generation, we need to enable BindgenPlugin

We can remove the manual binding and replace the main code with:

import scala.scalanative.unsafe.*
import lib_bad_math.functions.*

@main def hello = 
  println(badmath_add(1, 1))
  println(badmath_add(2, 1))

Now, when you run bindingDemoRun, you will get exactly the same result. But this time the bindings are placed in the generated code:

binding-generator-demo/target/scala-3.1.1/src_managed/main/lib_bad_math.scala

package lib_bad_math

import scala.scalanative.unsafe.*
import scala.scalanative.unsigned.*
import scalanative.libc.*

@link("badmath")
@extern
private[lib_bad_math] object extern_functions:
  def badmath_add(i: CInt, j: CInt): CInt = extern

object functions:
  import extern_functions.*
  export extern_functions.*

This is a very simple example, but the bindings can get pretty involved, and their encoding can get pretty complex because of the inherent differences between C and Scala.

Setting up Postgres bindings

As I mentioned, we will be using globally installed development packages, and the locations differ wildly between operating systems. The only two we care about at this point is MacOS (where I'm developing this) and Linux (which is what I use to test and deploy).

Let's define two helper functions in our build.sbt to handle platform-specific paths:

build.sbt

import bindgen.interface.Platform

def postgresInclude = {
  import Platform.*
  (os, arch) match {
    case (OS.Linux, _) => Paths.get("/usr/include/postgresql/")
    case (OS.MacOS, Arch.aarch64) =>
      Paths.get("/opt/homebrew/opt/libpq/include/")
    case (OS.MacOS, Arch.x86_64) => Paths.get("/usr/local/opt/libpq/include/")
  }
}

def postgresLib = {
  import Platform.*
  (os, arch) match {
    case (OS.MacOS, Arch.aarch64) =>
      Some(Paths.get("/opt/homebrew/opt/libpq/lib/"))
    case (OS.MacOS, Arch.x86_64) =>
      Some(Paths.get("/usr/local/opt/libpq/lib/"))
    case _ => None
  }
}

These paths should work well enough on our chosen platforms, so let's define our postgres module, which will contain both the generated bindings, and the handwritten code to make the interaction with postgres more idiomatic.

build.sbt

lazy val postgres =
  project
    .in(file("postgres"))
    .enablePlugins(ScalaNativePlugin, BindgenPlugin)
    .settings(
      scalaVersion := Versions.Scala,
      // Generate bindings to Postgres main API
      bindgenBindings +=
        Binding(
          postgresInclude.resolve("libpq-fe.h").toFile(),
          "libpq",
          linkName = Some("pq"),
          cImports = List("libpq-fe.h"), /* 1 */ 
          clangFlags = List( /* 2 */
            "-std=gnu99",
            s"-I$postgresInclude",
            "-fsigned-char"
          ) 
        ),
      nativeConfig ~= { conf =>
        conf.withLinkingOptions(
          conf.linkingOptions ++ postgresLib.toList.map("-L" + _)
        )
      }
    )

And if everything is set up correctly, binding code will be generated, and we can see in it examples of enums:

  opaque type PGpipelineStatus = CUnsignedInt
  object PGpipelineStatus extends CEnumU[PGpipelineStatus]:
    given _tag: Tag[PGpipelineStatus] = Tag.UInt
    inline def define(inline a: Long): PGpipelineStatus = a.toUInt
    val PQ_PIPELINE_OFF = define(0)
    val PQ_PIPELINE_ON = define(1)
    val PQ_PIPELINE_ABORTED = define(2)

and functions:

  def PQescapeLiteral(conn: Ptr[PGconn], str: CString, len: size_t): CString = extern

  def PQescapeString(to: CString, from: CString, length: size_t): size_t = extern

  def PQescapeStringConn(conn: Ptr[PGconn], to: CString, from: CString, length: size_t, error: Ptr[CInt]): size_t = extern

Overall, the binding is about 800 lines long, which is not huge, but defining all of the functions and enums would have been tedious.

Before we start wrapping this API, let's get the rest of the bindings out of the way.

Setting up OpenSSL bindings

The process is very similar, with the exception that we don't have a single header file for the entire (huge) library, but rather it's conveniently broken into several parts, depending on what sort of cryptographical functions we need.

For our usecase we need utilities to compute HMACs to sign our JWT tokens used in authentication, and SHA256 to hash the stored passwords for our thought leaders.

HMAC utilities are defined in the openssl/evp.h file, and SHA functions are in the openssl/sha.h, both relative to the installation location of OpenSSL development package.

build.sbt

def opensslInclude = {
  import Platform.*
  (os, arch) match {
    case (OS.Linux, _) => Paths.get("/usr/include/")
    case (OS.MacOS, Arch.aarch64) =>
      Paths.get("/opt/homebrew/opt/openssl/include/")
  }
}

def opensslLib = {
  import Platform.*
  (os, arch) match {
    case (OS.MacOS, Arch.aarch64) =>
      Some(Paths.get("/opt/homebrew/opt/openssl/lib/"))
    case _ => None
  }
}

And our module will look like this:

build.sbt

lazy val openssl =
  project
    .in(file("openssl"))
    .enablePlugins(ScalaNativePlugin, BindgenPlugin)
    .settings(
      scalaVersion := Versions.Scala,
      bindgenBindings := {
        Seq(
          Binding(
            opensslInclude.resolve("openssl/sha.h").toFile(),
            "libcrypto",
            linkName = Some("crypto"),
            cImports = List("openssl/sha.h"),
            clangFlags = List(
              "-std=gnu99",
              s"-I$opensslInclude",
              "-fsigned-char"
            )
          ),
          Binding(
            opensslIncoude.resolve("openssl/evp.h").toFile(),
            "libhmac",
            linkName = Some("crypto"),
            cImports = List("openssl/evp.h"),
            clangFlags = List(
              "-std=gnu99",
              s"-I$opensslInclude",
              "-fsigned-char"
            )
          )
        )
      },
      nativeConfig ~= { conf =>
        conf.withLinkingOptions(
          conf.linkingOptions ++ opensslLib.toList.map("-L" + _)
        )
      }
    )

If everything is done correctly, we will have two new packages in generated sources:

  • libcrypto (~300LOC) with functions like this:
  def SHA256(d: Ptr[CUnsignedChar], n: size_t, md: Ptr[CUnsignedChar]): Ptr[CUnsignedChar] = extern

  def SHA256_Final(md: Ptr[CUnsignedChar], c: Ptr[SHA256_CTX]): CInt = extern

  def SHA256_Init(c: Ptr[SHA256_CTX]): CInt = extern

  def SHA256_Transform(c: Ptr[SHA256_CTX], data: Ptr[CUnsignedChar]): Unit = extern
  • libhmac(~7500LOC), which has a lot of functions, not all of which we'll actually use.

Generated APIs are very C-like, and to use them efficiently from Scala we want (and need, really) to build better, more typesafe abstractions.

Building a Postgres data access library (Roach)

While we could use the very low-level functions offered by libpq to do pretty much anything, it would be much nicer if we could write a thin layer on top of it which will fit with the rest of our, more idiomatic, Scala code.

The library will be inspired by Skunk and Scodec in how the codecs are defined and used. As a natural follow up to Skunk and Doobie, we will call the library Roach.

Helpers

Before we begin wrapping the library, let's define a couple of helper exceptions and an opaque type representing a result of some (fallible) operation:

package roach 

class RoachException(msg: String) extends Exception(msg)
class RoachFatalException(msg: String) extends Exception(msg)

opaque type Validated[A] = Either[String, A]
object Validated:
  extension [A](v: Validated[A])
    inline def getOrThrow: A =
      v match
        case Left(err) => throw new RoachFatalException(err)
        case Right(r)  => r
    inline def either: Either[String, A] = v

Connecting to Postgres

One of the fundamental concepts in the libpq usage is a connection to a Postgres server. It's represented as Ptr[PGConn] where PGConn's implementation is opaque - it has no fields, and the state of the connection can only be interrogated by using the functions provided by libpq.

To avoid exposing this unsafe construct (pointer to externally managed memory), let's define an opaque type:

opaque type Database = Ptr[PGconn]

The rest of the definitions will be put in the Database companion object.

Establishing the connection is pretty simple:

// def PQconnectdb(conninfo: CString): Ptr[PGconn] = extern 

import libpq.functions.*
import libpq.types.*

import scala.scalanative.unsigned.*
import scala.scalanative.unsafe.*

object Database:
  //...
  def apply(connString: String)(using Zone): Validated[Database] =
    val conn = PQconnectdb(toCString(connString))

    if PQstatus(conn) != ConnStatusType.CONNECTION_OK then
      PQfinish(conn)
      Left(fromCString(PQerrorMessage(conn)))
    else Right(conn)

This method takes an implicit Zone - Scala Native's construct used for safely tracking allocated memory. We use it to convert a Scala String to a CString which maps 1-to-1 to a char * in C, which is what libpq expects.

We use libpq-provided functions to both check the connection status, and capture the error message in case there has been a failure.

Early on I wanted to use the Using construct to create safe regions where various Postgres' resources are released when necessary. To support that, we need to provide an instance of Releasable[Database]:

  given Releasable[Database] with
    def release(db: Database) =
      if db != null && PQstatus(db) == ConnStatusType.CONNECTION_OK then
        PQfinish(db)

With this, we can open the connection to the database like this :

import scala.util.Using
import roach.*

val url = "postgres://tester:pwd@localhost:5432/db"

Using.resource(Database(url).getOrThrow) { db => 
  // db is safe to use
}

The connection will be closed after the block exits.

Because connection is stateful, it can become stale, or be closed by the server, or be disconnected because of some network effect, etc. It's not easy to recover from such a situation (more on that later), so let's just make sure that we detect it early and throw a fatal exception.

//.. 
object Database:
  // ..
  extension (d: Database)
    def connectionIsOkay: Boolean =
      val status = PQstatus(d)
      status != ConnStatusType.CONNECTION_NEEDED && status != ConnStatusType.CONNECTION_BAD

    def checkConnection(): Unit =
      val status = PQstatus(d)
      val broken = 
        status == ConnStatusType.CONNECTION_NEEDED || 
        status == ConnStatusType.CONNECTION_BAD
      if broken then throw new RoachFatalException("Postgres connection is down")

We can use those functions to early terminate any operations that need an active connection.

Sending queries to the server

Now that we have an active connection, we can execute server-side queries. In return, libpq will return another opaque data structure PGResult, or rather a pointer to it.

So let's wrap that first:

opaque type Result = Ptr[PGresult]
object Result:
  extension (r: Result)
    inline def status: ExecStatusType = PQresultStatus(r)

For now we'll only define one operation, to get the status of the executed query.

With this, sending a query to the server is quite simple:

//.. 
object Database:
  // ..
  extension (d: Database)
    // ..
    def execute(query: String)(using Zone): Validated[Result] =
      checkConnection()
      val cstr = toCString(query)
      val res = PQexec(d, cstr)
      val status = PQresultStatus(res)
      import ExecStatusType.*
      val failed =
        status == PGRES_BAD_RESPONSE ||
          status == PGRES_NONFATAL_ERROR ||
          status == PGRES_FATAL_ERROR

      if failed then
        PQclear(res) // important!
        Left(fromCString(PQerrorMessage(d)))
      else Right(res)
    end execute

Note that we clear out the result even if the status is not what we expect - this is important to do to avoid memory leaks.

The result itself contains no rows or any data - we will request that later. In a similar fashion to the Database, query Results can and should be closed.

object Result:
  // ..
  given Releasable[Result] with
    def release(db: Result) =
      if db != null then PQclear(db)

Not all queries return data, so we can treat some of them as commands, and implement a

//.. 
object Database:
  // ..
  extension (d: Database)
    // ..
    def command(query: String)(using Zone): Unit =
      checkConnection()
      Using.resource(d.execute(query).getOrThrow) { res =>
        PQresultStatus(res)
      }

This operation will throw an exception if the command was unsuccessful.

import scala.util.Using
import roach.*

val url = "postgres://tester:pwd@localhost:5432/db"

Using.resource(Database(url).getOrThrow) { db => 
  db.command("CREATE TABLE hello (id uuid primary key);")
}

If the query fails, its result will still be cleared, and an exception will be thrown in the userspace after.

Reading rows

Before we can make the API nicer to use from Scala, let's use the raw libpq API to verify that we can read the data from the Result correctly.

For now we will assume that the results are sent using the text protocol - Postgres can also use binary protocol for some field values, but this is beyond the scope of our library for now.

This assumption greatly simplifies our work - means we don't have to convert between the network byte order and the regular one.

Before we write out the full function, let's introduce the libpq functions we need:

  • PQnfields(result) - number of columns in the result

  • PQntuples(result) - number of rows in the result

  • PQftype(result, col) - the Oid of the type of the col-th column

    Postgres sends type information in terms of integer identifiers, so later on we will have to write out a mapping between the identifiers and Scala types.

  • PQfname(result, col) - the name of the col-th column in the result

  • PQgetvalue(result, row, col) the value (as a char * - C string, CString in Scala Native) of the col-th column in the row-th row.

Those are all the definitions we need to write a function that will read all the relevant information from the result:

object Result:
  // ..
  extension (r: Result)
    // ..
    def rows: (Vector[(Oid, String)], Vector[Vector[String]]) =
      val nFields = PQnfields(r)
      val nTuples = PQntuples(r)

      val meta = Vector.newBuilder[(Oid, String)]
      val tuples = Vector.newBuilder[Vector[String]]

      // Read all the column names and their types
      for i <- 0 until nFields do
        meta.addOne(PQftype(r, i) -> fromCString(PQfname(r, i)))
      
      // Read all the rows
      for t <- 0 until nTuples
      do
        tuples.addOne(
          (0 until nFields).map(f => fromCString(PQgetvalue(r, t, f))).toVector
        )

      meta.result -> tuples.result
    end rows

And let's make our example a bit more interesting:

import scala.util.Using
import roach.*

val url = "postgres://tester:pwd@localhost:5432/db"

Using.resource(Database(url).getOrThrow) { db => 
  db.command("DROP TABLE IF EXISTS hello")
  db.command(
    "CREATE TABLE hello(id uuid primary key, test int4, bla text)"
  )
  db.command(
    "insert into hello values (gen_random_uuid(), 25, 'howdyyy');"
  )
  db.command(
    "insert into hello values (gen_random_uuid(), 135, 'test test');"
  )

  Using.resource(db.execute("select * from hello").getOrThrow) { res =>
    println(res.rows)
  }
}

Which will printout something like

(Vector((2950,id), (23,test), (25,bla)),Vector(Vector(bb1cf609-8147-45cc-9de0-c78b19a6bc0f, 25, howdyyy), Vector(e1689c3b-b4c6-4963-9930-1a5dca09d39b, 135, test test)))

And it is exciting, because we can now read data from Postgres and bring it into the warm embrace of Scala, where we can take this code even further.

Codecs for Scala types

To make this API a bit more usable, let's take a page out of Skunk's and Scodec's book, and define our codec API in terms of composable single-column codecs.

Defining a codec API for a relational database is simplified by the fact that the results are always of tabular format.

After not much deliberation this is the definition I've arrived at (mostly driven by desire to not do things the right way and not spend any considerable amount of time on it):

package roach 
import scala.scalanative.unsafe.*

trait Codec[T]:
  self =>
  def accepts(idx: Int): String
  def length: Int
  def decode(get: Int => CString)(using Zone): T
  def encode(value: T): Int => Zone ?=> CString
  def bimap[B](f: T => B, g: B => T): Codec[B] =
    new Codec[B]:
      def accepts(offset: Int) = self.accepts(offset)
      def length = self.length
      def decode(get: Int => CString)(using Zone) =
        f(self.decode(get))
      def encode(value: B) =
        self.encode(g(value))
end Codec
  1. accepts method return the name of the type this codec expects at a particular position in the row. Postgres type names are, for example, int2, varchar, bool, etc.

  2. length returns the number of columns this particular codec can decode

  3. decode receives a function that allows it to get the raw value at a position in the row, and using that function it must construct a value of type for which the codec is defined.

  4. encode returns a function which allows the caller to encode the value of type T. This operation will be required when we implement prepared statements

    Note that this function is of Int => Zone ?=> CString, which allows the caller to handle the allocation (usage of Zone) at the time of encoding, and not when the encode method was first called.

    This is important as we'd prefer not to capture the Zone in the definition of encoder, as it may be closed by the time encoding is happening.

  5. bimap allows us to define codecs for types backed by a more primitive one.

Defining primitive codecs

If those functions don't make sense, don't worry, we can explain them further by defining a codec builder that assumes that the type is converted to and from String using the built-in toString method.

def stringLike[A](accept: String)(f: String => A): Codec[A] =
  new Codec[A]:
    inline def length: Int = 1
    inline def accepts(offset: Int) = accept

    def decode(get: Int => CString)(using Zone) =
      f(fromCString(get(0)))

    def encode(value: A) =
      _ => toCString(value.toString)

    override def toString() = s"Decode[$accept]"

This function creates a codec for a single column, decoding the desired value from its String representation, and encoding it to C string directly from a Scala String.

We can define several codecs this way:

  val int2 = stringLike[Short]("int2")(_.toShort)
  val int4 = stringLike[Int]("int4")(_.toInt)
  val int8 = stringLike[Long]("int8")(_.toLong)
  val float4 = stringLike[Float]("float4")(_.toFloat)
  val float8 = stringLike[Double]("float8")(_.toDouble)
  val uuid = stringLike[UUID]("uuid")(UUID.fromString(_))
  val bool = stringLike[Boolean]("bool")(_ == "t")
  val char = stringLike[Char]("char")(_(0).toChar)

And there's a whole host of textual datatypes which we will define as just Strings:

  private def textual(nm: String) = stringLike[String](nm)(identity)
  val name = textual("name")
  val varchar = textual("varchar")
  val bpchar = textual("bpchar")
  val text = textual("text")

And finally we can define a codec for Oid type itself, by using our bimap combinator:

  val oid =
    int4.bimap[Oid](i => Oid(i.toUInt), _.asInstanceOf[CUnsignedInt].toInt)

Concatenating codecs

To represent on the typelevel we will be using Scala's tuples, that were supercharged in Scala 3. They're also much easier to use with Mirrors, when we inevitably want to read complete case classes out of the database.

The API we want to support is chaining codecs like this:

val codec: Codec[(Int, String, Float)] = int4 ~ text ~ float4 

Unfortunately, if we simply define a ~ method that produces a (A, B) codec, due to operator precedence the produced codec we will get

val codec: Codec[((Int, String), Float)] = int4 ~ text ~ float4 

which will pose difficulties if we want to map it to a case class, which is isomorphic to a flat tuple of field types.

So we will have to get a bit more creative.

This problem can be broken down into two cases:

  1. Appending a codec to an existing codec of a tuple
  2. Putting two codecs in a tuple

Let's solve those two cases separately.

Appending a codec to a codec of a tuple

If we have

  • codecA: Codec[A] where A <: (N1, ..., Nm)
  • codecB: Codec[B]

then codecA ~ codecB is a Codec[T] where T = (N1, ..., Nm, B)

private[roach] class AppendCodec[A <: Tuple, B](
  a: Codec[A], 
  b: Codec[B]
) extends Codec[Tuple.Concat[A, (B *: EmptyTuple)]]:
  // This type is defined in Scala 3 standard library
  type T = Tuple.Concat[A, (B *: EmptyTuple)]

  // If the field is in the codec for A - delegate to it
  // otherwise - delegate to codec for B
  def accepts(offset: Int) =
    if (offset < a.length) then a.accepts(offset)
    else b.accepts(offset - a.length)

  def length = a.length + b.length

  // first we decode the fields covered by codec for A 
  // then shift the column cursor and letting codec for B 
  //   to decode the rest
  def decode(get: Int => CString)(using Zone): T =
    val left = a.decode(get)
    val right = b.decode((i: Int) => get(i + a.length))
    // in the end we concatenate the tuples
    left ++ (right *: EmptyTuple)
  
  def encode(value: T) =
    val (left, right) = value.splitAt(a.length).asInstanceOf[(A, Tuple1[B])]
    val leftEncode = a.encode(left)
    val rightEncode = b.encode(right._1)
    
    // Depending on where the column offset is,
    // we use different codecs to encode it
    (offset: Int) =>
      if (offset + 1 > a.length) then rightEncode(offset - a.length)
      else leftEncode(offset)

  override def toString() =
    s"AppendCodec[$a, $b]"

end AppendCodec

We can introduce this codec as an extension method:

object Codec:
  // ..
  extension [A <: Tuple](d: Codec[A])
    inline def ~[B](
        other: Codec[B]
    ): Codec[Tuple.Concat[A, B *: EmptyTuple]] =
      AppendCodec(d, other)
  end extension

Combining two codecs into a tuple

This codec is similar to the one we defined before, but we don't perform the concatenation:

private[roach] class CombineCodec[A, B](a: Codec[A], b: Codec[B])
    extends Codec[(A, B)]:
  type T = (A, B)
  def accepts(offset: Int) =
    if (offset < a.length) then a.accepts(offset)
    else b.accepts(offset - a.length)

  def length = a.length + b.length

  def decode(get: Int => CString)(using Zone): T =
    val left = a.decode(get)
    val right = b.decode((i: Int) => get(i + a.length))
    (left, right)

  def encode(value: T) =
    val leftEncode = a.encode(value._1)
    val rightEncode = b.encode(value._2)

    (offset: Int) =>
      if (offset + 1 > a.length) then rightEncode(offset - a.length)
      else leftEncode(offset)

  override def toString() =
    s"CombineCodec[$a, $b]"
end CombineCodec

And we introduce an extension method ~ with the restriction that type B cannot be another Tuple:

  extension [A](d: Codec[A])
    inline def ~[B](
        other: Codec[B]
    )(using NotGiven[B <:< Tuple]): Codec[(A, B)] =
      CombineCodec(d, other)
  end extension

These two extension methods put together allow us to build long codecs:

val x: Codec[(Short, Int, Float, String)] = int2 ~ int4 ~ float4 ~ varchar
println(x)
// AppendCodec[AppendCodec[CombineCodec[Decode[int2], Decode[int4]], Decode[float4]], Decode[varchar]]

Mapping a codec to a case class

Note: this approach was adapted (or stolen, depends on how you're looking at it) from Scodec

Now that we can build flat tuples out of any data types, it's actually fairly easy to map those tuples to case classes that have fields of same types and in same order.

Let's define a very simple typeclass:

trait Iso[A, B]:
  def convert(a: A): B
  def invert(b: B): A

It's just a simple abstraction to say "We can go between A and B for all A and all B"

Now we want to express the idea that certain tuples are isomorphic to certain case classes.

object Iso:
  given [X <: Tuple, A](using
      mir: Mirror.ProductOf[A] { type MirroredElemTypes = X }
  ): Iso[X, A] with
    def convert(a: X) =
      mir.fromProduct(a)
    def invert(a: A) =
      Tuple.fromProduct(a.asInstanceOf[Product]).asInstanceOf[X]

The type signature says that if X is a Tuple, and for some type A there exists a Mirror.ProductOf[A] such that its type member MirroredElemTypes is the same as X, then we can go between X and A back and forth.

For a case class like this:

case class Howdy(s: Short, i: Int, str: String)

The compiler will synthesise an instance of Mirror.ProductOf[Howdy] with

  type MirroredElemTypes = Short *: Int *: String *: EmptyTuple

and therefore you should successfully be able to summon the instance of Iso:

summon[Iso[(Short, Int, String), Howdy]] // compiles

The actual convert and invert methods are just runtime implementations of converting a case class to a Tuple and Tuple to an instance of case class.

And that's all we need to define an extension method:

object Codec:
  // ..
  extension [A](d: Codec[A])
    inline def as[T](using iso: Iso[A, T]) =
      new Codec[T]:
        def accepts(offset: Int) =
          d.accepts(offset)

        def length = d.length

        def decode(get: Int => CString)(using Zone) =
          iso.convert(d.decode(get))

        def encode(value: T) =
          d.encode(iso.invert(value))

And we can use it like this:

case class Howdy(s: Short, i: Int, str: String)

val c: Codec[Howdy] = (int2 ~ int4 ~ varchar).as[Howdy]

With this, finally, our codecs are powerful enough to read the rows from a query result in a typesafe fashion.

Decoding rows from a result

Defining a Postgres type mapping

Before we can start decoding, we need to define a mapping between Postgres' integer identifiers and textual names of types we used in our codecs definitions.

The number of types in postgres is very large, so we'll just define a subset:

package roach

import libpq.types.Oid
import scala.scalanative.unsigned.*

trait OidMapping:
  def map(c: String): Oid
  def rev(oid: Oid): String

object OidMapping extends OidMapping:
  private val mapping =
    Map(
      21 -> "int2",
      23 -> "int4",
      20 -> "int8",
      700 -> "float4",
      701 -> "float8",
      26 -> "oid",
      25 -> "text",
      1042 -> "bpchar",
      1043 -> "varchar",
      18 -> "char",
      2950 -> "uuid",
      19 -> "name",
      16 -> "bool"
    ).map((k, v) => Oid(k.toUInt) -> v)

  private val reverse = mapping.map(_.swap)

  inline override def map(c: String): Oid =
    reverse(c)

  inline override def rev(oid: Oid): String = mapping(oid)

end OidMapping

We've made it an interface to make it more extensible - to allow the user to define their own types and codecs for them.

These identifiers are defined Postgres' source code and can be retrieved from the database as well:

postgres=# select oid, typname from pg_type;
  oid  |                typname
-------+----------------------------------------
    16 | bool
    17 | bytea
    18 | char
    19 | name
    20 | int8
    21 | int2
    22 | int2vector
    23 | int4
    24 | regproc
    25 | text
    26 | oid
    27 | tid
    28 | xid
    29 | cid
    ... loads more

Reading the rows from the result

We've already introduced all the functions needed to retrieve the rows, we just need to invoke the correct codecs with the right offsets.

object Result:
  // ..
  extension (r: Result)
    // ..
    def readAll[A](
        codec: Codec[A]
    )(using z: Zone, oids: OidMapping = OidMapping): Vector[A] =
      val nFields = PQnfields(r)
      val nTuples = PQntuples(r)
      val tuples = Vector.newBuilder[A]

      if (codec.length != PQnfields(r)) then
        throw new RoachException(
          s"Provided codec is for ${codec.length} fields, while the result has ${PQnfields(r)} fields"
        )
  
      (0 until nFields).foreach { offset =>
        // make sure that the codec and the result set agree
        // on the types of the columns
        // this is no Skunk-level error reporting, but it's useful
        val expectedType = oids.rev(PQftype(r, offset))
        val fieldName = fromCString(PQfname(r, offset))

        if codec.accepts(offset) != expectedType then
          throw new RoachException(
            s"$offset: Field $fieldName is of type '$expectedType', " +
              s"but the decoder only accepts '${codec.accepts(offset)}'"
          )
      }
      
      // read each row
      (0 until nTuples).foreach { row =>
        val func =
          (i: Int) => PQgetvalue(r, row, i)
        tuples.addOne(codec.decode(func))
      }

      tuples.result
    end readAll

And here's our example, modified to use case classes:

import scala.util.Using
import roach.*

val url = "postgres://tester:pwd@localhost:5432/db"

Using.resource(Database(url).getOrThrow) { db => 
  db.command("DROP TABLE IF EXISTS hello")
  db.command(
    "CREATE TABLE hello(id uuid primary key, test int4, bla text)"
  )
  db.command(
    "insert into hello values (gen_random_uuid(), 25, 'howdyyy');"
  )
  db.command(
    "insert into hello values (gen_random_uuid(), 135, 'test test');"
  )

  case class Howdy(s: UUID, i: Int, str: String)
  val howdyCodec = (uuid ~ int4 ~ text).as[Howdy]

  Using.resource(db.execute("select * from hello").getOrThrow) { res =>
    res.readAll(howdyCodec).foreach(println)
  }
}

Will output:

Howdy(af1d03bd-e5d3-4493-931e-e4291cc997a5,25,howdyyy)
Howdy(1e7c9e36-4d19-4055-9923-3dbc0e434ea5,135,test test)

Executing queries with parameters

Now, so far we've been passing queries that don't depend on any runtime information our app might be producing.

One obvious issue we should consider is SQL injections - care must be taken to escape the potentially malicious input.

Libpq does include several functions to escape string literals in order for them to be passed to the database safely.

That said, the result of those operations is allocated as a new string and as such must be freed when no longer necessary - whereas in the design of our Codec we already create the C string with the desired values.

Luckily, libpq has a PQexecParams function, that allows passing a SQL statement with placeholders, and providing the values separately - which eliminates that chance of those values to be interpreted as part of the query.

The signature in C is as such:

PGresult *PQexecParams(PGconn *conn,
                       const char *command,
                       int nParams,
                       const Oid *paramTypes,
                       const char * const *paramValues,
                       const int *paramLengths,
                       const int *paramFormats,
                       int resultFormat);

And in generated Scala code:

  def PQexecParams(
      conn: Ptr[PGconn],
      command: CString,
      nParams: CInt,
      paramTypes: Ptr[Oid],
      paramValues: Ptr[CString],
      paramLengths: Ptr[CInt],
      paramFormats: Ptr[CInt],
      resultFormat: CInt
  ): Ptr[PGresult] = extern

Note that parameter types and values are passed as separate arrays. resultFormat will be set to 0 as we only deal with text-based transmission, and this allows us to not provide paramLength and paramFormats.

So let's lay out the parameter types and values for the execution:

    def executeParams[T](
        query: String,
        codec: Codec[T],
        data: T
    )(using z: Zone, oids: OidMapping = OidMapping): Validated[Result] =
      checkConnection()
      val nParams = codec.length

      // resolve and encode parameter types, as specified by 
      // the codec
      val paramTypes = stackalloc[Oid](nParams)
      for l <- 0 until nParams do paramTypes(l) = oids.map(codec.accepts(l))
      
      // encode each part of the data according to its codec
      val paramValues = stackalloc[CString](nParams)
      val encoder = codec.encode(data)
      for i <- 0 until nParams do paramValues(i) = encoder(i)
      
      val res = PQexecParams(
        d,                 
        toCString(query),  
        nParams,           
        paramTypes,
        paramValues,
        null,
        null,
        0
      )

      result(res)
    end executeParams

At the end we extracted the logic for handling potentially failed result into a small helper function:

  private[roach] def result(res: Result): Validated[Result] =
  val status = PQresultStatus(res)
  import ExecStatusType.*

  val failed =
    status == PGRES_BAD_RESPONSE ||
      status == PGRES_NONFATAL_ERROR ||
      status == PGRES_FATAL_ERROR

  if failed then
    PQclear(res) // important!
    Left(fromCString(PQerrorMessage(d)))
  else Right(res)

This function allows us to safely pass values to a parametric SQL query:

val res1 = db.executeParams(
  "select id from hello where test = $1 and bla = $2",
  int4 ~ text,
  25 -> "howdyyy"
).getOrThrow

println(res1.readAll(uuid))
// Vector(297f0da4-e5dd-4665-9f2c-ff0733401934)

Prepared statements

The way executeParams operates translated directly to how prepared statements are used in libpq.

Prepared statements allow us (with a live session) to send queries with the placeholders to the server, have them parsed and analysed, so that we can send the values at a later point.

It's very useful when executing lots of repeated queries, where only the input data changes.

The way it's done in libpq is very similar to executeParams, just the stages are separated in time. Because of this separation, we want to provide the user with some typesafe value they can use later to provide the values to a prepared statement.

This value will capture the reference to the database connection and the codec that will be used to encode the provided parameters.

Note: capture of a stateful connection is a difficult problem to solve - if the connection is terminated, the prepared statement will disappear with it. This means that attempt to execute it later will fail - you both need to replace the connection, and re-send the prepared statement to the server.

To actually execute the prepared statement, the only function we will need is PQexecPrepared, which looks and feels exactly like PQexecParams, with the exception that we pass statement name, not the query - as the statement is already prepared on the server, we should be able to look it up by name.

class Prepared[T] private[roach] (
    db: Database,
    codec: Codec[T],
    statementName: String
):
  private val nParams = codec.length

  def execute(data: T)(using z: Zone): Validated[Result] =
    db.checkConnection()
    // encode the values
    val paramValues = stackalloc[CString](nParams)
    val encoder = codec.encode(data)
    for i <- 0 until nParams do paramValues(i) = encoder(i)

    // send just the values
    val res = PQexecPrepared(
      db,
      toCString(statementName), // sic! statement name, not the query
      nParams,
      paramValues,
      null,
      null,
      0
    )

    db.result(res)
  end execute
end Prepared

Now, constructing the prepared statement itself is easy - we just need to execute PQPrepare and pass it the information about just the parameter types.

def prepare[T](
    query: String,
    statementName: String,
    codec: Codec[T]
)(using z: Zone, oids: OidMapping = OidMapping): Validated[Prepared[T]] =
  checkConnection()
  // encode params
  val nParams = codec.length
  val paramTypes = stackalloc[Oid](nParams)
  for l <- 0 until nParams do paramTypes(l) = oids.map(codec.accepts(l))
  
  // We send both query and the statement name
  val res = PQprepare(
    d,
    toCString(statementName), 
    toCString(query),
    nParams,
    paramTypes
  )

  result(res).map(_ => Prepared[T](d, codec, statementName))
end prepare

Phew!

What we've built so far is a solid foundation to safely run queries on Postgres (or in fact postgres-compatible databases).

There's a ton of features we haven't implemented, and I'm pretty sure that even the stuff we do have goes against all the possible best practices.

Most importantly, what should we do with all those fatal exceptions we're throwing when the connection goes away or is terminated for different reasons?

My answer, for which I will never be forgiven in the functional Scala JVM community which I cherish so much, is "let the server crash and restart". Instead of managing a mutable database connection that we refresh on demand.

There are several reasons for this:

  1. Our process' startup time is negligible in the grand scheme of things

  2. When we replace the Postgres connection, we also need to re-submit all of the prepared statements, meaning our mutable state is growing in complexity

  3. The server that we'll use will handle load re-balancing for us.

  4. But most importantly, for many years I wasn't allowed to crash processes and instead I focused on "safety" and "predictability" whatever those mean.

    No more. My time has come, and I have a thirst for blood. Processes are going to be killed.

Interacting with OpenSSL

Compared to our usage of Postgres, the functionality we need from OpenSSL is pretty limited:

  1. Compute password hashes (we use SHA256)
  2. Compute HMAC for JWT signatures (again, SHA256)

As such, I've decided to not produce a wrapper library for those two usecases.

Instead, let's just see how we can use the generated functions to compute a SHA256 hash (as a Scala String) of an arbitrary String.

Computing SHA256 hash of a String

This code has been adapted from a C example found on StackOverflow. I've since verified that as far as I can tell it produces correct hashes, verified against several other implementation.

Keep in mind, I'm not versed in cryptography at all.

The function we want to define has the following signature:

def sha256(plaintext: String)(using Zone): String

We will require Zone here to allocate memory for converting a Scala string to a C string. In fact, let's do that first:

  val str = toCString(plaintext)

Next, we need to allocate some memory for the datastructures required by OpenSSL. We will be allocating them on the stack (as opposed to the heap), because we don't need to re-use those structures anywhere after the function call ends, and because the size of the memory is well known and way below the maximum stack size.

val SHA256_DIGEST_LENGTH = 32
val sha256_ctx = stackalloc[SHA256_CTX](1)
val hash = stackalloc[CUnsignedChar](SHA256_DIGEST_LENGTH)
val outputBuffer = stackalloc[CChar](65)
  • sha256_ctx is OpenSSL's internal data structure

  • hash is a memory location where OpenSSL will write the hashed value

  • outputBuffer is a location where we will construct a C string, containing zero-padded hexadecimal values of each byte from the hash

The size of outputBuffer is 64 (2 hex digits per byte) + 1 (zero-byte, terminating C string)

As a convention, OpenSSL functions return 1 on success, and 0 on failure (remember, no exceptions in C!). So we need to run a chain of commands and assert on each return value:

assert(
  SHA256_Init(sha256_ctx) == 1,
  "failed to initialise sha context"
)
assert(
  SHA256_Update(sha256_ctx, str, string.strlen(str)) == 1,
  "failed to update sha context"
)
assert(
  SHA256_Final(hash, sha256_ctx) == 1,
  "failed to finalise sha context"
)

If the operations all succeeded, in our hash buffer we will have the bytes that we need to convert to a more familiar hexadecimal representation:

for i <- 0 until SHA256_DIGEST_LENGTH do
  stdio.sprintf(outputBuffer + (i * 2), c"%02x", hash(i))

outputBuffer(64) = 0.toByte

%02x means "convert the byte to it's hexadecimal value, and pad it with 0 on the left if there's only 1 digit"

All we need to do after this is to convert the resulting C string to a Scala one:

fromCString(outputBuffer)

And that's it! Here's the full function:

import libcrypto.functions.*
import libcrypto.types.*

val SHA256_DIGEST_LENGTH = 32

def sha256(plaintext: String)(using Zone): String =
  val str = toCString(plaintext)

  val sha256_ctx = stackalloc[SHA256_CTX](1)
  val hash = stackalloc[CUnsignedChar](SHA256_DIGEST_LENGTH)
  val outputBuffer = stackalloc[CChar](65)

  assert(
    SHA256_Init(sha256_ctx) == 1,
    "failed to initialise sha context"
  )
  assert(
    SHA256_Update(sha256_ctx, str, string.strlen(str)) == 1,
    "failed to update sha context"
  )
  assert(
    SHA256_Final(hash, sha256_ctx) == 1,
    "failed to finalise sha context"
  )

  for i <- 0 until SHA256_DIGEST_LENGTH do
    stdio.sprintf(outputBuffer + (i * 2), c"%02x", hash(i))
  outputBuffer(64) = 0.toByte

  fromCString(outputBuffer)
end sha256

Not too bad, I think.

The HMAC procedure is a bit more involved, but it also boils down to squinting at C examples and converting them to Scala.

We will do it closer to the action, when we actually have JWTs to sign.

Conclusion

At this point we have

  1. Our build set up for Postgres and OpenSSL usage
  2. A wrapper API for libpq, along with codecs infrastructure
  3. Working sample of using OpenSSL functions to compute SHA256 hashes