Эх сурвалжийг харах

Session & auth now built in

Cody Joyce 2 сар өмнө
parent
commit
a0a7992b14
12 өөрчлөгдсөн 116 нэмэгдсэн , 704 устгасан
  1. 0 495
      db/query.dis
  2. 0 81
      sec/middleware.go
  3. 0 75
      sec/sec.go
  4. 0 11
      sess/middleware.go
  5. 12 3
      srv/req.go
  6. 3 1
      srv/res.go
  7. 23 2
      srv/rtr.go
  8. 36 33
      srv/sess.go
  9. 20 2
      tpl/tpl.go
  10. 0 0
      utils/base64/base64url.go
  11. 1 1
      utils/jwt/jwt.go
  12. 21 0
      utils/utils.go

+ 0 - 495
db/query.dis

@@ -1,495 +0,0 @@
-package db
-
-import (
-	"errors"
-	"fmt"
-	"math"
-	"net/http"
-	"reflect"
-	"strconv"
-	"strings"
-
-	"git.clearsky.net.au/cody/gex.git/sec"
-)
-
-type Pagination struct {
-	PrevPage  string
-	NextPage  string
-	Page      int
-	Pages     int
-	PerPage   int
-	Order     string
-	Direction string
-	Search    string
-	Start     int
-}
-
-type Query struct {
-	targetInt interface{}
-	targetPtr reflect.Value
-	targetVal reflect.Value
-	modelType reflect.Type
-
-	tblName string
-	pkName  string
-	Auth    *sec.Auth
-	err     error
-	ready   bool
-
-	sel   string
-	where string
-	order string
-	vals  []any
-}
-
-func (q *Query) New(targetInt interface{}, auth *sec.Auth) *Query {
-
-	q.targetInt = targetInt
-	q.targetPtr = reflect.ValueOf(q.targetInt)
-
-	if q.targetPtr.Kind().String() != "ptr" {
-		q.err = errors.New("ERROR: db.Query.New(), targetInt must be a pointer")
-		return q
-	}
-
-	q.targetVal = q.targetPtr.Elem()
-	q.sel = "*"
-
-	// Get the secion
-	q.Auth = auth
-
-	if q.targetVal.Kind().String() == "slice" {
-		q.newSlice()
-		return q
-	}
-
-	//Get the type of the model
-	q.modelType = q.targetVal.Type()
-
-	// Get the table name
-	tblMth := q.targetPtr.MethodByName("Table")
-	if !tblMth.IsValid() {
-		q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined")
-		return q
-	}
-	q.tblName = tblMth.Call(nil)[0].String()
-
-	// Get the primary key name
-	pkMth := q.targetPtr.MethodByName("PK")
-	if !tblMth.IsValid() {
-		q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined")
-		return q
-	}
-	q.pkName = pkMth.Call(nil)[0].String()
-
-	q.ready = true
-	return q
-}
-
-func (q *Query) newSlice() {
-	// Get the type of the model
-	q.modelType = q.targetVal.Type().Elem()
-
-	// create a new model so we can run the table and pk methods
-	tmpModel := reflect.New(q.modelType)
-
-	// Get the table name
-	tblMth := tmpModel.MethodByName("Table")
-	if !tblMth.IsValid() {
-		q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined")
-	}
-	q.tblName = tblMth.Call(nil)[0].String()
-
-	// Get the primary key name
-	pkMth := tmpModel.MethodByName("PK")
-	if !tblMth.IsValid() {
-		q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined")
-	}
-	q.pkName = pkMth.Call(nil)[0].String()
-	q.ready = true
-}
-
-func (q *Query) Reset() {
-	q.sel = "*"
-	q.where = ""
-
-	var new []any
-	q.vals = new
-}
-
-func (q *Query) Select(s string) *Query {
-	if q.err != nil {
-		return q
-	}
-
-	if !q.ready {
-		q.err = errors.New("ERROR: db.Query.Select(), db.Query.New() needs to be ran first")
-		return q
-	}
-
-	q.sel = s
-	return q
-}
-
-func (q *Query) Where(str string, vals ...any) *Query {
-	if q.err != nil {
-		return q
-	}
-
-	if !q.ready {
-		q.err = errors.New("ERROR: db.Query.Where(), db.Query.New() needs to run first")
-		return q
-	}
-
-	q.where = fmt.Sprintf("WHERE (%s) ", str)
-	q.vals = append(q.vals, vals...)
-	return q
-}
-
-// Security check here to ensure str is a field (column)
-func (q *Query) Order(str string) *Query {
-	q.order += "ORDER BY " + str
-	return q
-}
-
-func (q *Query) Create() (int64, error) {
-	if q.err != nil {
-		return 0, q.err
-	}
-	if !q.ready {
-		q.err = errors.New("ERROR: db.Query.Create(), db.Query.New() needs to run first")
-		return 0, q.err
-	}
-
-	if q.targetVal.Kind().String() != "struct" {
-		return 0, errors.New("ERROR: db.Query.Create() model must be a struct (are you using a slice)")
-	}
-
-	fields, plcHdr, vals := q.getFields(q.targetPtr.Interface())
-
-	for i := 0; i < len(vals); i++ {
-		if vals[i] == "<nil>" {
-			vals[i] = nil
-		}
-	}
-
-	query := fmt.Sprintf(
-		"INSERT INTO %s (%s) VALUES (%s)", q.tblName, strings.Join(fields, ", "), strings.Join(plcHdr, ", "),
-	)
-
-	q.targetVal.FieldByName("sec").Set(reflect.ValueOf(&q.Auth))
-	createRuleMethod := q.targetPtr.MethodByName("CreateRule")
-	createRulePassed := createRuleMethod.Call(nil)[0].Bool()
-
-	if !createRulePassed {
-		return 0, errors.New("ERROR: db.Query.Create(), create permissions not passed")
-	}
-
-	res, err := Conn.Exec(query, vals...)
-
-	if err != nil {
-		stringVals := make([]string, len(vals))
-		for i, v := range vals {
-			stringVals[i] = fmt.Sprint(v)
-		}
-		fmt.Printf("QUERY ERROR: db.Create(), Query: `%s` Values: `%s` Error: %s ", query, strings.Join(stringVals, ","), err.Error())
-
-		return 0, err
-	}
-
-	lastId, _ := res.LastInsertId()
-
-	q.targetVal.FieldByName(q.pkName).Set(reflect.ValueOf(int(lastId)))
-	return lastId, nil
-}
-
-func (q *Query) Read(id int) error {
-	if q.err != nil {
-		return q.err
-	}
-
-	if !q.ready {
-		q.err = errors.New("ERROR: db.Query.Read(), db.Query.New() needs to run first")
-		return q.err
-	}
-
-	if q.targetVal.Kind().String() != "struct" {
-		return errors.New("ERROR: db.Query.Read(), model must be a struct (are you using a slice)")
-	}
-
-	err := q.Where(q.pkName+" = ?", id).Get()
-	if err != nil {
-		return err
-	}
-	return nil
-}
-
-func (q *Query) Update() error {
-	if q.err != nil {
-		return q.err
-	}
-	if !q.ready {
-		q.err = errors.New("ERROR: db.Query.Update(), db.Query.New() needs to run first")
-		return q.err
-	}
-
-	if q.targetVal.Kind().String() != "struct" {
-		return errors.New("ERROR: db.Query.Update() model must be a struct (are you using a slice)")
-	}
-
-	// Check the existing data from db to check rule
-	id := q.targetVal.FieldByName(q.pkName)
-	tmpModel := reflect.New(q.modelType)
-
-	tmpQuery := Query{}
-	q.err = tmpQuery.New(tmpModel.Interface(), q.Auth).Read(int(id.Int()))
-
-	if q.err != nil {
-		return q.err
-	}
-
-	existRuleMethod := tmpModel.MethodByName("UpdateRule")
-	existRulePassed := existRuleMethod.Call(nil)[0].Bool()
-
-	newRuleMethod := q.targetPtr.MethodByName("UpdateRule")
-	newRulePassed := newRuleMethod.Call(nil)[0].Bool()
-
-	if !existRulePassed || !newRulePassed {
-		return errors.New("ERROR: db.Query.Update() ReadRule not passed)")
-	}
-
-	fields, _, vals := q.getFields(q.targetInt)
-
-	for i := 0; i < len(fields); i++ {
-		fields[i] = fields[i] + "=?"
-	}
-
-	for i := 0; i < len(vals); i++ {
-		if vals[i] == "<nil>" {
-			vals[i] = nil
-		}
-	}
-
-	vals = append(vals, int(id.Int()))
-	query := fmt.Sprintf(
-		"UPDATE %s SET %s WHERE %s=?", q.tblName, strings.Join(fields, ", "), q.pkName,
-	)
-
-	_, err := Conn.Exec(query, vals...)
-	if err != nil {
-		fmt.Printf("--QUERY ERROR (db.Update()): `%s` %s ", query, err.Error())
-		fmt.Println("\nQuery Values: ", vals)
-		q.Reset()
-		return err
-	}
-
-	return nil
-}
-
-func (q *Query) Delete() error {
-	if q.err != nil {
-		return q.err
-	}
-	if !q.ready {
-		q.err = errors.New("db.Query.New() needs to run first")
-		return q.err
-	}
-	if q.targetVal.Kind().String() != "struct" {
-		return errors.New("db.Query.Delete() model must be a struct (are you trying a slice or array?)")
-	}
-
-	// Check the existing data from db to check rule
-	id := q.targetVal.FieldByName(q.pkName)
-	tmpModel := reflect.New(q.targetVal.Type())
-
-	tmpQuery := Query{}
-	tmpQuery.New(tmpModel.Interface(), q.Auth).Read(int(id.Int()))
-
-	existRuleMethod := tmpModel.MethodByName("DeleteRule")
-	existRulePassed := existRuleMethod.Call(nil)[0].Bool()
-
-	if !existRulePassed {
-		err := "Update rule failed"
-		fmt.Println(err)
-		return errors.New(err)
-	}
-
-	query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", q.tblName, q.pkName)
-
-	_, err := Conn.Exec(query, int(id.Int()))
-
-	if err != nil {
-		return errors.New("db.Query.Get(), " + err.Error())
-	}
-	q.Reset()
-	return nil
-}
-
-func (q *Query) Get() error {
-
-	q.targetVal.FieldByName("Auth").Set(reflect.ValueOf(q.Auth))
-
-	// Create a temp model to store the data
-	tmpModel := reflect.New(q.modelType)
-	tmpModelElem := tmpModel.Elem()
-	tmpModelElem.FieldByName("Auth").Set(reflect.ValueOf(q.Auth))
-	//
-
-	query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
-
-	err := Conn.QueryRow(query, q.vals...).Scan(StructPointers(tmpModel.Interface())...)
-	if err != nil {
-		return errors.New("db.Query.Get(), " + err.Error())
-	}
-
-	readRuleMethod := tmpModel.MethodByName("ReadRule")
-	readRulePassed := readRuleMethod.Call(nil)[0].Bool()
-
-	if !readRulePassed {
-		return errors.New("ERROR: db.Query.Get(), read rules not passed")
-	}
-	for i := 0; i < tmpModelElem.NumField(); i++ {
-		q.targetVal.Field(i).Set(tmpModelElem.Field(i))
-	}
-	return nil
-}
-
-func (q *Query) All() error {
-
-	query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
-
-	rows, err := Conn.Query(query, q.vals...)
-
-	if err != nil {
-		return err
-	}
-
-	var set []reflect.Value
-	for rows.Next() {
-		tmpModel := reflect.New(q.modelType)
-		tmpModel.Elem().FieldByName("sec").Set(reflect.ValueOf(q.Auth))
-		rows.Scan(StructPointers(tmpModel.Interface())...)
-
-		readRuleMethod := tmpModel.MethodByName("ReadRule")
-		readRulePassed := readRuleMethod.Call(nil)[0].Bool()
-
-		if !readRulePassed {
-			return errors.New("ERROR: db.Query.All(), read rules not passed")
-		}
-		set = append(set, tmpModel.Elem())
-	}
-	rows.Close()
-
-	newSlice := reflect.Append(q.targetVal, set...)
-	q.targetVal.Set(newSlice)
-	return nil
-}
-
-func (q *Query) getFields(s interface{}) ([]string, []string, []any) {
-	targetPtr := reflect.ValueOf(s)
-	targetVal := targetPtr.Elem()
-
-	numFields := targetVal.NumField()
-
-	var fields []string
-	var plcHdr []string
-	var vals []any
-
-	for i := 0; i < numFields; i++ {
-		fieldName := targetVal.Type().Field(i).Name
-		if fieldName == q.pkName {
-			continue
-		}
-
-		fieldVal := targetVal.Field(i)
-
-		tag := targetVal.Type().Field(i).Tag
-
-		if tag != "" {
-			continue
-		}
-
-		if fieldVal.Kind() == reflect.Struct {
-			fieldsVal, plcHdrVal, valsVal := q.getFields(fieldVal.Addr().Interface())
-			fields = append(fields, fieldsVal...)
-			plcHdr = append(plcHdr, plcHdrVal...)
-			vals = append(vals, valsVal...)
-			continue
-		}
-
-		fields = append(fields, fmt.Sprintf("%v", fieldName))
-		plcHdr = append(plcHdr, "?")
-
-		vals = append(vals, fmt.Sprintf("%v", fieldVal))
-	}
-
-	return fields, plcHdr, vals
-}
-
-func (q *Query) SortParams(r *http.Request, defaults *Pagination) {
-	if r.FormValue("prevPage") != "" {
-		defaults.PrevPage = r.FormValue("prevPage")
-	}
-
-	if r.FormValue("nextPage") != "" {
-		defaults.NextPage = r.FormValue("nextPage")
-	}
-
-	if r.FormValue("page") != "" {
-		defaults.Page, _ = strconv.Atoi(r.FormValue("page"))
-	}
-
-	if r.FormValue("perPage") != "" {
-		defaults.PerPage, _ = strconv.Atoi(r.FormValue("perPage"))
-	}
-
-	if r.FormValue("order") != "" {
-		defaults.Order = r.FormValue("order")
-	}
-
-	if r.FormValue("direction") != "" {
-		defaults.Direction = r.FormValue("direction")
-	}
-
-	defaults.Search = r.FormValue("search")
-
-}
-
-func (q *Query) Sort(defaults *Pagination) *Query {
-
-	var count int
-
-	query := fmt.Sprintf("select count(*) as count from %s %s", q.tblName, q.where)
-	_ = Conn.QueryRow(query, q.vals...).Scan(&count)
-
-	pages := math.Ceil(float64(count) / float64(defaults.PerPage))
-	defaults.Pages = int(pages)
-
-	if defaults.PrevPage != "" {
-		defaults.Page--
-	}
-
-	if defaults.NextPage != "" {
-		defaults.Page++
-	}
-
-	if defaults.Page <= 0 {
-		defaults.Page = 1
-	}
-
-	if defaults.Page > defaults.Pages {
-		defaults.Page = defaults.Pages
-
-	}
-
-	defaults.Start = (defaults.Page - 1) * defaults.PerPage
-
-	if defaults.Start < 0 {
-		defaults.Start = 0
-	}
-
-	q.Order(defaults.Order + " " + defaults.Direction)
-	// this.limit(sort.start, sort.perPage);
-
-	return q
-}

+ 0 - 81
sec/middleware.go

@@ -1,81 +0,0 @@
-package sec
-
-import (
-	"errors"
-
-	"git.clearsky.net.au/cody/gex.git/sess"
-	"git.clearsky.net.au/cody/gex.git/srv"
-	"git.clearsky.net.au/cody/gex.git/utils"
-)
-
-func Middleware(req *srv.Req, res *srv.Res) error {
-
-	sess, err := sess.GetCtxSess(req)
-	if err != nil {
-		utils.Err(err)
-		return err
-	}
-
-	// default auth
-	auth := Auth{0, "Guest", []string{"Guest", "Everyone"}}
-
-	// if auth context exists, convert it to an Auth type
-	if sess.Data["Auth"] != nil {
-		sessAuth, ok := sess.Data["Auth"].(map[string]any)
-		if !ok {
-			err := errors.New("auth context in session data is not of the expected type, request cancelled")
-			utils.Err(err)
-			return err
-		}
-
-		auth.User_id, ok = sessAuth["User_id"].(float64)
-		if !ok {
-			err := errors.New("auth context in session data is not of the expected type, request cancelled")
-			utils.Err(err)
-			return err
-		}
-
-		auth.User_name, ok = sessAuth["User_name"].(string)
-		if !ok {
-			err := errors.New("auth context in session data is not of the expected type, request cancelled")
-			utils.Err(err)
-			return err
-		}
-
-		sessAuthRoles, ok := sessAuth["Roles"].([]any)
-		if !ok {
-			err := errors.New("auth context in session data is not of the expected type, request cancelled")
-			utils.Err(err)
-			return err
-		}
-
-		auth.Roles = []string{}
-		for _, v := range sessAuthRoles {
-			val, ok := v.(string)
-			if !ok {
-				err := errors.New("auth context in session data is not of the expected type, request cancelled")
-				utils.Err(err)
-				return err
-			}
-			auth.Roles = append(auth.Roles, val)
-		}
-	}
-
-	sess.Data["Auth"] = auth
-	sess.Save()
-
-	// Route Access Check
-	pattern := req.Pattern
-	if permissions[pattern] == nil {
-		return nil
-	}
-
-	for _, val := range permissions[pattern] {
-		if auth.HasRole(val) {
-			return nil
-		}
-	}
-
-	res.Send("No Access")
-	return errors.New("no access, request cancelled")
-}

+ 0 - 75
sec/sec.go

@@ -1,75 +0,0 @@
-package sec
-
-import (
-	"crypto/hmac"
-	"crypto/sha256"
-	"encoding/hex"
-	"errors"
-	"math/rand"
-
-	"git.clearsky.net.au/cody/gex.git/sess"
-	"git.clearsky.net.au/cody/gex.git/srv"
-	"git.clearsky.net.au/cody/gex.git/utils"
-)
-
-type Auth struct {
-	User_id   float64
-	User_name string
-	Roles     []string
-}
-
-func (auth *Auth) HasRole(role string) bool {
-	for _, val := range auth.Roles {
-		if role == val {
-			return true
-		}
-	}
-	return false
-}
-
-var permissions = make(map[string][]string)
-
-func GetCtxAuth(req *srv.Req) (Auth, error) {
-	sess, err := sess.GetCtxSess(req)
-	if err != nil {
-		utils.Err(err)
-		return Auth{}, err
-	}
-
-	if sess.Data["Auth"] == nil {
-		err := errors.New("no auth context in session data")
-		utils.Err(err)
-		return Auth{}, err
-	}
-
-	auth, ok := sess.Data["Auth"].(Auth)
-	if !ok {
-		err := errors.New("auth context in session data is not of the expected type")
-		utils.Err(err)
-		return Auth{}, err
-	}
-
-	return auth, nil
-}
-
-func Route(pattern string, roles []string, handler func(req *srv.Req, res *srv.Res)) {
-	permissions[pattern] = roles
-	srv.Route(pattern, handler)
-}
-
-func Hash(key string, data string) string {
-	bKey := []byte(key)
-	bData := []byte(data)
-	h := hmac.New(sha256.New, bKey)
-	h.Write(bData)
-	return hex.EncodeToString(h.Sum(nil))
-}
-
-func Salt(length int) string {
-	const charset = "1234567890-=!@#$%^&*()_+abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
-	b := make([]byte, length)
-	for i := range b {
-		b[i] = charset[rand.Intn(len(charset))]
-	}
-	return string(b)
-}

+ 0 - 11
sess/middleware.go

@@ -1,11 +0,0 @@
-package sess
-
-import "git.clearsky.net.au/cody/gex.git/srv"
-
-func Middleware(req *srv.Req, res *srv.Res) error {
-	var sess Sess
-	sess.Construct(req, res)
-	sess.Save()
-	req.Ctx["Sess"] = &sess
-	return nil
-}

+ 12 - 3
srv/req.go

@@ -11,17 +11,18 @@ type Req struct {
 	Path       string
 	ParentPath string
 	Pattern    string
-	Ctx        map[string]any
+	Sess       Sess
+	Data       map[string]string
 }
 
 func (req *Req) Construct(r *http.Request) {
-	r.ParseForm()
 	req.r = r
+	req.r.ParseForm()
 	req.AppURL = "/" + r.Host
 	req.Path = r.RequestURI
 	req.ParentPath = req.getParentPath()
 	req.Pattern = req.getPattern()
-	req.Ctx = make(map[string]any)
+	req.Data = make(map[string]string)
 }
 
 func (req *Req) Body(key string) string {
@@ -61,3 +62,11 @@ func (req *Req) getPattern() string {
 	_, pattern := rtr.Handler(req.r)
 	return pattern
 }
+
+func (req *Req) Set(key string, val string) {
+	req.Data[key] = val
+}
+
+func (req *Req) Get(key string) string {
+	return req.Data[key]
+}

+ 3 - 1
srv/res.go

@@ -31,6 +31,9 @@ func (res *Res) Send(txt string) {
 }
 
 func (res *Res) Redirect(url string) {
+	for _, v := range res.cookies {
+		http.SetCookie(res.w, v)
+	}
 	http.Redirect(res.w, res.r, url, http.StatusTemporaryRedirect)
 }
 
@@ -46,7 +49,6 @@ func (res *Res) Header(k string, v string) *Res {
 
 func (res *Res) JSON(json string) {
 	res.Header("Content-type", "application/json").Send(json)
-	res.w.WriteHeader(res.status)
 }
 
 func (res *Res) Cookie(name string, val string, expires ...time.Time) *Res {

+ 23 - 2
srv/rtr.go

@@ -7,21 +7,42 @@ import (
 )
 
 var rtr *http.ServeMux
+var permissions = make(map[string][]string)
 
 /*
 Our Special HandleFunc
 */
-func Route(pattern string, handler func(req *Req, res *Res)) {
+func Route(pattern string, roles []string, handler func(req *Req, res *Res)) {
+	permissions[pattern] = roles
 
 	realHandler := func(w http.ResponseWriter, r *http.Request) {
 
 		var req Req
 		var res Res
+		var sess Sess
 
 		req.Construct(r)
 		res.Construct(w, r)
+		sess.Construct(&req, &res)
+		req.Sess = sess
+		req.Sess.Save()
 
-		// run middleware on each request
+		// Route Access Check
+		proceed := false
+		pattern := req.Pattern
+		for _, val := range permissions[pattern] {
+			if sess.HasRole(val) {
+				proceed = true
+				break
+			}
+		}
+
+		if !proceed {
+			res.Redirect("/403")
+			return
+		}
+
+		// run middleware on request
 		for _, fun := range middleware {
 			err := fun(&req, &res)
 			if err != nil {

+ 36 - 33
sess/sess.go → srv/sess.go

@@ -1,37 +1,37 @@
-package sess
+package srv
 
 import (
 	"encoding/json"
-	"errors"
 	"fmt"
 
-	"git.clearsky.net.au/cody/gex.git/sess/jwt"
-	"git.clearsky.net.au/cody/gex.git/srv"
 	"git.clearsky.net.au/cody/gex.git/utils"
+	"git.clearsky.net.au/cody/gex.git/utils/jwt"
 
 	"time"
 )
 
 type Sess struct {
-	req     *srv.Req
-	res     *srv.Res
+	req     *Req
+	res     *Res
 	Expires time.Time
-	Data    map[string]any
+	User_id int
+	Roles   []string
+	Data    map[string]string
 }
 
 // config defaults
-var TOKENNAME string = "SessToken"
-var TIMEOUT time.Duration = 1 * time.Hour
-var SECRET string = "secret"
+var SESS_TOKEN string = "SessToken"
+var SESS_TIMEOUT time.Duration = 1 * time.Hour
+var SESS_SECRET string = "secret"
 
-func (sess *Sess) Construct(req *srv.Req, res *srv.Res) {
+func (sess *Sess) Construct(req *Req, res *Res) {
 	sess.req = req
 	sess.res = res
-
-	sess.setDefaults()
+	sess.Data = make(map[string]string)
+	sess.SetDefaults()
 
 	// check cookie is valid and not expired
-	cookie, err := req.Cookie(TOKENNAME)
+	cookie, err := req.Cookie(SESS_TOKEN)
 
 	if err != nil {
 		//utils.Err(err)
@@ -39,7 +39,7 @@ func (sess *Sess) Construct(req *srv.Req, res *srv.Res) {
 	}
 
 	// decode jwt to json bytes
-	jsonByt, err := jwt.Decode(cookie, SECRET)
+	jsonByt, err := jwt.Decode(cookie, SESS_SECRET)
 	if err != nil {
 		utils.Err(err)
 		return
@@ -54,19 +54,20 @@ func (sess *Sess) Construct(req *srv.Req, res *srv.Res) {
 
 	// if session token has expired, return default session
 	if time.Now().After(sess.Expires) {
-		if time.Now().After(sess.Expires.Add(TIMEOUT)) {
+		if time.Now().After(sess.Expires.Add(SESS_TIMEOUT)) {
 			sess.Expires = time.Now().Add(20 * time.Minute)
 			return
 		}
 		fmt.Println("session expired")
-		sess.setDefaults()
+		sess.SetDefaults()
 	}
 
 }
 
-func (sess *Sess) setDefaults() {
-	sess.Data = make(map[string]any)
+func (sess *Sess) SetDefaults() {
 	sess.Expires = time.Now().Add(20 * time.Minute)
+	sess.User_id = 0
+	sess.Roles = []string{"Guest", "Everyone"}
 }
 
 func (sess *Sess) Token() (string, error) {
@@ -76,7 +77,7 @@ func (sess *Sess) Token() (string, error) {
 	}
 
 	// encode the json to jwt and set the cookie
-	token, err := jwt.Encode(jsonStr, SECRET)
+	token, err := jwt.Encode(jsonStr, SESS_SECRET)
 	if err != nil {
 		return "", err
 	}
@@ -94,20 +95,22 @@ func (sess *Sess) Save() {
 	}
 
 	// set the token cookie
-	sess.res.Cookie(TOKENNAME, token)
+	sess.res.Cookie(SESS_TOKEN, token)
 }
 
-func GetCtxSess(req *srv.Req) (*Sess, error) {
-	if req.Ctx["Sess"] == nil {
-		err := errors.New("no session context, did you add the session middleware?")
-		utils.Err(err)
-		return nil, err
-	}
-	sess, ok := req.Ctx["Sess"].(*Sess)
-	if !ok {
-		err := errors.New("session from context is not of type *Sess")
-		utils.Err(err)
-		return nil, err
+func (sess *Sess) HasRole(role string) bool {
+	for _, val := range sess.Roles {
+		if role == val {
+			return true
+		}
 	}
-	return sess, nil
+	return false
+}
+
+func (sess *Sess) Set(key string, val string) {
+	sess.Data[key] = val
+}
+
+func (sess *Sess) Get(key string) string {
+	return sess.Data[key]
 }

+ 20 - 2
tpl/tpl.go

@@ -2,6 +2,7 @@ package tpl
 
 import (
 	"bytes"
+	"errors"
 	"html/template"
 	"path/filepath"
 	"strings"
@@ -28,12 +29,27 @@ func Render(fp string, req *srv.Req, propFunc ...interface{}) string {
 	var funcs = Funcs{}
 
 	if len(propFunc) > 0 {
-		props = propFunc[0].(Props)
+		var ok bool
+		props, ok = propFunc[0].(Props)
+		if !ok {
+			err := errors.New("invalid props type passed to Render(), must be map[string]any")
+			utils.Err(err)
+			return err.Error()
+		}
 	}
+
 	if len(propFunc) > 1 {
-		funcs = propFunc[1].(Funcs)
+		var ok bool
+		funcs, ok = propFunc[1].(Funcs)
+		if !ok {
+			err := errors.New("invalid funcs type passed to Render(), must be map[string]func (template.funcMap)")
+			utils.Err(err)
+			return err.Error()
+		}
 	}
 
+	props["Req"] = req
+
 	Include(req, props, funcs)
 
 	tmpl, err := template.
@@ -42,6 +58,7 @@ func Render(fp string, req *srv.Req, propFunc ...interface{}) string {
 		ParseFiles(fp)
 
 	if err != nil {
+		utils.Err(err)
 		return err.Error()
 	}
 
@@ -49,6 +66,7 @@ func Render(fp string, req *srv.Req, propFunc ...interface{}) string {
 	err = tmpl.Execute(&buf, props)
 
 	if err != nil {
+		utils.Err(err)
 		return err.Error()
 	}
 

+ 0 - 0
sess/base64/base64url.go → utils/base64/base64url.go


+ 1 - 1
sess/jwt/jwt.go → utils/jwt/jwt.go

@@ -7,8 +7,8 @@ import (
 	"errors"
 	"strings"
 
-	"git.clearsky.net.au/cody/gex.git/sess/base64"
 	"git.clearsky.net.au/cody/gex.git/utils"
+	"git.clearsky.net.au/cody/gex.git/utils/base64"
 )
 
 // encode json bytes to a jwt token string

+ 21 - 0
utils/utils.go

@@ -1,8 +1,12 @@
 package utils
 
 import (
+	"crypto/hmac"
+	"crypto/sha256"
+	"encoding/hex"
 	"fmt"
 	"log"
+	"math/rand"
 	"os"
 	"runtime"
 	"strconv"
@@ -31,6 +35,23 @@ func Err(err error) {
 
 }
 
+func Hash(key string, data string) string {
+	bKey := []byte(key)
+	bData := []byte(data)
+	h := hmac.New(sha256.New, bKey)
+	h.Write(bData)
+	return hex.EncodeToString(h.Sum(nil))
+}
+
+func Salt(length int) string {
+	const charset = "1234567890-=!@#$%^&*()_+abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+	b := make([]byte, length)
+	for i := range b {
+		b[i] = charset[rand.Intn(len(charset))]
+	}
+	return string(b)
+}
+
 func CheckArgs() {
 	if len(os.Args) < 2 {
 		log.Fatalf("ERROR: port number required: eg ./bin 3000")