在单元测试中模拟 context.Done()

sys*_*cll 6 unit-testing go

我有一个 HTTP 处理程序,它为每个请求设置上下文截止日期:

func submitHandler(stream chan data) http.HandlerFunc {
    return func(w http.ResponseWriter, r *http.Request) {
        ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
        defer cancel()

        // read request body, etc.

        select {
        case stream <- req:
            w.WriteHeader(http.StatusNoContent)
        case <-ctx.Done():
            err := ctx.Err()
            if err == context.DeadlineExceeded {
                w.WriteHeader(http.StatusRequestTimeout)
            }
            log.Printf("context done: %v", err)
        }
    }
}
Run Code Online (Sandbox Code Playgroud)

我很容易测试http.StatusNoContent标题,但我不确定如何<-ctx.Done()在 select 语句中测试案例。

在我的测试用例中,我构建了一个模拟context.Context并将其传递给req.WithContext()我的模拟上的方法http.Request,但是,返回的状态代码总是http.StatusNoContent使我相信该select语句始终属于我的测试中的第一个案例。

type mockContext struct{}

func (ctx mockContext) Deadline() (deadline time.Time, ok bool) {
    return deadline, ok
}

func (ctx mockContext) Done() <-chan struct{} {
    ch := make(chan struct{})
    close(ch)
    return ch
}

func (ctx mockContext) Err() error {
    return context.DeadlineExceeded
}

func (ctx mockContext) Value(key interface{}) interface{} {
    return nil
}

func TestHandler(t *testing.T) {
    stream := make(chan data, 1)
    defer close(stream)

    handler := submitHandler(stream)
    req, err := http.NewRequest(http.MethodPost, "/submit", nil)
    if err != nil {
        t.Fatal(err)
    }
    req = req.WithContext(mockContext{})

    rec := httptest.NewRecorder()
    handler.ServeHTTP(rec, req)

    if rec.Code != http.StatusRequestTimeout {
        t.Errorf("expected status code: %d, got: %d", http.StatusRequestTimeout, rec.Code)
    }
}
Run Code Online (Sandbox Code Playgroud)

我怎么能模拟上下文截止日期已超过?

sys*_*cll 7

所以,经过多次反复试验,我发现我做错了什么。context.Context我没有尝试创建模拟,而是创建了一个具有过期截止日期的新模拟,并立即调用了返回的cancelFunc. 然后我把req.WithContext()它传给了,现在它就像一个魅力!

func TestHandler(t *testing.T) {
    stream := make(chan data, 1)
    defer close(stream)

    handler := submitHandler(stream)
    req, err := http.NewRequest(http.MethodPost, "/submit", nil)
    if err != nil {
        t.Fatal(err)
    }

    stream <- data{}
    ctx, cancel := context.WithDeadline(req.Context(), time.Now().Add(-7*time.Hour))
    cancel()
    req = req.WithContext(ctx)

    rec := httptest.NewRecorder()
    handler.ServeHTTP(rec, req)

    if rec.Code != http.StatusRequestTimeout {
        t.Errorf("expected status code: %d, got: %d", http.StatusRequestTimeout, rec.Code)
    }
}
Run Code Online (Sandbox Code Playgroud)