Microservices in Golang, pt4 - Validation
Middleware#
When posting data it’s a good idea to validate the data and we’ll look into that a bit later, before that we need to check middleware (middleware it’s an HTTP handler that hijacks a request, does something before sending it back to another or final handler, for example a good use case is CORS or authentication). Read more about the Gorilla Mux middleware here
A very basic usage of middleware would be to log a message, you’ll notice that the order its assigned to the router mathers:
func middlewareOne(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("MiddlewareOne")
next.ServeHTTP(w, r)
})
}
func middlewareTwo(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Println("MiddlewareTwo")
next.ServeHTTP(w, r)
})
}
func main() {
...
api.Use(middlewareOne)
api.Use(middlewareTwo)
}
The output is:
1970/1/1 15:59:00 MiddlewareOne
1970/1/1 15:59:00 MiddlewareTwo
In our code for our PUT and POST, we can get the Song from the request and move all the DRY code to the middleware, such as the unmarshall.
package handlers
import (
"context"
"log"
"net/http"
"strconv"
"example.com/go-intro-microservices-pt2/data"
"github.com/gorilla/mux"
)
// KeySong is a key used for the Song object in the context
type KeySong struct{}
type Songs struct {
l *log.Logger
}
func NewSongs(l *log.Logger) *Songs {
return &Songs{l}
}
func (s *Songs) Get(rw http.ResponseWriter, r *http.Request) {
s.l.Println("Handle GET Songs")
ls := data.GetSongs()
err := ls.ToJSON(rw)
if err != nil {
http.Error(rw, "Unable to marshal json of songs!", http.StatusInternalServerError)
}
}
func (s *Songs) Post(rw http.ResponseWriter, r *http.Request) {
s.l.Println("Handle POST Songs")
song := r.Context().Value(KeySong{}).(data.Song)
data.AddSong(&song)
}
func (s Songs) Put(rw http.ResponseWriter, r *http.Request) {
vars := mux.Vars(r)
id, err := strconv.Atoi(vars["id"])
if err != nil {
http.Error(rw, "Unable to convert id", http.StatusBadRequest)
return
}
s.l.Println("Handle PUT Songs, update song id", id)
// an interface is returned but we cast to Song
song := r.Context().Value(KeySong{}).(data.Song)
err = data.UpdateSong(id, &song)
if err == data.ErrSongNotFound {
http.Error(rw, "Song not found", http.StatusNotFound)
return
}
if err != nil {
http.Error(rw, "Song not found", http.StatusInternalServerError)
return
}
}
func (s Songs) MiddlewareSongValidation(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
song := data.Song{}
if r.Method == "POST" || r.Method == "PUT" {
err := song.FromJSON(r.Body)
if err != nil {
s.l.Println("[ERROR] deserializing song", err)
http.Error(rw, "Error reading song", http.StatusBadRequest)
return
}
}
// add the product to the context
// the preferred approach is to use Types as keys
ctx := context.WithValue(r.Context(), KeySong{}, song)
r = r.WithContext(ctx)
// Call the next handler, which can be another middleware in the chain, or the final handler.
next.ServeHTTP(rw, r)
})
}
Here’s the code we’ve removed, or moved to the middleware
song := &data.Song{}
err = song.FromJSON(r.Body)
if err != nil {
http.Error(rw, "Unable to unmarshal json of song", http.StatusBadRequest)
}
JSON Validation#
We’re going to start looking at doing validation on our structs, by that I mean use the package Validator
to help us achieve that goal; It’s a nice tool that allow us to check the minimal length on struct fields, check if fields are present, if a particular field type is of a certain type, etc based on tags. Documents available here
and a list explaining why we should sanitize the data here
.
Start by adding a validator function in our data model. To do the validation we also need to construct a validator (data object) and add the validator tag in our struct.
In our data/songs.go
type Song struct {
ID int `json:"id"`
Band string `json:"band" validate:"required"`
...
}
func (s *Song) Validator() error {
validate := validator.New()
return validate.Struct(s)
}
To test we create a basic unit test, as such:
package data
import "testing"
func TestChecksValidation(t *testing.T) {
s := &Song{}
err := s.Validate()
if err != nil {
t.Fatal(err)
}
}
And then execute:
go test -timeout 30s example.com/go-intro-microservices-pt2/data -run ^TestChecksValidation
That should fail with:
--- FAIL: TestChecksValidation (0.00s)
songs_test.go:11: Key: 'Song.Band' Error:Field validation for 'Band' failed on the 'required' tag
FAIL
FAIL example.com/go-intro-microservices-pt2/data 0.055s
FAIL
So, let’s add a new validation tag to Price and make it validate:"gt=0"
, but also fulfil the required fields on our test.
package data
import "testing"
func TestChecksValidation(t *testing.T) {
s := &Song{
Band: "Mad Funk",
Price: 9.90,
}
err := s.Validate()
if err != nil {
t.Fatal(err)
}
}
Now that we understand how to validate, we should wire it to our API, through the middleware we’ve just created.
func (s Songs) MiddlewareSongValidation(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
song := data.Song{}
if r.Method == "POST" || r.Method == "PUT" {
err := song.FromJSON(r.Body)
if err != nil {
s.l.Println("[ERROR] deserializing song", err)
http.Error(rw, "Error reading song", http.StatusBadRequest)
return
}
// validate the song
err = song.Validate()
if err != nil {
s.l.Println("[ERROR] validating song", err)
http.Error(
rw,
fmt.Sprintf("Error validating song: %s", err),
http.StatusBadRequest,
)
return
}
}
// add the product to the context
// the preferred approach is to use Types as keys
ctx := context.WithValue(r.Context(), KeySong{}, song)
r = r.WithContext(ctx)
// Call the next handler, which can be another middleware in the chain, or the final handler.
next.ServeHTTP(rw, r)
})
}
So, if we exec:
curl -v -X PUT -d '{"band": "Zipzags", "title": "Songalicious", "price": 0, "sku": "abz1"}' localhost:9000/api/v1/songs/3
We’d get the output:
* Trying 127.0.0.1...
* TCP_NODELAY set
* Connected to localhost (127.0.0.1) port 9000 (#0)
> PUT /api/v1/songs/3 HTTP/1.1
> Host: localhost:9000
> User-Agent: curl/7.64.1
> Accept: */*
> Content-Length: 71
> Content-Type: application/x-www-form-urlencoded
>
* upload completely sent off: 71 out of 71 bytes
< HTTP/1.1 400 Bad Request
< Content-Type: text/plain; charset=utf-8
< X-Content-Type-Options: nosniff
< Date: Thu, 29 Oct 2020 18:21:30 GMT
< Content-Length: 99
<
Error validating song: Key: 'Song.Price' Error:Field validation for 'Price' failed on the 'gt' tag
* Connection #0 to host localhost left intact
* Closing connection 0
As you see, we present a useful error message and also improve the security of our application.
CORS#
To enable CORS see the Package CORS
As an example, for our use case we’ll only allow our localhost:9000; Notice that we’ve created an alias gorHandlers
for the Gorilla Handlers
, since we already have Handlers
used.
package main
import (
"context"
"log"
"net/http"
"os"
"os/signal"
"time"
"example.com/go-intro-microservices-pt2/handlers"
gorHandlers "github.com/gorilla/handlers"
"github.com/gorilla/mux"
)
var bindAddress = env.String("BIND_ADDRESS", false, ":9000", "Bind address for the server")
func main() {
l := log.New(os.Stdout, "rest-api", log.LstdFlags)
hh := handlers.NewHello(l)
sh := handlers.NewSongs(l)
sm := mux.NewRouter()
api := sm.PathPrefix("/api/v1").Subrouter()
api.HandleFunc("", hh.Get).Methods(http.MethodGet)
api.HandleFunc("", hh.Post).Methods(http.MethodPost)
api.HandleFunc("", hh.Put).Methods(http.MethodPut)
api.HandleFunc("", hh.Delete).Methods(http.MethodDelete)
api.HandleFunc("/songs", sh.Get).Methods(http.MethodGet)
api.HandleFunc("/songs", sh.Post).Methods(http.MethodPost)
api.HandleFunc("/songs/{id:[0-9]+}", sh.Put).Methods(http.MethodPut)
api.HandleFunc("", hh.NotFound)
api.Use(sh.MiddlewareSongValidation)
// CORS
ch := gorHandlers.CORS(
gorHandlers.AllowedOrigins(
[]string{"http://localhost:9000"},
),
)
s := &http.Server{
Addr: *bindAddress, // configure the bind address
Handler: ch(sm), // the default handlers
IdleTimeout: 120 * time.Second, // max time for connections using TCP keep-alive
ReadTimeout: 20 * time.Second, // max time to read request from client
WriteTimeout: 30 * time.Second, // max time to write response to the client
}
go func() {
err := s.ListenAndServe()
if err != nil {
l.Fatal(err)
}
}()
sigChan := make(chan os.Signal)
signal.Notify(sigChan, os.Interrupt)
signal.Notify(sigChan, os.Kill)
sig := <-sigChan
l.Println("Terminate received, gracefully shuttingdown...", sig)
tc, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
s.Shutdown(tc)
}