Cody Joyce 2 months ago
parent
commit
66d51c0606
21 changed files with 445 additions and 238 deletions
  1. 97 15
      db/db.go
  2. 15 10
      db/query.dis
  3. 0 36
      db/util.go
  4. 5 2
      gen/domquery/htmlquery.go
  5. 4 4
      gen/img/img.go
  6. 3 3
      gen/layout/main.go
  7. 24 17
      gen/main.go
  8. 3 3
      gen/partial/partial.go
  9. 63 17
      sec/middleware.go
  10. 42 1
      sec/sec.go
  11. 0 108
      sec/sess.go
  12. 0 0
      sess/base64/base64url.go
  13. 9 4
      sess/jwt/jwt.go
  14. 11 0
      sess/middleware.go
  15. 113 0
      sess/sess.go
  16. 1 1
      srv/middleware.go
  17. 5 6
      srv/req.go
  18. 27 4
      srv/res.go
  19. 7 4
      srv/rtr.go
  20. 1 3
      srv/srv.go
  21. 15 0
      utils/utils.go

+ 97 - 15
db/db.go

@@ -2,21 +2,30 @@ package db
 
 import (
 	"database/sql"
-	"fmt"
+	"errors"
+	"os"
 	"reflect"
 
+	"git.clearsky.net.au/cody/gex.git/utils"
 	_ "github.com/mattn/go-sqlite3"
 )
 
 var Conn *sql.DB
 
 func Connect(fp string) error {
-	var err error
-	Conn, err = sql.Open("sqlite3", fp)
+	_, err := os.Stat(fp)
 	if err != nil {
-		fmt.Println("db.Connect() Error connecting to database")
+		utils.Err(err)
+		os.Exit(1)
 		return err
 	}
+
+	if Conn, err = sql.Open("sqlite3", fp); err != nil {
+		utils.Err(err)
+		os.Exit(1)
+		return err
+	}
+
 	Conn.Exec(`
 		PRAGMA journal_mode=ON
 		PRAGMA journal_mode=WAL;
@@ -24,39 +33,112 @@ func Connect(fp string) error {
 		PRAGMA auto_vacuum=ON
 		PRAGMA foreign_keys=ON
 	`)
+
 	return nil
 }
 
-// Pass fields struct as a pointer
-func Get(fieldsPtr interface{}, query string, vals []any) {
-	err := Conn.QueryRow(query, vals...).Scan(StructPointers(fieldsPtr)...)
+func Get(model interface{}, query string, params ...any) error {
+	if reflect.TypeOf(model).Kind() != reflect.Pointer {
+		err := errors.New("model must be a pointer")
+		utils.Err(err)
+		return err
+	}
+
+	if reflect.TypeOf(model).Elem().Kind() != reflect.Struct {
+		err := errors.New("model must be a struct")
+		utils.Err(err)
+		return err
+	}
+
+	p := []any{}
+	p = append(p, params...)
+
+	err := Conn.QueryRow(query, p...).Scan(StructPointers(model)...)
 	if err != nil {
-		fmt.Printf("ERROR: %s\n", err)
-		return
+		err = errors.New("SQL " + err.Error())
+		utils.Err(err)
+		return err
 	}
+	return nil
 }
 
-func All(sliceInt interface{}, query string, vals []any) {
-	slicePtr := reflect.ValueOf(sliceInt)
+func All(model interface{}, query string, params ...any) error {
+	p := []any{}
+	p = append(p, params...)
+	if reflect.TypeOf(model).Kind() != reflect.Pointer {
+		err := errors.New("db.All(): model must be a pointer")
+		utils.Err(err)
+		return err
+	}
+
+	slicePtr := reflect.ValueOf(model)
+
 	sliceVal := slicePtr.Elem()
+
+	if sliceVal.Kind() != reflect.Slice {
+		err := errors.New("db.All(): model must be a slice of model")
+		utils.Err(err)
+		return err
+	}
+
 	modelType := sliceVal.Type().Elem()
 
-	rows, _ := Conn.Query(query, vals...)
+	rows, err := Conn.Query(query, p...)
+	if err != nil {
+		err := errors.New("db.All(): SQL " + err.Error())
+		utils.Err(err)
+		return err
+	}
 
 	var set []reflect.Value
 	for rows.Next() {
 		tmpModel := reflect.New(modelType)
 
-		rows.Scan(StructPointers(tmpModel.Interface())...)
+		err = rows.Scan(StructPointers(tmpModel.Interface())...)
 
 		set = append(set, tmpModel.Elem())
 	}
+	if err != nil {
+		utils.Err(err)
+		return err
+	}
 	rows.Close()
 
-	newSlice := reflect.Append(sliceVal, set...)
-	sliceVal.Set(newSlice)
+	sliceVal.Set(reflect.Append(sliceVal, set...))
+	return nil
 }
 
 func Close() {
 	Conn.Close()
 }
+
+func StructPointers(s interface{}) []interface{} {
+	model := reflect.ValueOf(s)
+	modelElem := model.Elem()
+
+	numFields := modelElem.NumField()
+
+	var Ptrs []interface{}
+
+	for i := 0; i < numFields; i++ {
+		field := modelElem.Field(i)
+
+		tag := modelElem.Type().Field(i).Tag
+
+		if tag != "" {
+			continue
+		}
+
+		// Merge
+
+		if field.Kind() == reflect.Struct {
+			Ptrs = append(Ptrs, StructPointers(field.Addr().Interface())...)
+			continue
+		}
+
+		Ptrs = append(Ptrs, field.Addr().Interface())
+
+	}
+
+	return Ptrs
+}

+ 15 - 10
db/query.go → db/query.dis

@@ -32,7 +32,7 @@ type Query struct {
 
 	tblName string
 	pkName  string
-	Sess    sec.Sess
+	Auth    *sec.Auth
 	err     error
 	ready   bool
 
@@ -42,7 +42,8 @@ type Query struct {
 	vals  []any
 }
 
-func (q *Query) New(targetInt interface{}, sess sec.Sess) *Query {
+func (q *Query) New(targetInt interface{}, auth *sec.Auth) *Query {
+
 	q.targetInt = targetInt
 	q.targetPtr = reflect.ValueOf(q.targetInt)
 
@@ -54,8 +55,8 @@ func (q *Query) New(targetInt interface{}, sess sec.Sess) *Query {
 	q.targetVal = q.targetPtr.Elem()
 	q.sel = "*"
 
-	// Get the session
-	q.Sess = sess
+	// Get the secion
+	q.Auth = auth
 
 	if q.targetVal.Kind().String() == "slice" {
 		q.newSlice()
@@ -176,7 +177,7 @@ func (q *Query) Create() (int64, error) {
 		"INSERT INTO %s (%s) VALUES (%s)", q.tblName, strings.Join(fields, ", "), strings.Join(plcHdr, ", "),
 	)
 
-	q.targetVal.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
+	q.targetVal.FieldByName("sec").Set(reflect.ValueOf(&q.Auth))
 	createRuleMethod := q.targetPtr.MethodByName("CreateRule")
 	createRulePassed := createRuleMethod.Call(nil)[0].Bool()
 
@@ -241,7 +242,7 @@ func (q *Query) Update() error {
 	tmpModel := reflect.New(q.modelType)
 
 	tmpQuery := Query{}
-	q.err = tmpQuery.New(tmpModel.Interface(), q.Sess).Read(int(id.Int()))
+	q.err = tmpQuery.New(tmpModel.Interface(), q.Auth).Read(int(id.Int()))
 
 	if q.err != nil {
 		return q.err
@@ -302,7 +303,7 @@ func (q *Query) Delete() error {
 	tmpModel := reflect.New(q.targetVal.Type())
 
 	tmpQuery := Query{}
-	tmpQuery.New(tmpModel.Interface(), q.Sess).Read(int(id.Int()))
+	tmpQuery.New(tmpModel.Interface(), q.Auth).Read(int(id.Int()))
 
 	existRuleMethod := tmpModel.MethodByName("DeleteRule")
 	existRulePassed := existRuleMethod.Call(nil)[0].Bool()
@@ -325,15 +326,17 @@ func (q *Query) Delete() error {
 }
 
 func (q *Query) Get() error {
-	q.targetVal.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
+
+	q.targetVal.FieldByName("Auth").Set(reflect.ValueOf(q.Auth))
 
 	// Create a temp model to store the data
 	tmpModel := reflect.New(q.modelType)
 	tmpModelElem := tmpModel.Elem()
-	tmpModelElem.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
+	tmpModelElem.FieldByName("Auth").Set(reflect.ValueOf(q.Auth))
 	//
 
 	query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
+
 	err := Conn.QueryRow(query, q.vals...).Scan(StructPointers(tmpModel.Interface())...)
 	if err != nil {
 		return errors.New("db.Query.Get(), " + err.Error())
@@ -352,7 +355,9 @@ func (q *Query) Get() error {
 }
 
 func (q *Query) All() error {
+
 	query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
+
 	rows, err := Conn.Query(query, q.vals...)
 
 	if err != nil {
@@ -362,7 +367,7 @@ func (q *Query) All() error {
 	var set []reflect.Value
 	for rows.Next() {
 		tmpModel := reflect.New(q.modelType)
-		tmpModel.Elem().FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
+		tmpModel.Elem().FieldByName("sec").Set(reflect.ValueOf(q.Auth))
 		rows.Scan(StructPointers(tmpModel.Interface())...)
 
 		readRuleMethod := tmpModel.MethodByName("ReadRule")

+ 0 - 36
db/util.go

@@ -1,36 +0,0 @@
-package db
-
-import (
-	"reflect"
-)
-
-func StructPointers(s interface{}) []interface{} {
-	model := reflect.ValueOf(s)
-	modelElem := model.Elem()
-
-	numFields := modelElem.NumField()
-
-	var Ptrs []interface{}
-
-	for i := 0; i < numFields; i++ {
-		field := modelElem.Field(i)
-
-		tag := modelElem.Type().Field(i).Tag
-
-		if tag != "" {
-			continue
-		}
-
-		// Merge
-
-		if field.Kind() == reflect.Struct {
-			Ptrs = append(Ptrs, StructPointers(field.Addr().Interface())...)
-			continue
-		}
-
-		Ptrs = append(Ptrs, field.Addr().Interface())
-
-	}
-
-	return Ptrs
-}

+ 5 - 2
gen/domquery/htmlquery.go

@@ -1,8 +1,10 @@
 package domquery
 
 import (
-	"fmt"
+	"errors"
 	"strings"
+
+	"git.clearsky.net.au/cody/gex.git/utils"
 )
 
 /* End Node Functions */
@@ -210,7 +212,8 @@ func getCloseNode(node *Node) *Node {
 	}
 	idx = idx + 1
 	if idx > len(node.Parent.Children)-1 {
-		fmt.Println("Parse Error: Unclosed tag in " + node.token)
+		err := errors.New("Parse Error: Unclosed tag in " + node.token)
+		utils.Err(err)
 		idx--
 	}
 

+ 4 - 4
gen/img/img.go

@@ -1,13 +1,13 @@
 package img
 
 import (
-	"fmt"
 	"os"
 	"os/exec"
 	"path/filepath"
 	"strings"
 
 	"git.clearsky.net.au/cody/gex.git/gen/domquery"
+	"git.clearsky.net.au/cody/gex.git/utils"
 )
 
 func Process(htmlDir string, tag *domquery.Node) string {
@@ -22,7 +22,7 @@ func Process(htmlDir string, tag *domquery.Node) string {
 	imgSrc = strings.Replace(imgSrc, "/app/", "app/", 1)
 	_, err := os.ReadFile(imgSrc)
 	if err != nil {
-		fmt.Println(err)
+		utils.Err(err)
 		return ""
 	}
 
@@ -70,7 +70,7 @@ func Process(htmlDir string, tag *domquery.Node) string {
 		tag.SetAttribute("src", "/"+newImgSrc)
 		return ""
 	} else {
-		fmt.Println(err)
+		utils.Err(err)
 	}
 
 	command += " " + newImgSrc
@@ -79,7 +79,7 @@ func Process(htmlDir string, tag *domquery.Node) string {
 	_, err = cmd.Output()
 
 	if err != nil {
-		fmt.Println(err)
+		utils.Err(err)
 	}
 
 	tag.SetAttribute("src", "/"+newImgSrc)

+ 3 - 3
gen/layout/main.go

@@ -1,12 +1,12 @@
 package layout
 
 import (
-	"fmt"
 	"os"
 	"path"
 	"strings"
 
 	"git.clearsky.net.au/cody/gex.git/gen/domquery"
+	"git.clearsky.net.au/cody/gex.git/utils"
 )
 
 func ProcessHTML(htmlStr string, fp string) (string, error) {
@@ -31,7 +31,7 @@ func ProcessHTML(htmlStr string, fp string) (string, error) {
 
 	layoutFileByt, err := os.ReadFile(attr)
 	if err != nil {
-		fmt.Printf("ERROR: %s", err)
+		utils.Err(err)
 		return "", err
 	}
 
@@ -41,7 +41,7 @@ func ProcessHTML(htmlStr string, fp string) (string, error) {
 	layoutFileHtml := string(layoutFileByt)
 	layoutFileHtml, err = ProcessHTML(layoutFileHtml, fp)
 	if err != nil {
-		fmt.Printf("ERROR: %s\n", err)
+		utils.Err(err)
 	}
 
 	layoutFileDom := domquery.LoadHTML(layoutFileHtml)

+ 24 - 17
gen/main.go

@@ -11,6 +11,7 @@ import (
 	"git.clearsky.net.au/cody/gex.git/gen/img"
 	"git.clearsky.net.au/cody/gex.git/gen/layout"
 	"git.clearsky.net.au/cody/gex.git/gen/partial"
+	"git.clearsky.net.au/cody/gex.git/utils"
 )
 
 var appRuleList cssRuleList
@@ -18,37 +19,32 @@ var appRuleList cssRuleList
 //var appJsStr string
 
 func GeneratePages(dir string) error {
-	err := processTemplate(dir)
-	if err != nil {
-		fmt.Println(err)
+	if err := processTemplate(dir); err != nil {
+		utils.Err(err)
 		return err
 	}
 
 	// Process Partials
-	err = processPartials(dir)
-	if err != nil {
-		fmt.Println(err)
+	if err := processPartials(dir); err != nil {
+		utils.Err(err)
 		return err
 	}
 
 	// Build the layout & blocks
-	err = processLayout(dir)
-	if err != nil {
-		fmt.Println(err)
+	if err := processLayout(dir); err != nil {
+		utils.Err(err)
 		return err
 	}
 
 	// Process IMG on the generated files
-	err = ProcessImg(dir)
-	if err != nil {
-		fmt.Println(err)
+	if err := ProcessImg(dir); err != nil {
+		utils.Err(err)
 		return err
 	}
 
 	// Clean block tags
-	err = cleanBlocks(dir)
-	if err != nil {
-		fmt.Println(err)
+	if err := cleanBlocks(dir); err != nil {
+		utils.Err(err)
 		return err
 	}
 
@@ -74,6 +70,7 @@ func processTemplate(dir string) error {
 		// get the html string
 		data, err := os.ReadFile(fp)
 		if err != nil {
+			utils.Err(err)
 			return err
 		}
 		htmlStr := string(data)
@@ -103,6 +100,7 @@ func processTemplate(dir string) error {
 
 	walk := func(path string, info os.FileInfo, err error) error {
 		if err != nil {
+			utils.Err(err)
 			return err
 		}
 
@@ -124,13 +122,14 @@ func processPartials(dir string) error {
 
 		data, err := os.ReadFile(fp)
 		if err != nil {
+			utils.Err(err)
 			return err
 		}
 		htmlStr := string(data)
 
 		htmlStr, err = partial.ProcessHTML(htmlStr, fp)
 		if err != nil {
-			fmt.Println("\n" + fp)
+			utils.Err(err)
 			return err
 		}
 
@@ -139,6 +138,7 @@ func processPartials(dir string) error {
 
 	walk := func(path string, info os.FileInfo, err error) error {
 		if err != nil {
+			utils.Err(err)
 			return err
 		}
 
@@ -151,6 +151,7 @@ func processPartials(dir string) error {
 	}
 	err := filepath.Walk(dir, walk)
 	if err != nil {
+		utils.Err(err)
 		return err
 	}
 	return nil
@@ -164,13 +165,14 @@ func processLayout(dir string) error {
 
 		data, err := os.ReadFile(fp)
 		if err != nil {
+			utils.Err(err)
 			return err
 		}
 		htmlStr := string(data)
 
 		htmlStr, err = layout.ProcessHTML(htmlStr, fp)
 		if err != nil {
-			fmt.Println("\n" + fp)
+			utils.Err(err)
 			return err
 		}
 
@@ -180,6 +182,7 @@ func processLayout(dir string) error {
 
 	walk := func(path string, info os.FileInfo, err error) error {
 		if err != nil {
+			utils.Err(err)
 			return err
 		}
 
@@ -268,6 +271,7 @@ func ProcessImg(dir string) error {
 		// get the html string
 		data, err := os.ReadFile(fp)
 		if err != nil {
+			utils.Err(err)
 			return err
 		}
 		htmlStr := string(data)
@@ -291,6 +295,7 @@ func ProcessImg(dir string) error {
 
 	walk := func(path string, info os.FileInfo, err error) error {
 		if err != nil {
+			utils.Err(err)
 			return err
 		}
 
@@ -314,6 +319,7 @@ func cleanBlocks(dir string) error {
 		// get the html string
 		data, err := os.ReadFile(fp)
 		if err != nil {
+			utils.Err(err)
 			return err
 		}
 		htmlStr := string(data)
@@ -339,6 +345,7 @@ func cleanBlocks(dir string) error {
 
 	walk := func(dir string, info os.FileInfo, err error) error {
 		if err != nil {
+			utils.Err(err)
 			return err
 		}
 

+ 3 - 3
gen/partial/partial.go

@@ -1,12 +1,12 @@
 package partial
 
 import (
-	"fmt"
 	"os"
 	"path"
 	"strings"
 
 	"git.clearsky.net.au/cody/gex.git/gen/domquery"
+	"git.clearsky.net.au/cody/gex.git/utils"
 )
 
 func ProcessHTML(htmlStr string, fp string) (string, error) {
@@ -30,7 +30,7 @@ func ProcessHTML(htmlStr string, fp string) (string, error) {
 
 	partFileByt, err := os.ReadFile(attr)
 	if err != nil {
-		fmt.Printf("ERROR: %s", err)
+		utils.Err(err)
 		return "", err
 	}
 	partFileHtml := string(partFileByt)
@@ -44,7 +44,7 @@ func ProcessHTML(htmlStr string, fp string) (string, error) {
 
 	htmlStr, err = ProcessHTML(dom.InnerHTML(), fp)
 	if err != nil {
-		fmt.Printf("ERROR: %s", err)
+		utils.Err(err)
 		return "", err
 	}
 	return htmlStr, nil

+ 63 - 17
sec/middleware.go

@@ -1,35 +1,81 @@
 package sec
 
 import (
+	"errors"
+
+	"git.clearsky.net.au/cody/gex.git/sess"
 	"git.clearsky.net.au/cody/gex.git/srv"
+	"git.clearsky.net.au/cody/gex.git/utils"
 )
 
-func Middleware(req *srv.Req, res srv.Res) bool {
-	//Session
+func Middleware(req *srv.Req, res *srv.Res) error {
 
-	var sess Sess
-	sess.Construct(req, res)
-	sess.Save()
-	req.Ctx["Sess"] = sess
+	sess, err := sess.GetCtxSess(req)
+	if err != nil {
+		utils.Err(err)
+		return err
+	}
 
-	pattern := req.Pattern
+	// default auth
+	auth := Auth{0, "Guest", []string{"Guest", "Everyone"}}
 
-	// Route Access Check
+	// if auth context exists, convert it to an Auth type
+	if sess.Data["Auth"] != nil {
+		sessAuth, ok := sess.Data["Auth"].(map[string]any)
+		if !ok {
+			err := errors.New("auth context in session data is not of the expected type, request cancelled")
+			utils.Err(err)
+			return err
+		}
+
+		auth.User_id, ok = sessAuth["User_id"].(float64)
+		if !ok {
+			err := errors.New("auth context in session data is not of the expected type, request cancelled")
+			utils.Err(err)
+			return err
+		}
+
+		auth.User_name, ok = sessAuth["User_name"].(string)
+		if !ok {
+			err := errors.New("auth context in session data is not of the expected type, request cancelled")
+			utils.Err(err)
+			return err
+		}
+
+		sessAuthRoles, ok := sessAuth["Roles"].([]any)
+		if !ok {
+			err := errors.New("auth context in session data is not of the expected type, request cancelled")
+			utils.Err(err)
+			return err
+		}
 
-	// Cancel the security check as there are no permissions for this route
-	if len(permissions[pattern]) == 0 {
-		return true
+		auth.Roles = []string{}
+		for _, v := range sessAuthRoles {
+			val, ok := v.(string)
+			if !ok {
+				err := errors.New("auth context in session data is not of the expected type, request cancelled")
+				utils.Err(err)
+				return err
+			}
+			auth.Roles = append(auth.Roles, val)
+		}
 	}
 
-	handlerRoles := permissions[pattern]
-	sess = req.Ctx["Sess"].(Sess)
+	sess.Data["Auth"] = auth
+	sess.Save()
+
+	// Route Access Check
+	pattern := req.Pattern
+	if permissions[pattern] == nil {
+		return nil
+	}
 
-	for _, role := range handlerRoles {
-		if sess.HasRole(role) {
-			return true
+	for _, val := range permissions[pattern] {
+		if auth.HasRole(val) {
+			return nil
 		}
 	}
 
 	res.Send("No Access")
-	return false
+	return nil
 }

+ 42 - 1
sec/sec.go

@@ -4,14 +4,55 @@ import (
 	"crypto/hmac"
 	"crypto/sha256"
 	"encoding/hex"
+	"errors"
 	"math/rand"
 
+	"git.clearsky.net.au/cody/gex.git/sess"
 	"git.clearsky.net.au/cody/gex.git/srv"
+	"git.clearsky.net.au/cody/gex.git/utils"
 )
 
+type Auth struct {
+	User_id   float64
+	User_name string
+	Roles     []string
+}
+
+func (auth *Auth) HasRole(role string) bool {
+	for _, val := range auth.Roles {
+		if role == val {
+			return true
+		}
+	}
+	return false
+}
+
 var permissions = make(map[string][]string)
 
-func Route(pattern string, roles []string, handler func(req *srv.Req, res srv.Res)) {
+func GetCtxAuth(req *srv.Req) (Auth, error) {
+	sess, err := sess.GetCtxSess(req)
+	if err != nil {
+		utils.Err(err)
+		return Auth{}, err
+	}
+
+	if sess.Data["Auth"] == nil {
+		err := errors.New("no auth context in session data")
+		utils.Err(err)
+		return Auth{}, err
+	}
+
+	auth, ok := sess.Data["Auth"].(Auth)
+	if !ok {
+		err := errors.New("auth context in session data is not of the expected type")
+		utils.Err(err)
+		return Auth{}, err
+	}
+
+	return auth, nil
+}
+
+func Route(pattern string, roles []string, handler func(req *srv.Req, res *srv.Res)) {
 	permissions[pattern] = roles
 	srv.Route(pattern, handler)
 }

+ 0 - 108
sec/sess.go

@@ -1,108 +0,0 @@
-package sec
-
-import (
-	"encoding/json"
-	"fmt"
-
-	"git.clearsky.net.au/cody/gex.git/sec/jwt"
-	"git.clearsky.net.au/cody/gex.git/srv"
-
-	"time"
-)
-
-type Sess struct {
-	req       *srv.Req
-	res       srv.Res
-	User_id   int
-	User_name string
-	Roles     []string
-	Expires   time.Time
-	Props     map[string]any
-}
-
-var TokenName string = "GexToken"
-var Expires time.Time = time.Now().Add(24 * time.Hour)
-var Secret string = "secret"
-
-func (sess *Sess) setDefaults() {
-	sess.User_id = 0
-	sess.User_name = "Guest"
-	sess.Roles = []string{"Guest", "Everyone"}
-	sess.Expires = Expires
-}
-
-func (sess *Sess) Construct(req *srv.Req, res srv.Res) {
-	sess.req = req
-	sess.res = res
-	sess.Props = make(map[string]any)
-	sess.setDefaults()
-
-	// check cookie is valid (not expired too)
-	cookie, err := req.Cookie(TokenName)
-
-	if err != nil {
-		//fmt.Println("cookie error")
-		return
-	}
-
-	// decode jwt to json bytes
-	jsonByt, err := jwt.Decode(cookie, Secret)
-	if err != nil {
-		fmt.Println("jwt decode error")
-		return
-	}
-
-	// decode json bytes to session
-	err = json.Unmarshal(jsonByt, &sess)
-	if err != nil {
-		fmt.Println("jwt to session error")
-		return
-	}
-
-	// if session token has expired, return default session
-	if time.Now().After(sess.Expires) {
-		fmt.Println("session expired")
-		sess.setDefaults()
-	}
-
-	sess.Expires = Expires
-
-}
-
-func (sess *Sess) Token() (string, error) {
-	jsonStr, err := json.Marshal(sess)
-	if err != nil {
-		return "", err
-	}
-
-	// encode the json to jwt and set the cookie
-	token, err := jwt.Encode(jsonStr, Secret)
-	if err != nil {
-		return "", err
-	}
-	return token, nil
-}
-
-func (sess *Sess) HasRole(roleName string) bool {
-	for _, role := range sess.Roles {
-		if role == roleName {
-			return true
-		}
-	}
-	return false
-}
-
-// Saves token to cookie
-func (sess *Sess) Save() {
-	// get existing session or create new one
-
-	// get the session token
-	token, err := sess.Token()
-	if err != nil {
-		sess.res.Send(err.Error())
-		return
-	}
-
-	// set the token cookie
-	sess.res.Cookie(TokenName, token)
-}

+ 0 - 0
sec/base64/base64url.go → sess/base64/base64url.go


+ 9 - 4
sec/jwt/jwt.go → sess/jwt/jwt.go

@@ -5,10 +5,10 @@ import (
 	"crypto/hmac"
 	"crypto/sha256"
 	"errors"
-	"fmt"
 	"strings"
 
-	"git.clearsky.net.au/cody/gex.git/sec/base64"
+	"git.clearsky.net.au/cody/gex.git/sess/base64"
+	"git.clearsky.net.au/cody/gex.git/utils"
 )
 
 // encode json bytes to a jwt token string
@@ -26,11 +26,16 @@ func Encode(jsonStr []byte, secret string) (string, error) {
 // decode a jwt token string to a json string to be processed
 func Decode(tokenStr string, secret string) ([]byte, error) {
 	parts := strings.Split(tokenStr, ".")
+	if len(parts) < 3 {
+		err := errors.New("cannot decode JWT")
+		utils.Err(err)
+		return []byte(""), err
+	}
 	header := parts[0]
 	payload := parts[1]
 	sig, err := base64.DecodeURL(parts[2])
 	if err != nil {
-		fmt.Printf("ERROR: %s", err)
+		utils.Err(err)
 		return []byte(""), err
 	}
 
@@ -45,7 +50,7 @@ func Decode(tokenStr string, secret string) ([]byte, error) {
 	jsonStr, err := base64.DecodeURL(payload)
 
 	if err != nil {
-		fmt.Printf("ERROR: %s", err)
+		utils.Err(err)
 		return []byte(""), err
 	}
 

+ 11 - 0
sess/middleware.go

@@ -0,0 +1,11 @@
+package sess
+
+import "git.clearsky.net.au/cody/gex.git/srv"
+
+func Middleware(req *srv.Req, res *srv.Res) error {
+	var sess Sess
+	sess.Construct(req, res)
+	sess.Save()
+	req.Ctx["Sess"] = &sess
+	return nil
+}

+ 113 - 0
sess/sess.go

@@ -0,0 +1,113 @@
+package sess
+
+import (
+	"encoding/json"
+	"errors"
+	"fmt"
+
+	"git.clearsky.net.au/cody/gex.git/sess/jwt"
+	"git.clearsky.net.au/cody/gex.git/srv"
+	"git.clearsky.net.au/cody/gex.git/utils"
+
+	"time"
+)
+
+type Sess struct {
+	req     *srv.Req
+	res     *srv.Res
+	Expires time.Time
+	Data    map[string]any
+}
+
+// config defaults
+var TOKENNAME string = "SessToken"
+var TIMEOUT time.Duration = 1 * time.Hour
+var SECRET string = "secret"
+
+func (sess *Sess) Construct(req *srv.Req, res *srv.Res) {
+	sess.req = req
+	sess.res = res
+
+	sess.setDefaults()
+
+	// check cookie is valid and not expired
+	cookie, err := req.Cookie(TOKENNAME)
+
+	if err != nil {
+		//utils.Err(err)
+		return
+	}
+
+	// decode jwt to json bytes
+	jsonByt, err := jwt.Decode(cookie, SECRET)
+	if err != nil {
+		utils.Err(err)
+		return
+	}
+
+	// decode json bytes to session
+	err = json.Unmarshal(jsonByt, &sess)
+	if err != nil {
+		utils.Err(err)
+		return
+	}
+
+	// if session token has expired, return default session
+	if time.Now().After(sess.Expires) {
+		if time.Now().After(sess.Expires.Add(TIMEOUT)) {
+			sess.Expires = time.Now().Add(20 * time.Minute)
+			return
+		}
+		fmt.Println("session expired")
+		sess.setDefaults()
+	}
+
+}
+
+func (sess *Sess) setDefaults() {
+	sess.Data = make(map[string]any)
+	sess.Expires = time.Now().Add(20 * time.Minute)
+}
+
+func (sess *Sess) Token() (string, error) {
+	jsonStr, err := json.Marshal(sess)
+	if err != nil {
+		return "", err
+	}
+
+	// encode the json to jwt and set the cookie
+	token, err := jwt.Encode(jsonStr, SECRET)
+	if err != nil {
+		return "", err
+	}
+	return token, nil
+}
+
+// Saves token to cookie
+func (sess *Sess) Save() {
+
+	// get the session token
+	token, err := sess.Token()
+	if err != nil {
+		sess.res.Send(err.Error())
+		return
+	}
+
+	// set the token cookie
+	sess.res.Cookie(TOKENNAME, token)
+}
+
+func GetCtxSess(req *srv.Req) (*Sess, error) {
+	if req.Ctx["Sess"] == nil {
+		err := errors.New("no session context, did you add the session middleware?")
+		utils.Err(err)
+		return nil, err
+	}
+	sess, ok := req.Ctx["Sess"].(*Sess)
+	if !ok {
+		err := errors.New("session from context is not of type *Sess")
+		utils.Err(err)
+		return nil, err
+	}
+	return sess, nil
+}

+ 1 - 1
srv/middleware.go

@@ -1,6 +1,6 @@
 package srv
 
-type middlewareHandler func(req *Req, res Res) bool
+type middlewareHandler func(req *Req, res *Res) error
 
 var middleware []middlewareHandler
 

+ 5 - 6
srv/req.go

@@ -36,12 +36,6 @@ func (req *Req) Param(key string) string {
 	return req.r.PathValue(key)
 }
 
-// The route pattern eg: /user
-func (req *Req) getPattern() string {
-	_, pattern := rtr.Handler(req.r)
-	return pattern
-}
-
 func (req *Req) Cookie(name string) (string, error) {
 	cookie, err := req.r.Cookie(name)
 	if err != nil {
@@ -62,3 +56,8 @@ func (req *Req) getParentPath() string {
 	}
 	return parentPath
 }
+
+func (req *Req) getPattern() string {
+	_, pattern := rtr.Handler(req.r)
+	return pattern
+}

+ 27 - 4
srv/res.go

@@ -7,16 +7,19 @@ import (
 )
 
 type Res struct {
-	w http.ResponseWriter
-	r *http.Request
+	w      http.ResponseWriter
+	r      *http.Request
+	status int
 }
 
 func (res *Res) Construct(w http.ResponseWriter, r *http.Request) {
 	res.w = w
 	res.r = r
+	res.status = 200
 }
 
 func (res *Res) Send(txt string) {
+	res.w.WriteHeader(res.status)
 	fmt.Fprint(res.w, txt)
 }
 
@@ -24,11 +27,30 @@ func (res *Res) Redirect(url string) {
 	http.Redirect(res.w, res.r, url, http.StatusTemporaryRedirect)
 }
 
-func (res *Res) Cookie(name string, val string) {
+func (res *Res) Status(code int) *Res {
+	res.status = code
+	return res
+}
+
+func (res *Res) Header(k string, v string) *Res {
+	res.w.Header().Add(k, v)
+	return res
+}
+
+func (res *Res) JSON(json string) {
+	res.Header("Content-type", "application/json").Send(json)
+	res.w.WriteHeader(res.status)
+}
+
+func (res *Res) Cookie(name string, val string, expires ...time.Time) *Res {
+	Expires := time.Now().Add(365 * 24 * time.Hour)
+	for _, v := range expires {
+		Expires = v
+	}
 	cookie := &http.Cookie{
 		Name:     name,
 		Value:    val,
-		Expires:  time.Now().Add(365 * 24 * time.Hour),
+		Expires:  Expires,
 		Path:     "/",
 		Secure:   true,
 		HttpOnly: true,
@@ -36,4 +58,5 @@ func (res *Res) Cookie(name string, val string) {
 	}
 
 	http.SetCookie(res.w, cookie)
+	return res
 }

+ 7 - 4
srv/rtr.go

@@ -2,6 +2,8 @@ package srv
 
 import (
 	"net/http"
+
+	"git.clearsky.net.au/cody/gex.git/utils"
 )
 
 var rtr *http.ServeMux
@@ -9,7 +11,7 @@ var rtr *http.ServeMux
 /*
 Our Special HandleFunc
 */
-func Route(pattern string, handler func(req *Req, res Res)) {
+func Route(pattern string, handler func(req *Req, res *Res)) {
 
 	realHandler := func(w http.ResponseWriter, r *http.Request) {
 
@@ -21,13 +23,14 @@ func Route(pattern string, handler func(req *Req, res Res)) {
 
 		// run middleware on each request
 		for _, fun := range middleware {
-			cont := fun(&req, res)
-			if !cont {
+			err := fun(&req, &res)
+			if err != nil {
+				utils.Err(err)
 				return
 			}
 		}
 
-		handler(&req, res)
+		handler(&req, &res)
 
 	}
 

+ 1 - 3
srv/srv.go

@@ -7,12 +7,10 @@ import (
 	"os"
 )
 
-var Config map[string]string
+var PORT string = "3000"
 
 func StartServer() {
-	var PORT string = os.Args[1]
 
-	Config = make(map[string]string)
 	server := http.Server{
 		Addr:    fmt.Sprintf(":%v", PORT),
 		Handler: rtr,

+ 15 - 0
utils/utils.go

@@ -1,11 +1,13 @@
 package utils
 
 import (
+	"fmt"
 	"log"
 	"os"
 	"runtime"
 	"strconv"
 	"strings"
+	"time"
 )
 
 func Cwd(i int) string {
@@ -16,6 +18,19 @@ func Cwd(i int) string {
 	return path
 }
 
+func Err(err error) {
+	_, fp, line, _ := runtime.Caller(1)
+
+	errStr := fmt.Sprintf("Error at %s:%d: %s\n", fp, line, err)
+	fmt.Printf("%s", errStr)
+
+	log, _ := os.OpenFile("error.log", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0655)
+	defer log.Close()
+	errLogStr := fmt.Sprintf("%s %s", time.Now(), errStr)
+	log.WriteString(errLogStr)
+
+}
+
 func CheckArgs() {
 	if len(os.Args) < 2 {
 		log.Fatalf("ERROR: port number required: eg ./bin 3000")