Series TL;DR
- We are deploying a Scala Native backend (with NGINX Unit) and Scala.js frontend (with Laminar) on Fly.io
- We are using my SN binding generator
- We are using Scala 3 heavily
- Code on Github
- Deployed app
- Roach - postgres bindings and interface
- Navigation
- Previous: Part I - Introduction
- Next: Part III - NGINX Unit and Fly.io
- Full series
Interfacing with native code
The general interaction with the various native libraries that we'll use is described in this diagram:
-
Compile and build time are conflated - the bindings are generated only at build time, and the application code can be compiled against the bindings, which will have (hopefully) rich type information
-
Binding generator may as well be human (you, yes, you!), if the interface is small and simple. But it can also be generated automatically, which is what we will explore in this project.
-
The bindings only provide an interface - the implementation itself is treated as opaque - in Scala Native it's expressed by annotating methods with
@extern
, saying that the method will be available at runtime, and we will attempt to call it with parameters and types specified by our version of the interface -
Interaction with the dynamic library, which holds compiled implementations of interfaced methods, is happening only in binary format - there is no types, just registers, pointers, stack, and bytes.
-
Version mismatch between header files and runtime dynamic library can lead to pretty much any consequences you can think of.
Manually writing bindings
Before we dive into the deep end, let's demonstrate this process on a very small example.
Let's say we have a build like this:
build.sbt
lazy val bindingDemo =
project
.in(file("binding-demo"))
.enablePlugins(ScalaNativePlugin)
project/plugins.sbt
addSbtPlugin("org.scala-native" % "sbt-scala-native" % "0.4.4")
And we are supposed to interface with a library defined only by this header file:
bad_math.h
int badmath_add(int i, int j);
Then, we can manually define the binding to this library:
binding-demo/Main.scala
import scala.scalanative.unsafe.*
@main def hello =
println(BadMath.add(1, 1))
println(BadMath.add(2, 1))
object BadMath:
@name("badmath_add") /* 1 */
@extern /* 2 */
def add(i: CInt, j: CInt): CInt = extern /* 3 */
-
We are telling Scala Native that the function we are defining is actually named
badmath_add
, otherwise it will name itadd
. Remember, there's no namespaces, objects, or packages, at runtime - just flat list of functions. -
This definition is designated as coming externally - its existence and location will be determined during linking
-
The actual definition of the function - whenever we use it in our Scala code, Scala Native will emit a function call with these specific parameter types (this affect memory layout, remember!) and will assume specified return type
If we try to run this program, we get a righteous error:
sbt:demos> bindingDemo/run
[info] Linking (563 ms)
[info] Discovered 668 classes and 3696 methods
[info] Optimizing (debug mode) (325 ms)
[info] Generating intermediate code (307 ms)
[info] Produced 8 files
[info] Compiling to native code (271 ms)
[error] Undefined symbols for architecture arm64:
[error] "_badmath_add", referenced from:
[error] __SM13Main$package$D5helloiEO in 5.ll.o
[error] ld: symbol(s) not found for architecture arm64
[error] clang: error: linker command failed with exit code 1 (use -v to see invocation)
Which is expected - we have not provided the definition of badmath_add
function, and the linker
loudly complains about it.
Note: if you noticed this weird __SM13Main$package$D5helloiEO
name, it's because it's mangled, which is necessary to map Scala's rich package/object/class structure to a flat list of uniquely named functions
First of all, let's provide the actual implementation for our function, in C:
int badmath_add(int i, int j) {
if (i == 1 && j == 1)
return 1; // MWAHAHAHAHA
else
return i + j;
}
To be able to link against it, we need to turn it into a binary artifact, in this case a dynamic library:
$ clang -dynamiclib binding-demo/badmath.c -o libbadmath.dylib
And to make our application aware of the library, we need to modify the linking options:
build.sbt
lazy val bindingDemo =
project
.in(file("binding-demo"))
.enablePlugins(ScalaNativePlugin)
.settings(scalaVersion := "3.1.1")
.settings(nativeConfig := {
val conf = nativeConfig.value
val base = (ThisBuild / baseDirectory).value /* 1 */
conf.withLinkingOptions(
conf.linkingOptions ++
List(
"-lbadmath", /* 2 */
s"-L$base" /* 3 */
)
)
})
-
Compute the root folder of our project
-
Add
-lbadmath
to linking options - during linking, the librarylibbadmath.dylib
will be looked up in the paths provided/set by default.Note that
lib
and.dylib
are added to the library name -
Add root folder of the project to the list of paths where libraries will be searched during linking
After doing this, our application will run successfully:
sbt:demos> bindingDemo/run
[info] Linking (586 ms)
[info] Discovered 668 classes and 3696 methods
[info] Optimizing (debug mode) (446 ms)
[info] Generating intermediate code (439 ms)
[info] Produced 8 files
[info] Compiling to native code (277 ms)
[info] Linking native code (immix gc, none lto) (71 ms)
[info] Total (1849 ms)
1
3
And we can confirm our binary is dynamically linked against the library:
> otool -L binding-demo/target/scala-3.1.1/bindingdemo-out
binding-demo/target/scala-3.1.1/bindingdemo-out:
libbadmath.dylib (compatibility version 0.0.0, current version 0.0.0)
/usr/lib/libSystem.B.dylib (compatibility version 1.0.0, current version 1311.0.0)
/usr/lib/libc++.1.dylib (compatibility version 1.0.0, current version 1200.3.0)
And that's it!
Now, for large libraries like OpenSSL and libpq we won't be compiling them ourselves - both in the runtime environment and during development we'll link against a globally installed development package for both of those.
For example, on Debian, the following two development packages are available: libpq-dev
and libssl-dev
.
On Mac, I installed libpq
and openssl
packages with Homebrew.
Another option is to explicitly add Postgres and OpenSSL sources to your repository, and reference the necessary header files locally. The advantage here being the ability to checkout out a particular version of the interface (that matches your future runtime), and, in case of smaller libraries, - being able to build the library from source, and potentially link it statically to your application binary.
We won't be using this option, but do check out my examples repository where this approach is used for several libraries of varying size.
Using a binding generator
Manually defining bindings is easy and works really well for a small amount of functions.
But when the library interface grows, and especially if it includes a lot of different types, writing typesafe bindings can become tedious.
That's where a binding generator comes in - its only purpose it to translate C's typesystem and defintions into idiomatic code in the desired platform - in our case Scala 3 Native.
The generator we will be using is sn-bindgen, homegrown and creatively named by yours truly.
So let's define a sibling project, which will use the generator instead:
project/plugins.sbt
addSbtPlugin("com.indoorvivants" % "bindgen-sbt-plugin" % "0.0.8")
build.sbt
// ...
lazy val bindingGenDemo =
project
.in(file("binding-generator-demo"))
.enablePlugins(ScalaNativePlugin, BindgenPlugin) /* 4 */
.settings(scalaVersion := "3.1.1")
.settings(
bindgenBindings +=
bindgen.interface.Binding(
headerFile = (ThisBuild / baseDirectory).value / "badmath.h", /* 1 */
packageName = "lib_bad_math", /* 2 */
linkName = Some("badmath") /* 3 */
)
)
.settings(nativeConfig := {
val conf = nativeConfig.value
val base = (ThisBuild / baseDirectory).value
conf.withLinkingOptions(
conf.linkingOptions ++ List(s"-L$base")
)
})
-
We define a binding and point it at the same header file as before
-
We want all the generated code to be put into
lib_bad_math
package -
The code should automatically be linked against a library named
badmath
There's a very nice feature of Scala Native which allows you to put an annotation on an extern function/object and the compiler will automatically adjust linking flags:
@link("badmath") @extern def badmath_add(i: CInt, j: CInt): CInt = extern
Will automatically add a
-lbadmath
flag during linking -
To enable code generation, we need to enable
BindgenPlugin
We can remove the manual binding and replace the main code with:
import scala.scalanative.unsafe.*
import lib_bad_math.functions.*
@main def hello =
println(badmath_add(1, 1))
println(badmath_add(2, 1))
Now, when you run bindingDemoRun
, you will get exactly the same result.
But this time the bindings are placed in the generated code:
binding-generator-demo/target/scala-3.1.1/src_managed/main/lib_bad_math.scala
package lib_bad_math
import scala.scalanative.unsafe.*
import scala.scalanative.unsigned.*
import scalanative.libc.*
@link("badmath")
@extern
private[lib_bad_math] object extern_functions:
def badmath_add(i: CInt, j: CInt): CInt = extern
object functions:
import extern_functions.*
export extern_functions.*
This is a very simple example, but the bindings can get pretty involved, and their encoding can get pretty complex because of the inherent differences between C and Scala.
Setting up Postgres bindings
As I mentioned, we will be using globally installed development packages, and the locations differ wildly between operating systems. The only two we care about at this point is MacOS (where I'm developing this) and Linux (which is what I use to test and deploy).
Let's define two helper functions in our build.sbt
to handle platform-specific paths:
build.sbt
import bindgen.interface.Platform
def postgresInclude = {
import Platform.*
(os, arch) match {
case (OS.Linux, _) => Paths.get("/usr/include/postgresql/")
case (OS.MacOS, Arch.aarch64) =>
Paths.get("/opt/homebrew/opt/libpq/include/")
case (OS.MacOS, Arch.x86_64) => Paths.get("/usr/local/opt/libpq/include/")
}
}
def postgresLib = {
import Platform.*
(os, arch) match {
case (OS.MacOS, Arch.aarch64) =>
Some(Paths.get("/opt/homebrew/opt/libpq/lib/"))
case (OS.MacOS, Arch.x86_64) =>
Some(Paths.get("/usr/local/opt/libpq/lib/"))
case _ => None
}
}
These paths should work well enough on our chosen platforms, so let's define our postgres module, which will contain both the generated bindings, and the handwritten code to make the interaction with postgres more idiomatic.
build.sbt
lazy val postgres =
project
.in(file("postgres"))
.enablePlugins(ScalaNativePlugin, BindgenPlugin)
.settings(
scalaVersion := Versions.Scala,
// Generate bindings to Postgres main API
bindgenBindings +=
Binding(
postgresInclude.resolve("libpq-fe.h").toFile(),
"libpq",
linkName = Some("pq"),
cImports = List("libpq-fe.h"), /* 1 */
clangFlags = List( /* 2 */
"-std=gnu99",
s"-I$postgresInclude",
"-fsigned-char"
)
),
nativeConfig ~= { conf =>
conf.withLinkingOptions(
conf.linkingOptions ++ postgresLib.toList.map("-L" + _)
)
}
)
And if everything is set up correctly, binding code will be generated, and we can see in it examples of enums:
opaque type PGpipelineStatus = CUnsignedInt
object PGpipelineStatus extends CEnumU[PGpipelineStatus]:
given _tag: Tag[PGpipelineStatus] = Tag.UInt
inline def define(inline a: Long): PGpipelineStatus = a.toUInt
val PQ_PIPELINE_OFF = define(0)
val PQ_PIPELINE_ON = define(1)
val PQ_PIPELINE_ABORTED = define(2)
and functions:
def PQescapeLiteral(conn: Ptr[PGconn], str: CString, len: size_t): CString = extern
def PQescapeString(to: CString, from: CString, length: size_t): size_t = extern
def PQescapeStringConn(conn: Ptr[PGconn], to: CString, from: CString, length: size_t, error: Ptr[CInt]): size_t = extern
Overall, the binding is about 800 lines long, which is not huge, but defining all of the functions and enums would have been tedious.
Before we start wrapping this API, let's get the rest of the bindings out of the way.
Setting up OpenSSL bindings
The process is very similar, with the exception that we don't have a single header file for the entire (huge) library, but rather it's conveniently broken into several parts, depending on what sort of cryptographical functions we need.
For our usecase we need utilities to compute HMACs to sign our JWT tokens used in authentication, and SHA256 to hash the stored passwords for our thought leaders.
HMAC utilities are defined in the openssl/evp.h
file, and SHA functions are in the openssl/sha.h
, both relative to the installation
location of OpenSSL development package.
build.sbt
def opensslInclude = {
import Platform.*
(os, arch) match {
case (OS.Linux, _) => Paths.get("/usr/include/")
case (OS.MacOS, Arch.aarch64) =>
Paths.get("/opt/homebrew/opt/openssl/include/")
}
}
def opensslLib = {
import Platform.*
(os, arch) match {
case (OS.MacOS, Arch.aarch64) =>
Some(Paths.get("/opt/homebrew/opt/openssl/lib/"))
case _ => None
}
}
And our module will look like this:
build.sbt
lazy val openssl =
project
.in(file("openssl"))
.enablePlugins(ScalaNativePlugin, BindgenPlugin)
.settings(
scalaVersion := Versions.Scala,
bindgenBindings := {
Seq(
Binding(
opensslInclude.resolve("openssl/sha.h").toFile(),
"libcrypto",
linkName = Some("crypto"),
cImports = List("openssl/sha.h"),
clangFlags = List(
"-std=gnu99",
s"-I$opensslInclude",
"-fsigned-char"
)
),
Binding(
opensslIncoude.resolve("openssl/evp.h").toFile(),
"libhmac",
linkName = Some("crypto"),
cImports = List("openssl/evp.h"),
clangFlags = List(
"-std=gnu99",
s"-I$opensslInclude",
"-fsigned-char"
)
)
)
},
nativeConfig ~= { conf =>
conf.withLinkingOptions(
conf.linkingOptions ++ opensslLib.toList.map("-L" + _)
)
}
)
If everything is done correctly, we will have two new packages in generated sources:
libcrypto
(~300LOC) with functions like this:
def SHA256(d: Ptr[CUnsignedChar], n: size_t, md: Ptr[CUnsignedChar]): Ptr[CUnsignedChar] = extern
def SHA256_Final(md: Ptr[CUnsignedChar], c: Ptr[SHA256_CTX]): CInt = extern
def SHA256_Init(c: Ptr[SHA256_CTX]): CInt = extern
def SHA256_Transform(c: Ptr[SHA256_CTX], data: Ptr[CUnsignedChar]): Unit = extern
libhmac
(~7500LOC), which has a lot of functions, not all of which we'll actually use.
Generated APIs are very C-like, and to use them efficiently from Scala we want (and need, really) to build better, more typesafe abstractions.
Building a Postgres data access library (Roach)
While we could use the very low-level functions offered by libpq to do pretty much anything, it would be much nicer if we could write a thin layer on top of it which will fit with the rest of our, more idiomatic, Scala code.
The library will be inspired by Skunk and Scodec in how
the codecs are defined and used. As a natural follow up to Skunk and Doobie, we will call the library Roach
.
Helpers
Before we begin wrapping the library, let's define a couple of helper exceptions and an opaque type representing a result of some (fallible) operation:
package roach
class RoachException(msg: String) extends Exception(msg)
class RoachFatalException(msg: String) extends Exception(msg)
opaque type Validated[A] = Either[String, A]
object Validated:
extension [A](v: Validated[A])
inline def getOrThrow: A =
v match
case Left(err) => throw new RoachFatalException(err)
case Right(r) => r
inline def either: Either[String, A] = v
Connecting to Postgres
One of the fundamental concepts in the libpq usage is a connection to a Postgres server. It's represented as Ptr[PGConn]
where
PGConn
's implementation is opaque - it has no fields, and the state of the connection can only be interrogated by using the functions
provided by libpq.
To avoid exposing this unsafe construct (pointer to externally managed memory), let's define an opaque type:
opaque type Database = Ptr[PGconn]
The rest of the definitions will be put in the Database
companion object.
Establishing the connection is pretty simple:
// def PQconnectdb(conninfo: CString): Ptr[PGconn] = extern
import libpq.functions.*
import libpq.types.*
import scala.scalanative.unsigned.*
import scala.scalanative.unsafe.*
object Database:
//...
def apply(connString: String)(using Zone): Validated[Database] =
val conn = PQconnectdb(toCString(connString))
if PQstatus(conn) != ConnStatusType.CONNECTION_OK then
PQfinish(conn)
Left(fromCString(PQerrorMessage(conn)))
else Right(conn)
This method takes an implicit Zone
- Scala Native's construct used for safely tracking allocated memory.
We use it to convert a Scala String
to a CString
which maps 1-to-1 to a char *
in C, which is what libpq expects.
We use libpq-provided functions to both check the connection status, and capture the error message in case there has been a failure.
Early on I wanted to use the Using
construct
to create safe regions where various Postgres' resources are released when necessary. To support that,
we need to provide an instance of Releasable[Database]
:
given Releasable[Database] with
def release(db: Database) =
if db != null && PQstatus(db) == ConnStatusType.CONNECTION_OK then
PQfinish(db)
With this, we can open the connection to the database like this :
import scala.util.Using
import roach.*
val url = "postgres://tester:pwd@localhost:5432/db"
Using.resource(Database(url).getOrThrow) { db =>
// db is safe to use
}
The connection will be closed after the block exits.
Because connection is stateful, it can become stale, or be closed by the server, or be disconnected because of some network effect, etc. It's not easy to recover from such a situation (more on that later), so let's just make sure that we detect it early and throw a fatal exception.
//..
object Database:
// ..
extension (d: Database)
def connectionIsOkay: Boolean =
val status = PQstatus(d)
status != ConnStatusType.CONNECTION_NEEDED && status != ConnStatusType.CONNECTION_BAD
def checkConnection(): Unit =
val status = PQstatus(d)
val broken =
status == ConnStatusType.CONNECTION_NEEDED ||
status == ConnStatusType.CONNECTION_BAD
if broken then throw new RoachFatalException("Postgres connection is down")
We can use those functions to early terminate any operations that need an active connection.
Sending queries to the server
Now that we have an active connection, we can execute server-side queries.
In return, libpq will return another opaque data structure PGResult
, or rather
a pointer to it.
So let's wrap that first:
opaque type Result = Ptr[PGresult]
object Result:
extension (r: Result)
inline def status: ExecStatusType = PQresultStatus(r)
For now we'll only define one operation, to get the status of the executed query.
With this, sending a query to the server is quite simple:
//..
object Database:
// ..
extension (d: Database)
// ..
def execute(query: String)(using Zone): Validated[Result] =
checkConnection()
val cstr = toCString(query)
val res = PQexec(d, cstr)
val status = PQresultStatus(res)
import ExecStatusType.*
val failed =
status == PGRES_BAD_RESPONSE ||
status == PGRES_NONFATAL_ERROR ||
status == PGRES_FATAL_ERROR
if failed then
PQclear(res) // important!
Left(fromCString(PQerrorMessage(d)))
else Right(res)
end execute
Note that we clear out the result even if the status is not what we expect - this is important to do to avoid memory leaks.
The result itself contains no rows or any data - we will request that later.
In a similar fashion to the Database
, query Result
s can and should be closed.
object Result:
// ..
given Releasable[Result] with
def release(db: Result) =
if db != null then PQclear(db)
Not all queries return data, so we can treat some of them as commands, and implement a
//..
object Database:
// ..
extension (d: Database)
// ..
def command(query: String)(using Zone): Unit =
checkConnection()
Using.resource(d.execute(query).getOrThrow) { res =>
PQresultStatus(res)
}
This operation will throw an exception if the command was unsuccessful.
import scala.util.Using
import roach.*
val url = "postgres://tester:pwd@localhost:5432/db"
Using.resource(Database(url).getOrThrow) { db =>
db.command("CREATE TABLE hello (id uuid primary key);")
}
If the query fails, its result will still be cleared, and an exception will be thrown in the userspace after.
Reading rows
Before we can make the API nicer to use from Scala, let's use the raw libpq
API to verify that we can read the data from the Result
correctly.
For now we will assume that the results are sent using the text protocol - Postgres can also use binary protocol for some field values, but this is beyond the scope of our library for now.
This assumption greatly simplifies our work - means we don't have to convert between the network byte order and the regular one.
Before we write out the full function, let's introduce the libpq functions we need:
-
PQnfields(result)
- number of columns in the result -
PQntuples(result)
- number of rows in the result -
PQftype(result, col)
- theOid
of the type of thecol
-th columnPostgres sends type information in terms of integer identifiers, so later on we will have to write out a mapping between the identifiers and Scala types.
-
PQfname(result, col)
- the name of thecol
-th column in the result -
PQgetvalue(result, row, col)
the value (as achar *
- C string,CString
in Scala Native) of thecol
-th column in therow
-th row.
Those are all the definitions we need to write a function that will read all the relevant information from the result:
object Result:
// ..
extension (r: Result)
// ..
def rows: (Vector[(Oid, String)], Vector[Vector[String]]) =
val nFields = PQnfields(r)
val nTuples = PQntuples(r)
val meta = Vector.newBuilder[(Oid, String)]
val tuples = Vector.newBuilder[Vector[String]]
// Read all the column names and their types
for i <- 0 until nFields do
meta.addOne(PQftype(r, i) -> fromCString(PQfname(r, i)))
// Read all the rows
for t <- 0 until nTuples
do
tuples.addOne(
(0 until nFields).map(f => fromCString(PQgetvalue(r, t, f))).toVector
)
meta.result -> tuples.result
end rows
And let's make our example a bit more interesting:
import scala.util.Using
import roach.*
val url = "postgres://tester:pwd@localhost:5432/db"
Using.resource(Database(url).getOrThrow) { db =>
db.command("DROP TABLE IF EXISTS hello")
db.command(
"CREATE TABLE hello(id uuid primary key, test int4, bla text)"
)
db.command(
"insert into hello values (gen_random_uuid(), 25, 'howdyyy');"
)
db.command(
"insert into hello values (gen_random_uuid(), 135, 'test test');"
)
Using.resource(db.execute("select * from hello").getOrThrow) { res =>
println(res.rows)
}
}
Which will printout something like
(Vector((2950,id), (23,test), (25,bla)),Vector(Vector(bb1cf609-8147-45cc-9de0-c78b19a6bc0f, 25, howdyyy), Vector(e1689c3b-b4c6-4963-9930-1a5dca09d39b, 135, test test)))
And it is exciting, because we can now read data from Postgres and bring it into the warm embrace of Scala, where we can take this code even further.
Codecs for Scala types
To make this API a bit more usable, let's take a page out of Skunk's and Scodec's book, and define our codec API in terms of composable single-column codecs.
Defining a codec API for a relational database is simplified by the fact that the results are always of tabular format.
After not much deliberation this is the definition I've arrived at (mostly driven by desire to not do things the right way and not spend any considerable amount of time on it):
package roach
import scala.scalanative.unsafe.*
trait Codec[T]:
self =>
def accepts(idx: Int): String
def length: Int
def decode(get: Int => CString)(using Zone): T
def encode(value: T): Int => Zone ?=> CString
def bimap[B](f: T => B, g: B => T): Codec[B] =
new Codec[B]:
def accepts(offset: Int) = self.accepts(offset)
def length = self.length
def decode(get: Int => CString)(using Zone) =
f(self.decode(get))
def encode(value: B) =
self.encode(g(value))
end Codec
-
accepts
method return the name of the type this codec expects at a particular position in the row. Postgres type names are, for example,int2
,varchar
,bool
, etc. -
length
returns the number of columns this particular codec can decode -
decode
receives a function that allows it to get the raw value at a position in the row, and using that function it must construct a value of type for which the codec is defined. -
encode
returns a function which allows the caller to encode the value of type T. This operation will be required when we implement prepared statementsNote that this function is of
Int => Zone ?=> CString
, which allows the caller to handle the allocation (usage ofZone
) at the time of encoding, and not when theencode
method was first called.This is important as we'd prefer not to capture the
Zone
in the definition of encoder, as it may be closed by the time encoding is happening. -
bimap
allows us to define codecs for types backed by a more primitive one.
Defining primitive codecs
If those functions don't make sense, don't worry, we can explain them
further by defining a codec builder that assumes that the type
is converted to and from String
using the built-in toString
method.
def stringLike[A](accept: String)(f: String => A): Codec[A] =
new Codec[A]:
inline def length: Int = 1
inline def accepts(offset: Int) = accept
def decode(get: Int => CString)(using Zone) =
f(fromCString(get(0)))
def encode(value: A) =
_ => toCString(value.toString)
override def toString() = s"Decode[$accept]"
This function creates a codec for a single column, decoding the
desired value from its String
representation, and encoding it to C string
directly from a Scala String.
We can define several codecs this way:
val int2 = stringLike[Short]("int2")(_.toShort)
val int4 = stringLike[Int]("int4")(_.toInt)
val int8 = stringLike[Long]("int8")(_.toLong)
val float4 = stringLike[Float]("float4")(_.toFloat)
val float8 = stringLike[Double]("float8")(_.toDouble)
val uuid = stringLike[UUID]("uuid")(UUID.fromString(_))
val bool = stringLike[Boolean]("bool")(_ == "t")
val char = stringLike[Char]("char")(_(0).toChar)
And there's a whole host of textual datatypes which we will define as
just String
s:
private def textual(nm: String) = stringLike[String](nm)(identity)
val name = textual("name")
val varchar = textual("varchar")
val bpchar = textual("bpchar")
val text = textual("text")
And finally we can define a codec for Oid
type itself, by using our bimap
combinator:
val oid =
int4.bimap[Oid](i => Oid(i.toUInt), _.asInstanceOf[CUnsignedInt].toInt)
Concatenating codecs
To represent on the typelevel we will be using Scala's tuples, that were supercharged in Scala 3. They're also much easier to use with Mirrors, when we inevitably want to read complete case classes out of the database.
The API we want to support is chaining codecs like this:
val codec: Codec[(Int, String, Float)] = int4 ~ text ~ float4
Unfortunately, if we simply define a ~
method that produces a (A, B)
codec,
due to operator precedence the produced codec we will get
val codec: Codec[((Int, String), Float)] = int4 ~ text ~ float4
which will pose difficulties if we want to map it to a case class, which is isomorphic to a flat tuple of field types.
So we will have to get a bit more creative.
This problem can be broken down into two cases:
- Appending a codec to an existing codec of a tuple
- Putting two codecs in a tuple
Let's solve those two cases separately.
Appending a codec to a codec of a tuple
If we have
codecA: Codec[A]
whereA <: (N1, ..., Nm)
codecB: Codec[B]
then codecA ~ codecB
is a Codec[T]
where T = (N1, ..., Nm, B)
private[roach] class AppendCodec[A <: Tuple, B](
a: Codec[A],
b: Codec[B]
) extends Codec[Tuple.Concat[A, (B *: EmptyTuple)]]:
// This type is defined in Scala 3 standard library
type T = Tuple.Concat[A, (B *: EmptyTuple)]
// If the field is in the codec for A - delegate to it
// otherwise - delegate to codec for B
def accepts(offset: Int) =
if (offset < a.length) then a.accepts(offset)
else b.accepts(offset - a.length)
def length = a.length + b.length
// first we decode the fields covered by codec for A
// then shift the column cursor and letting codec for B
// to decode the rest
def decode(get: Int => CString)(using Zone): T =
val left = a.decode(get)
val right = b.decode((i: Int) => get(i + a.length))
// in the end we concatenate the tuples
left ++ (right *: EmptyTuple)
def encode(value: T) =
val (left, right) = value.splitAt(a.length).asInstanceOf[(A, Tuple1[B])]
val leftEncode = a.encode(left)
val rightEncode = b.encode(right._1)
// Depending on where the column offset is,
// we use different codecs to encode it
(offset: Int) =>
if (offset + 1 > a.length) then rightEncode(offset - a.length)
else leftEncode(offset)
override def toString() =
s"AppendCodec[$a, $b]"
end AppendCodec
We can introduce this codec as an extension method:
object Codec:
// ..
extension [A <: Tuple](d: Codec[A])
inline def ~[B](
other: Codec[B]
): Codec[Tuple.Concat[A, B *: EmptyTuple]] =
AppendCodec(d, other)
end extension
Combining two codecs into a tuple
This codec is similar to the one we defined before, but we don't perform the concatenation:
private[roach] class CombineCodec[A, B](a: Codec[A], b: Codec[B])
extends Codec[(A, B)]:
type T = (A, B)
def accepts(offset: Int) =
if (offset < a.length) then a.accepts(offset)
else b.accepts(offset - a.length)
def length = a.length + b.length
def decode(get: Int => CString)(using Zone): T =
val left = a.decode(get)
val right = b.decode((i: Int) => get(i + a.length))
(left, right)
def encode(value: T) =
val leftEncode = a.encode(value._1)
val rightEncode = b.encode(value._2)
(offset: Int) =>
if (offset + 1 > a.length) then rightEncode(offset - a.length)
else leftEncode(offset)
override def toString() =
s"CombineCodec[$a, $b]"
end CombineCodec
And we introduce an extension method ~
with the restriction that type
B
cannot be another Tuple
:
extension [A](d: Codec[A])
inline def ~[B](
other: Codec[B]
)(using NotGiven[B <:< Tuple]): Codec[(A, B)] =
CombineCodec(d, other)
end extension
These two extension methods put together allow us to build long codecs:
val x: Codec[(Short, Int, Float, String)] = int2 ~ int4 ~ float4 ~ varchar
println(x)
// AppendCodec[AppendCodec[CombineCodec[Decode[int2], Decode[int4]], Decode[float4]], Decode[varchar]]
Mapping a codec to a case class
Note: this approach was adapted (or stolen, depends on how you're looking at it) from Scodec
Now that we can build flat tuples out of any data types, it's actually fairly easy to map those tuples to case classes that have fields of same types and in same order.
Let's define a very simple typeclass:
trait Iso[A, B]:
def convert(a: A): B
def invert(b: B): A
It's just a simple abstraction to say "We can go between A and B for all A and all B"
Now we want to express the idea that certain tuples are isomorphic to certain case classes.
object Iso:
given [X <: Tuple, A](using
mir: Mirror.ProductOf[A] { type MirroredElemTypes = X }
): Iso[X, A] with
def convert(a: X) =
mir.fromProduct(a)
def invert(a: A) =
Tuple.fromProduct(a.asInstanceOf[Product]).asInstanceOf[X]
The type signature says that if X
is a Tuple
, and for some type A
there exists a Mirror.ProductOf[A]
such that its type member MirroredElemTypes
is the same as X, then we can go between X and A
back and forth.
For a case class like this:
case class Howdy(s: Short, i: Int, str: String)
The compiler will synthesise an instance of Mirror.ProductOf[Howdy]
with
type MirroredElemTypes = Short *: Int *: String *: EmptyTuple
and therefore you should successfully be able to summon the instance of Iso
:
summon[Iso[(Short, Int, String), Howdy]] // compiles
The actual convert
and invert
methods are just runtime implementations of converting
a case class to a Tuple and Tuple to an instance of case class.
And that's all we need to define an extension method:
object Codec:
// ..
extension [A](d: Codec[A])
inline def as[T](using iso: Iso[A, T]) =
new Codec[T]:
def accepts(offset: Int) =
d.accepts(offset)
def length = d.length
def decode(get: Int => CString)(using Zone) =
iso.convert(d.decode(get))
def encode(value: T) =
d.encode(iso.invert(value))
And we can use it like this:
case class Howdy(s: Short, i: Int, str: String)
val c: Codec[Howdy] = (int2 ~ int4 ~ varchar).as[Howdy]
With this, finally, our codecs are powerful enough to read the rows from a query result in a typesafe fashion.
Decoding rows from a result
Defining a Postgres type mapping
Before we can start decoding, we need to define a mapping between Postgres' integer identifiers and textual names of types we used in our codecs definitions.
The number of types in postgres is very large, so we'll just define a subset:
package roach
import libpq.types.Oid
import scala.scalanative.unsigned.*
trait OidMapping:
def map(c: String): Oid
def rev(oid: Oid): String
object OidMapping extends OidMapping:
private val mapping =
Map(
21 -> "int2",
23 -> "int4",
20 -> "int8",
700 -> "float4",
701 -> "float8",
26 -> "oid",
25 -> "text",
1042 -> "bpchar",
1043 -> "varchar",
18 -> "char",
2950 -> "uuid",
19 -> "name",
16 -> "bool"
).map((k, v) => Oid(k.toUInt) -> v)
private val reverse = mapping.map(_.swap)
inline override def map(c: String): Oid =
reverse(c)
inline override def rev(oid: Oid): String = mapping(oid)
end OidMapping
We've made it an interface to make it more extensible - to allow the user to define their own types and codecs for them.
These identifiers are defined Postgres' source code and can be retrieved from the database as well:
postgres=# select oid, typname from pg_type;
oid | typname
-------+----------------------------------------
16 | bool
17 | bytea
18 | char
19 | name
20 | int8
21 | int2
22 | int2vector
23 | int4
24 | regproc
25 | text
26 | oid
27 | tid
28 | xid
29 | cid
... loads more
Reading the rows from the result
We've already introduced all the functions needed to retrieve the rows, we just need to invoke the correct codecs with the right offsets.
object Result:
// ..
extension (r: Result)
// ..
def readAll[A](
codec: Codec[A]
)(using z: Zone, oids: OidMapping = OidMapping): Vector[A] =
val nFields = PQnfields(r)
val nTuples = PQntuples(r)
val tuples = Vector.newBuilder[A]
if (codec.length != PQnfields(r)) then
throw new RoachException(
s"Provided codec is for ${codec.length} fields, while the result has ${PQnfields(r)} fields"
)
(0 until nFields).foreach { offset =>
// make sure that the codec and the result set agree
// on the types of the columns
// this is no Skunk-level error reporting, but it's useful
val expectedType = oids.rev(PQftype(r, offset))
val fieldName = fromCString(PQfname(r, offset))
if codec.accepts(offset) != expectedType then
throw new RoachException(
s"$offset: Field $fieldName is of type '$expectedType', " +
s"but the decoder only accepts '${codec.accepts(offset)}'"
)
}
// read each row
(0 until nTuples).foreach { row =>
val func =
(i: Int) => PQgetvalue(r, row, i)
tuples.addOne(codec.decode(func))
}
tuples.result
end readAll
And here's our example, modified to use case classes:
import scala.util.Using
import roach.*
val url = "postgres://tester:pwd@localhost:5432/db"
Using.resource(Database(url).getOrThrow) { db =>
db.command("DROP TABLE IF EXISTS hello")
db.command(
"CREATE TABLE hello(id uuid primary key, test int4, bla text)"
)
db.command(
"insert into hello values (gen_random_uuid(), 25, 'howdyyy');"
)
db.command(
"insert into hello values (gen_random_uuid(), 135, 'test test');"
)
case class Howdy(s: UUID, i: Int, str: String)
val howdyCodec = (uuid ~ int4 ~ text).as[Howdy]
Using.resource(db.execute("select * from hello").getOrThrow) { res =>
res.readAll(howdyCodec).foreach(println)
}
}
Will output:
Howdy(af1d03bd-e5d3-4493-931e-e4291cc997a5,25,howdyyy)
Howdy(1e7c9e36-4d19-4055-9923-3dbc0e434ea5,135,test test)
Executing queries with parameters
Now, so far we've been passing queries that don't depend on any runtime information our app might be producing.
One obvious issue we should consider is SQL injections - care must be taken to escape the potentially malicious input.
Libpq does include several functions to escape string literals in order for them to be passed to the database safely.
That said, the result of those operations is allocated as a new string and as such must be freed when no longer necessary - whereas in the design of our Codec we already create the C string with the desired values.
Luckily, libpq has a PQexecParams
function, that allows passing a SQL statement
with placeholders, and providing the values separately - which eliminates that
chance of those values to be interpreted as part of the query.
The signature in C is as such:
PGresult *PQexecParams(PGconn *conn,
const char *command,
int nParams,
const Oid *paramTypes,
const char * const *paramValues,
const int *paramLengths,
const int *paramFormats,
int resultFormat);
And in generated Scala code:
def PQexecParams(
conn: Ptr[PGconn],
command: CString,
nParams: CInt,
paramTypes: Ptr[Oid],
paramValues: Ptr[CString],
paramLengths: Ptr[CInt],
paramFormats: Ptr[CInt],
resultFormat: CInt
): Ptr[PGresult] = extern
Note that parameter types and values are passed as separate arrays.
resultFormat
will be set to 0 as we only deal with text-based transmission,
and this allows us to not provide paramLength
and paramFormats
.
So let's lay out the parameter types and values for the execution:
def executeParams[T](
query: String,
codec: Codec[T],
data: T
)(using z: Zone, oids: OidMapping = OidMapping): Validated[Result] =
checkConnection()
val nParams = codec.length
// resolve and encode parameter types, as specified by
// the codec
val paramTypes = stackalloc[Oid](nParams)
for l <- 0 until nParams do paramTypes(l) = oids.map(codec.accepts(l))
// encode each part of the data according to its codec
val paramValues = stackalloc[CString](nParams)
val encoder = codec.encode(data)
for i <- 0 until nParams do paramValues(i) = encoder(i)
val res = PQexecParams(
d,
toCString(query),
nParams,
paramTypes,
paramValues,
null,
null,
0
)
result(res)
end executeParams
At the end we extracted the logic for handling potentially failed result into a small helper function:
private[roach] def result(res: Result): Validated[Result] =
val status = PQresultStatus(res)
import ExecStatusType.*
val failed =
status == PGRES_BAD_RESPONSE ||
status == PGRES_NONFATAL_ERROR ||
status == PGRES_FATAL_ERROR
if failed then
PQclear(res) // important!
Left(fromCString(PQerrorMessage(d)))
else Right(res)
This function allows us to safely pass values to a parametric SQL query:
val res1 = db.executeParams(
"select id from hello where test = $1 and bla = $2",
int4 ~ text,
25 -> "howdyyy"
).getOrThrow
println(res1.readAll(uuid))
// Vector(297f0da4-e5dd-4665-9f2c-ff0733401934)
Prepared statements
The way executeParams
operates translated directly to how prepared statements
are used in libpq.
Prepared statements allow us (with a live session) to send queries with the placeholders to the server, have them parsed and analysed, so that we can send the values at a later point.
It's very useful when executing lots of repeated queries, where only the input data changes.
The way it's done in libpq is very similar to executeParams
, just the stages
are separated in time. Because of this separation, we want to provide the
user with some typesafe value they can use later to provide the values
to a prepared statement.
This value will capture the reference to the database connection and the codec that will be used to encode the provided parameters.
Note: capture of a stateful connection is a difficult problem to solve - if the connection is terminated, the prepared statement will disappear with it. This means that attempt to execute it later will fail - you both need to replace the connection, and re-send the prepared statement to the server.
To actually execute the prepared statement, the only function we will need is
PQexecPrepared
, which looks and feels exactly like PQexecParams
, with the
exception that we pass statement name, not the query - as the statement is
already prepared on the server, we should be able to look it up by name.
class Prepared[T] private[roach] (
db: Database,
codec: Codec[T],
statementName: String
):
private val nParams = codec.length
def execute(data: T)(using z: Zone): Validated[Result] =
db.checkConnection()
// encode the values
val paramValues = stackalloc[CString](nParams)
val encoder = codec.encode(data)
for i <- 0 until nParams do paramValues(i) = encoder(i)
// send just the values
val res = PQexecPrepared(
db,
toCString(statementName), // sic! statement name, not the query
nParams,
paramValues,
null,
null,
0
)
db.result(res)
end execute
end Prepared
Now, constructing the prepared statement itself is easy - we just need
to execute PQPrepare
and pass it the information about just the parameter types.
def prepare[T](
query: String,
statementName: String,
codec: Codec[T]
)(using z: Zone, oids: OidMapping = OidMapping): Validated[Prepared[T]] =
checkConnection()
// encode params
val nParams = codec.length
val paramTypes = stackalloc[Oid](nParams)
for l <- 0 until nParams do paramTypes(l) = oids.map(codec.accepts(l))
// We send both query and the statement name
val res = PQprepare(
d,
toCString(statementName),
toCString(query),
nParams,
paramTypes
)
result(res).map(_ => Prepared[T](d, codec, statementName))
end prepare
Phew!
What we've built so far is a solid foundation to safely run queries on Postgres (or in fact postgres-compatible databases).
There's a ton of features we haven't implemented, and I'm pretty sure that even the stuff we do have goes against all the possible best practices.
Most importantly, what should we do with all those fatal exceptions we're throwing when the connection goes away or is terminated for different reasons?
My answer, for which I will never be forgiven in the functional Scala JVM community which I cherish so much, is "let the server crash and restart". Instead of managing a mutable database connection that we refresh on demand.
There are several reasons for this:
-
Our process' startup time is negligible in the grand scheme of things
-
When we replace the Postgres connection, we also need to re-submit all of the prepared statements, meaning our mutable state is growing in complexity
-
The server that we'll use will handle load re-balancing for us.
-
But most importantly, for many years I wasn't allowed to crash processes and instead I focused on "safety" and "predictability" whatever those mean.
No more. My time has come, and I have a thirst for blood. Processes are going to be killed.
Interacting with OpenSSL
Compared to our usage of Postgres, the functionality we need from OpenSSL is pretty limited:
- Compute password hashes (we use SHA256)
- Compute HMAC for JWT signatures (again, SHA256)
As such, I've decided to not produce a wrapper library for those two usecases.
Instead, let's just see how we can use the generated functions to compute a
SHA256 hash (as a Scala String
) of an arbitrary String
.
Computing SHA256 hash of a String
This code has been adapted from a C example found on StackOverflow. I've since verified that as far as I can tell it produces correct hashes, verified against several other implementation.
Keep in mind, I'm not versed in cryptography at all.
The function we want to define has the following signature:
def sha256(plaintext: String)(using Zone): String
We will require Zone
here to allocate memory for converting a Scala string to a C
string. In fact, let's do that first:
val str = toCString(plaintext)
Next, we need to allocate some memory for the datastructures required by OpenSSL. We will be allocating them on the stack (as opposed to the heap), because we don't need to re-use those structures anywhere after the function call ends, and because the size of the memory is well known and way below the maximum stack size.
val SHA256_DIGEST_LENGTH = 32
val sha256_ctx = stackalloc[SHA256_CTX](1)
val hash = stackalloc[CUnsignedChar](SHA256_DIGEST_LENGTH)
val outputBuffer = stackalloc[CChar](65)
-
sha256_ctx
is OpenSSL's internal data structure -
hash
is a memory location where OpenSSL will write the hashed value -
outputBuffer
is a location where we will construct a C string, containing zero-padded hexadecimal values of each byte from thehash
The size of outputBuffer
is 64 (2 hex digits per byte) + 1 (zero-byte, terminating C string)
As a convention, OpenSSL functions return 1
on success, and 0
on failure
(remember, no exceptions in C!).
So we need to run a chain of commands and assert on each return value:
assert(
SHA256_Init(sha256_ctx) == 1,
"failed to initialise sha context"
)
assert(
SHA256_Update(sha256_ctx, str, string.strlen(str)) == 1,
"failed to update sha context"
)
assert(
SHA256_Final(hash, sha256_ctx) == 1,
"failed to finalise sha context"
)
If the operations all succeeded, in our hash
buffer we will have the
bytes that we need to convert to a more familiar hexadecimal representation:
for i <- 0 until SHA256_DIGEST_LENGTH do
stdio.sprintf(outputBuffer + (i * 2), c"%02x", hash(i))
outputBuffer(64) = 0.toByte
%02x
means "convert the byte to it's hexadecimal value, and pad it with 0 on the left if there's only 1 digit"
All we need to do after this is to convert the resulting C string to a Scala one:
fromCString(outputBuffer)
And that's it! Here's the full function:
import libcrypto.functions.*
import libcrypto.types.*
val SHA256_DIGEST_LENGTH = 32
def sha256(plaintext: String)(using Zone): String =
val str = toCString(plaintext)
val sha256_ctx = stackalloc[SHA256_CTX](1)
val hash = stackalloc[CUnsignedChar](SHA256_DIGEST_LENGTH)
val outputBuffer = stackalloc[CChar](65)
assert(
SHA256_Init(sha256_ctx) == 1,
"failed to initialise sha context"
)
assert(
SHA256_Update(sha256_ctx, str, string.strlen(str)) == 1,
"failed to update sha context"
)
assert(
SHA256_Final(hash, sha256_ctx) == 1,
"failed to finalise sha context"
)
for i <- 0 until SHA256_DIGEST_LENGTH do
stdio.sprintf(outputBuffer + (i * 2), c"%02x", hash(i))
outputBuffer(64) = 0.toByte
fromCString(outputBuffer)
end sha256
Not too bad, I think.
The HMAC procedure is a bit more involved, but it also boils down to squinting at C examples and converting them to Scala.
We will do it closer to the action, when we actually have JWTs to sign.
Conclusion
At this point we have
- Our build set up for Postgres and OpenSSL usage
- A wrapper API for libpq, along with codecs infrastructure
- Working sample of using OpenSSL functions to compute SHA256 hashes