I'm new in Go
and unit test
. In my project am using Go
with gorm
and connecting mysql
database.
my queries is how to unit test my code:
My code is below(main.go):
package main
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
)
type Jobs struct {
JobID uint `json: "jobId" gorm:"primary_key;auto_increment"`
SourcePath string `json: "sourcePath"`
Priority int64 `json: "priority"`
InternalPriority string `json: "internalPriority"`
ExecutionEnvironmentID string `json: "executionEnvironmentID"`
}
type ExecutionEnvironment struct {
ID uint `json: "id" gorm:"primary_key;auto_increment"`
ExecutionEnvironmentId string `json: "executionEnvironmentID"`
CloudProviderType string `json: "cloudProviderType"`
InfrastructureType string `json: "infrastructureType"`
CloudRegion string `json: "cloudRegion"`
CreatedAt time.Time `json: "createdAt"`
}
var db *gorm.DB
func initDB() {
var err error
dataSourceName := "root:@tcp(localhost:3306)/?parseTime=True"
db, err = gorm.Open("mysql", dataSourceName)
if err != nil {
fmt.Println(err)
panic("failed to connect database")
}
//db.Exec("CREATE DATABASE test")
db.LogMode(true)
db.Exec("USE test")
db.AutoMigrate(&Jobs{}, &ExecutionEnvironment{})
}
func GetAllJobs(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Get All Jobs function")
var jobs []Jobs
if err := db.Select("jobs.*, execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Find(&jobs).Error; err != nil {
fmt.Println(err)
}
fmt.Println()
if len(jobs) == 0 {
json.NewEncoder(w).Encode("No data found")
} else {
json.NewEncoder(w).Encode(jobs)
}
}
// create job
func createJob(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Create Jobs function")
var jobs Jobs
json.NewDecoder(r.Body).Decode(&jobs)
db.Create(&jobs)
json.NewEncoder(w).Encode(jobs)
}
// get job by id
func GetJobById(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
params := mux.Vars(r)
jobId := params["jobId"]
//var job []Jobs
//db.Preload("Items").First(&job, jobId)
var jobs []Jobs
var executionEnvironments []ExecutionEnvironment
if err := db.Table("jobs").Select("jobs.*, execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Where("job_id =?", jobId).Find(&jobs).Scan(&executionEnvironments).Error; err != nil {
fmt.Println(err)
}
if len(jobs) == 0 {
json.NewEncoder(w).Encode("No data found")
} else {
json.NewEncoder(w).Encode(jobs)
}
}
// Delete Job By Id
func DeleteJobById(w http.ResponseWriter, r *http.Request) {
params := mux.Vars(r)
jobId := params["jobId"]
// check data
var job []Jobs
db.Table("jobs").Select("jobs.*").Where("job_id=?", jobId).Find(&job)
if len(job) == 0 {
json.NewEncoder(w).Encode("Invalid JobId")
} else {
id64, _ := strconv.ParseUint(jobId, 10, 64)
idToDelete := uint(id64)
db.Where("job_id = ?", idToDelete).Delete(&Jobs{})
//db.Where("jobId = ?", idToDelete).Delete(&ExecutionEnvironment{})
json.NewEncoder(w).Encode("Job deleted successfully")
w.WriteHeader(http.StatusNoContent)
}
}
// create Execution Environments
func createEnvironments(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Create Execution Environments function")
var executionEnvironments ExecutionEnvironment
json.NewDecoder(r.Body).Decode(&executionEnvironments)
db.Create(&executionEnvironments)
json.NewEncoder(w).Encode(executionEnvironments)
}
// Get Job Cloud Region
func GetJobCloudRegion(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Println("Executing Get Job Cloud Region function")
params := mux.Vars(r)
jobId := params["jobId"]
//var jobs []Jobs
var executionEnvironment []ExecutionEnvironment
db.Table("jobs").Select("execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Where("jobs.job_id =?", jobId).Find(&executionEnvironment)
var pUuid []string
for _, uuid := range executionEnvironment {
pUuid = append(pUuid, uuid.CloudRegion)
}
json.NewEncoder(w).Encode(pUuid)
}
func main() {
// router
router := mux.NewRouter()
// Access URL
router.HandleFunc("/GetAllJobs", GetAllJobs).Methods("GET")
router.HandleFunc("/createJob", createJob).Methods("POST")
router.HandleFunc("/GetJobById/{jobId}", GetJobById).Methods("GET")
router.HandleFunc("/DeleteJobById/{jobId}", DeleteJobById).Methods("DELETE")
router.HandleFunc("/createEnvironments", createEnvironments).Methods("POST")
router.HandleFunc("/GetJobCloudRegion/{jobId}", GetJobCloudRegion).Methods("GET")
// Initialize db connection
initDB()
// config port
fmt.Printf("Starting server at 8000 \n")
http.ListenAndServe(":8000", router)
}
I try to create unit test file below, but it is not running it shows like this
main_test.go:
package main
import (
"log"
"os"
"testing"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
)
func TestinitDB(m *testing.M) {
dataSourceName := "root:@tcp(localhost:3306)/?parseTime=True"
db, err := gorm.Open("mysql", dataSourceName)
if err != nil {
log.Fatal("failed to connect database")
}
//db.Exec("CREATE DATABASE test")
db.LogMode(true)
db.Exec("USE test111")
os.Exit(m.Run())
}
Please help me to write unit test file
CodePudding user response:
"How to unit test" is a very broad question since it depends on what you want to test. In your example you're working with remote connections to a database which is usually something that is mocked in unit testing. It's not clear if that's what you're looking for and it's not a requirement to do so either. By seeing you use different databases I would expect the intention is not to mock.
Start by looking at this post that has already answered your question around how TestMain
and testing.M
is intended to work.
What your code currently does (if your test name would be named TestMain
properly) is add a method around your other tests to do setup and teardown, however you don't have any other tests to make use of this setup and teardown thus you'll get the result no tests to run
.
It's not a part of your question but I would suggest try to avoid testing.M
until you feel confident in testing Go code. Using testing.T
and testing separate units might be easier to understand. You could achieve pretty much the same thing by just calling initDB()
in your test and making the initializer take an argument.
func initDB(dbToUse string) {
// ...
db.Exec("USE " dbToUse)
}
You would then call initDB("test")
from your main file and initDB("test111")
from your test.
You can read about the testing package for Go at pkg.go.dev/testing where you'll also find the differences between testing.T
and testing.M
.
Here's a shorter example with some basic testing that does not require any setup or teardown and that uses testing.T
instead of testing.M
.
main.go
package main
import "fmt"
func main() {
fmt.Println(add(1, 2))
}
func add(a, b int) int {
return a b
}
main_test.go
package main
import "testing"
func TestAdd(t *testing.T) {
t.Run("add 2 2", func(t *testing.T) {
want := 4
// Call the function you want to test.
got := add(2, 2)
// Assert that you got your expected response
if got != want {
t.Fail()
}
})
}
This test will test your method add
and ensure it returns the right value when you pass 2, 2
as argument. The use of t.Run
is optional but it creates a sub test for you which makes reading the output a bit easier.
Since you test on package level you'll need to specify what package to test if you're not using the triple dot format including every package recursively.
To run the test in the example above, specify your package and -v
for verbose output.
$ go test ./ -v
=== RUN TestAdd
=== RUN TestAdd/add_2_ _2
--- PASS: TestAdd (0.00s)
--- PASS: TestAdd/add_2_ _2 (0.00s)
PASS
ok x (cached)
There is a lot more to learn around this topic as well such as testing frameworks and testing patterns. As an example the testing framework testify
helps you do assertions and prints nice output when tests fail and table driven tests is a pretty common pattern in Go.
You're also writing a HTTP server which usually requires additional testing setup to test properly. Luckily the http
package in standard library comes with a sub package named httptest
which can help you record external requests or start local servers for external requests. You can also test your handlers by directly calling your handlers with a manually constructed request.
It would look something like this.
func TestSomeHandler(t *testing.T) {
// Create a request to pass to our handler. We don't have any query parameters for now, so we'll
// pass 'nil' as the third parameter.
req, err := http.NewRequest("GET", "/some-endpoint", nil)
if err != nil {
t.Fatal(err)
}
// We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
rr := httptest.NewRecorder()
handler := http.HandlerFunc(SomeHandler)
// Our handlers satisfy http.Handler, so we can call their ServeHTTP method
// directly and pass in our Request and ResponseRecorder.
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
Now, to test some of your code. We can run the init method and call any of your services with a response recorder.
package main
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestGetAllJobs(t *testing.T) {
// Initialize the DB
initDB("test111")
req, err := http.NewRequest("GET", "/GetAllJobs", nil)
if err != nil {
t.Fatal(err)
}
rr := httptest.NewRecorder()
handler := http.HandlerFunc(GetAllJobs)
handler.ServeHTTP(rr, req)
// Check the status code is what we expect.
if status := rr.Code; status != http.StatusOK {
t.Errorf("handler returned wrong status code: got %v want %v",
status, http.StatusOK)
}
var response []Jobs
if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil {
t.Errorf("got invalid response, expected list of jobs, got: %v", rr.Body.String())
}
if len(response) < 1 {
t.Errorf("expected at least 1 job, got %v", len(response))
}
for _, job := range response {
if job.SourcePath == "" {
t.Errorf("expected job id %d to have a source path, was empty", job.JobID)
}
}
}