Skip to content

Commit

Permalink
Merge pull request #6 from ifad/unit-tests-and-refactoring
Browse files Browse the repository at this point in the history
Unit tests and refactoring
  • Loading branch information
vjt authored Mar 26, 2017
2 parents f0e3217 + 8802db4 commit 310b8d7
Show file tree
Hide file tree
Showing 7 changed files with 375 additions and 61 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
[![Build Status](https://travis-ci.org/ifad/clammit.svg)](https://travis-ci.org/ifad/clammit)
[![Code Climate](https://codeclimate.com/github/ifad/clammit/badges/gpa.svg)](https://codeclimate.com/github/ifad/clammit)

Clammit is a proxy that will perform virus scans of files uploaded via
`multipart/form-data`. If a virus exists, it will reject the request out of
Clammit is a proxy that will perform virus scans of files uploaded via http requests,
including `multipart/form-data`. If a virus exists, it will reject the request out of
hand. If no virus exists, the request is then forwarded to the application and
it's response returned in the upstream direction.

Expand Down
37 changes: 21 additions & 16 deletions src/clammit/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ package main
import (
"bytes"
"clammit/forwarder"
"clammit/scanner"
"encoding/json"
"flag"
"fmt"
clamd "github.com/dutchcoders/go-clamd"
"gopkg.in/gcfg.v1"
"io/ioutil"
"log"
Expand Down Expand Up @@ -102,7 +102,8 @@ var DefaultApplicationConfig = ApplicationConfig{
type Ctx struct {
Config Config
ApplicationURL *url.URL
ClamInterceptor *ClamInterceptor
ScanInterceptor *ScanInterceptor
Scanner scanner.Scanner
Logger *log.Logger
Listener net.Listener
ActivityChan chan int
Expand All @@ -113,7 +114,7 @@ type Ctx struct {
// JSON server information response
//
type Info struct {
ClamdURL string `json:"clam_server_url"`
Address string `json:"scan_server_url"`
PingResult string `json:"ping_result"`
Version string `json:"version"`
TestScanVirusResult string `json:"test_scan_virus"`
Expand All @@ -125,6 +126,7 @@ type Info struct {
//
var ctx *Ctx
var configFile string
var EICAR = []byte(`X5O!P%@AP[4\PZX54(P^)7CC)7}$EICAR-STANDARD-ANTIVIRUS-TEST-FILE!$H+H*`)

func init() {
flag.StringVar(&configFile, "config", "", "Configuration file")
Expand Down Expand Up @@ -166,9 +168,14 @@ func main() {
*/
ctx.ApplicationURL = checkURL(ctx.Config.App.ApplicationURL)
checkURL(ctx.Config.App.ClamdURL)
ctx.ClamInterceptor = &ClamInterceptor{
ClamdURL: ctx.Config.App.ClamdURL,

ctx.Scanner = new(scanner.Clamav)
ctx.Scanner.SetLogger(ctx.Logger, ctx.Config.App.Debug)
ctx.Scanner.SetAddress(ctx.Config.App.ClamdURL)

ctx.ScanInterceptor = &ScanInterceptor{
VirusStatusCode: ctx.Config.App.VirusStatusCode,
Scanner: ctx.Scanner,
}

/*
Expand Down Expand Up @@ -310,7 +317,7 @@ func scanHandler(w http.ResponseWriter, req *http.Request) {
ctx.ActivityChan <- 1
defer func() { ctx.ActivityChan <- -1 }()

if !ctx.ClamInterceptor.Handle(w, req, req.Body) {
if !ctx.ScanInterceptor.Handle(w, req, req.Body) {
w.Write([]byte("No virus found"))
}
}
Expand All @@ -327,15 +334,15 @@ func scanForwardHandler(w http.ResponseWriter, req *http.Request) {
ctx.ActivityChan <- 1
defer func() { ctx.ActivityChan <- -1 }()

fw := forwarder.NewForwarder(ctx.ApplicationURL, ctx.Config.App.ContentMemoryThreshold, ctx.ClamInterceptor)
fw := forwarder.NewForwarder(ctx.ApplicationURL, ctx.Config.App.ContentMemoryThreshold, ctx.ScanInterceptor)
fw.SetLogger(ctx.Logger, ctx.Config.App.Debug)
fw.HandleRequest(w, req)
}

/*
* Handler for /info
*
* Validates the Clamd connection
* Validates the Scanner connection
* Emits the information as a JSON response
*/
func infoHandler(w http.ResponseWriter, req *http.Request) {
Expand All @@ -345,16 +352,14 @@ func infoHandler(w http.ResponseWriter, req *http.Request) {
ctx.ActivityChan <- 1
defer func() { ctx.ActivityChan <- -1 }()

c := clamd.NewClamd(ctx.ClamInterceptor.ClamdURL)
info := &Info{
ClamdURL: ctx.ClamInterceptor.ClamdURL,
Address: ctx.Scanner.Address(),
}
if err := c.Ping(); err != nil {
// If we can't ping the Clamd server, no point in making the remaining requests
if err := ctx.Scanner.Ping(); err != nil {
info.PingResult = err.Error()
} else {
info.PingResult = "Connected to server OK"
if response, err := c.Version(); err != nil {
if response, err := ctx.Scanner.Version(); err != nil {
info.Version = err.Error()
} else {
for s := range response {
Expand All @@ -364,8 +369,8 @@ func infoHandler(w http.ResponseWriter, req *http.Request) {
/*
* Validate the Clamd response for a viral string
*/
reader := bytes.NewReader(clamd.EICAR)
if response, err := c.ScanStream(reader); err != nil {
reader := bytes.NewReader(EICAR)
if response, err := ctx.Scanner.Scan(reader); err != nil {
info.TestScanVirusResult = err.Error()
} else {
for s := range response {
Expand All @@ -376,7 +381,7 @@ func infoHandler(w http.ResponseWriter, req *http.Request) {
* Validate the Clamd response for a non-viral string
*/
reader = bytes.NewReader([]byte("foo bar mcgrew"))
if response, err := c.ScanStream(reader); err != nil {
if response, err := ctx.Scanner.Scan(reader); err != nil {
info.TestScanCleanResult = err.Error()
} else {
for s := range response {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,30 @@
package main

import (
"clammit/scanner"
"fmt"
"io"
"mime"
"mime/multipart"
"net/http"

clamd "github.com/dutchcoders/go-clamd"
)

//
// The implementation of the ClamAV interceptor
// The implementation of the Scan interceptor
//
type ClamInterceptor struct {
ClamdURL string
type ScanInterceptor struct {
VirusStatusCode int
Scanner scanner.Scanner
}

/*
* Interceptor implementation for Clamd
* Interceptor implementation
*
* Runs a multi-part parser across the request body and sends all file contents to Clamd
* Runs a multi-part parser across the request body and sends all file contents to Scanner
*
* returns True if the body contains a virus
*/
func (c *ClamInterceptor) Handle(w http.ResponseWriter, req *http.Request, body io.Reader) bool {
func (c *ScanInterceptor) Handle(w http.ResponseWriter, req *http.Request, body io.Reader) bool {
//
// Don't care unless it's a post
//
Expand Down Expand Up @@ -108,8 +107,8 @@ func (c *ClamInterceptor) Handle(w http.ResponseWriter, req *http.Request, body
*
* returns True if a virus has been found and a http error response has been written
*/
func (c *ClamInterceptor) respondOnVirus(w http.ResponseWriter, filename string, reader io.Reader) bool {
if hasVirus, err := c.scan(reader); err != nil {
func (c *ScanInterceptor) respondOnVirus(w http.ResponseWriter, filename string, reader io.Reader) bool {
if hasVirus, err := c.Scanner.HasVirus(reader); err != nil {
ctx.Logger.Printf("Unable to scan file (%s): %v\n", filename, err)
http.Error(w, "Internal Server Error", 500)
return true
Expand All @@ -120,36 +119,3 @@ func (c *ClamInterceptor) respondOnVirus(w http.ResponseWriter, filename string,
}
return false
}

/*
* This function performs the actual virus scan
*/
func (c *ClamInterceptor) scan(reader io.Reader) (bool, error) {

clam := clamd.NewClamd(c.ClamdURL)

if ctx.Config.App.Debug {
ctx.Logger.Println("Sending to clamav")
}

response, err := clam.ScanStream(reader)
if err != nil {
return false, err
}

hasVirus := false
for s := range response {
if s != "stream: OK" {
if ctx.Config.App.Debug {
ctx.Logger.Printf(" %v", s)
}
hasVirus = true
}
}

if ctx.Config.App.Debug {
ctx.Logger.Println(" result of scan:", hasVirus)
}

return hasVirus, nil
}
165 changes: 165 additions & 0 deletions src/clammit/scan_interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package main

import (
"bytes"
"clammit/scanner"
"io"
"log"
"mime/multipart"
"net/http"
"net/http/httptest"
"os"
"testing"
)

const virusCode = 418

var mockVirusFound = false

type MockScanner struct {
scanner.Engine
}

func (s MockScanner) HasVirus(reader io.Reader) (bool, error) {
return mockVirusFound, nil
}

var scanInterceptor = ScanInterceptor{
VirusStatusCode: virusCode,
Scanner: new(MockScanner),
}

var handler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { scanInterceptor.Handle(w, req, req.Body) })

func TestNonMultipartRequest_VirusFound_Without_ContentDisposition(t *testing.T) {
setup()
mockVirusFound = true
req := newHTTPRequest("POST", "application/octet-stream", bytes.NewReader([]byte(`<virus/>`)))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

if status := rr.Code; status != virusCode {
t.Errorf("handler returned wrong status code: got %v want %v",
status, virusCode)
}
expected := `File untitled has a virus!`
if rr.Body.String() != expected {
t.Errorf("handler returned unexpected body: got %v want %v",
rr.Body.String(), expected)
}
}

func TestNonMultipartRequest_VirusFound_With_ContentDisposition(t *testing.T) {
setup()
mockVirusFound = true
req := newHTTPRequest("POST", "application/octet-stream", bytes.NewReader([]byte(`<virus/>`)))
req.Header["Content-Disposition"] = []string{"attachment;filename=virus.dat"}
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

if status := rr.Code; status != virusCode {
t.Errorf("handler returned wrong status code: got %v want %v",
status, virusCode)
}
expected := `File virus.dat has a virus!`
if rr.Body.String() != expected {
t.Errorf("handler returned unexpected body: got %v want %v",
rr.Body.String(), expected)
}
}

func TestNonMultipartRequest_Clean(t *testing.T) {
setup()
mockVirusFound = false
req := newHTTPRequest("POST", "application/octet-stream", bytes.NewReader([]byte(`<clean/>`)))
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

if status := rr.Code; status != 200 {
t.Errorf("handler returned wrong status code: got %v want %v",
status, 200)
}
}

func TestMultipartRequest_VirusFound(t *testing.T) {
setup()
mockVirusFound = true

body, contentType := makeMultipartBody()

req := newHTTPRequest("POST", contentType, body)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

if status := rr.Code; status != virusCode {
t.Errorf("handler returned wrong status code: got %v want %v",
status, virusCode)
}
expected := `File foo.dat has a virus!`
if rr.Body.String() != expected {
t.Errorf("handler returned unexpected body: got %v want %v",
rr.Body.String(), expected)
}
}

func TestMultipartRequest_Clean(t *testing.T) {
setup()
mockVirusFound = false

body, contentType := makeMultipartBody()

req := newHTTPRequest("POST", contentType, body)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

if status := rr.Code; status != 200 {
t.Errorf("handler returned wrong status code: got %v want %v",
status, virusCode)
}
}

func makeMultipartBody() (*bytes.Buffer, string) {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)

defer writer.Close()

addPart(writer, "file1", "foo.dat")
addPart(writer, "file2", "bar.dat")

err := writer.Close()
if err != nil {
log.Fatal("Can't close multipart writer: %v", err)
}

return body, writer.FormDataContentType()
}

func addPart(w *multipart.Writer, name, fileName string) {
part, err := w.CreateFormFile(name, fileName)
if err != nil {
log.Fatal("Cannot create multipart body: %v", err)
}

_, err = io.WriteString(part, name)
if err != nil {
log.Fatal("Can't write part to multipart body: %v", err)
}
return
}

func setup() {
ctx = &Ctx{
ShuttingDown: false,
}
ctx.Logger = log.New(os.Stdout, "", log.LstdFlags)
}

func newHTTPRequest(method string, contentType string, body io.Reader) *http.Request {
req, _ := http.NewRequest(method, "http://clammit/scan", body)
req.Header = map[string][]string{
"Content-Type": []string{contentType},
"X-Forwarded-For": []string{"kermit"},
}
return req
}
Loading

0 comments on commit 310b8d7

Please sign in to comment.