Cody Joyce 2 months ago
parent
commit
a541fa60b6
6 changed files with 617 additions and 1 deletions
  1. 62 0
      db/db.go
  2. 490 0
      db/query.go
  3. 36 0
      db/util.go
  4. 2 0
      go.mod
  5. 2 0
      go.sum
  6. 25 1
      sec/sec.go

+ 62 - 0
db/db.go

@@ -0,0 +1,62 @@
+package db
+
+import (
+	"database/sql"
+	"fmt"
+	"reflect"
+
+	_ "github.com/mattn/go-sqlite3"
+)
+
+var Conn *sql.DB
+
+func Connect(fp string) error {
+	var err error
+	Conn, err = sql.Open("sqlite3", fp)
+	if err != nil {
+		fmt.Println("db.Connect() Error connecting to database")
+		return err
+	}
+	Conn.Exec(`
+		PRAGMA journal_mode=ON
+		PRAGMA journal_mode=WAL;
+		PRAGMA synchronous=ON
+		PRAGMA auto_vacuum=ON
+		PRAGMA foreign_keys=ON
+	`)
+	return nil
+}
+
+// Pass fields struct as a pointer
+func Get(fieldsPtr interface{}, query string, vals []any) {
+	err := Conn.QueryRow(query, vals...).Scan(StructPointers(fieldsPtr)...)
+	if err != nil {
+		fmt.Printf("ERROR: %s\n", err)
+		return
+	}
+}
+
+func All(sliceInt interface{}, query string, vals []any) {
+	slicePtr := reflect.ValueOf(sliceInt)
+	sliceVal := slicePtr.Elem()
+	modelType := sliceVal.Type().Elem()
+
+	rows, _ := Conn.Query(query, vals...)
+
+	var set []reflect.Value
+	for rows.Next() {
+		tmpModel := reflect.New(modelType)
+
+		rows.Scan(StructPointers(tmpModel.Interface())...)
+
+		set = append(set, tmpModel.Elem())
+	}
+	rows.Close()
+
+	newSlice := reflect.Append(sliceVal, set...)
+	sliceVal.Set(newSlice)
+}
+
+func Close() {
+	Conn.Close()
+}

+ 490 - 0
db/query.go

@@ -0,0 +1,490 @@
+package db
+
+import (
+	"errors"
+	"fmt"
+	"math"
+	"net/http"
+	"reflect"
+	"strconv"
+	"strings"
+
+	"git.clearsky.net.au/cody/gex.git/sec"
+)
+
+type Pagination struct {
+	PrevPage  string
+	NextPage  string
+	Page      int
+	Pages     int
+	PerPage   int
+	Order     string
+	Direction string
+	Search    string
+	Start     int
+}
+
+type Query struct {
+	targetInt interface{}
+	targetPtr reflect.Value
+	targetVal reflect.Value
+	modelType reflect.Type
+
+	tblName string
+	pkName  string
+	Sess    sec.Sess
+	err     error
+	ready   bool
+
+	sel   string
+	where string
+	order string
+	vals  []any
+}
+
+func (q *Query) New(targetInt interface{}, sess sec.Sess) *Query {
+	q.targetInt = targetInt
+	q.targetPtr = reflect.ValueOf(q.targetInt)
+
+	if q.targetPtr.Kind().String() != "ptr" {
+		q.err = errors.New("ERROR: db.Query.New(), targetInt must be a pointer")
+		return q
+	}
+
+	q.targetVal = q.targetPtr.Elem()
+	q.sel = "*"
+
+	// Get the session
+	q.Sess = sess
+
+	if q.targetVal.Kind().String() == "slice" {
+		q.newSlice()
+		return q
+	}
+
+	//Get the type of the model
+	q.modelType = q.targetVal.Type()
+
+	// Get the table name
+	tblMth := q.targetPtr.MethodByName("Table")
+	if !tblMth.IsValid() {
+		q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined")
+		return q
+	}
+	q.tblName = tblMth.Call(nil)[0].String()
+
+	// Get the primary key name
+	pkMth := q.targetPtr.MethodByName("PK")
+	if !tblMth.IsValid() {
+		q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined")
+		return q
+	}
+	q.pkName = pkMth.Call(nil)[0].String()
+
+	q.ready = true
+	return q
+}
+
+func (q *Query) newSlice() {
+	// Get the type of the model
+	q.modelType = q.targetVal.Type().Elem()
+
+	// create a new model so we can run the table and pk methods
+	tmpModel := reflect.New(q.modelType)
+
+	// Get the table name
+	tblMth := tmpModel.MethodByName("Table")
+	if !tblMth.IsValid() {
+		q.err = errors.New("ERROR: db.Query.New(), target Table() method must be defined")
+	}
+	q.tblName = tblMth.Call(nil)[0].String()
+
+	// Get the primary key name
+	pkMth := tmpModel.MethodByName("PK")
+	if !tblMth.IsValid() {
+		q.err = errors.New("ERROR: db.Query.New(), target PK() method must be defined")
+	}
+	q.pkName = pkMth.Call(nil)[0].String()
+	q.ready = true
+}
+
+func (q *Query) Reset() {
+	q.sel = "*"
+	q.where = ""
+
+	var new []any
+	q.vals = new
+}
+
+func (q *Query) Select(s string) *Query {
+	if q.err != nil {
+		return q
+	}
+
+	if !q.ready {
+		q.err = errors.New("ERROR: db.Query.Select(), db.Query.New() needs to be ran first")
+		return q
+	}
+
+	q.sel = s
+	return q
+}
+
+func (q *Query) Where(str string, vals ...any) *Query {
+	if q.err != nil {
+		return q
+	}
+
+	if !q.ready {
+		q.err = errors.New("ERROR: db.Query.Where(), db.Query.New() needs to run first")
+		return q
+	}
+
+	q.where = fmt.Sprintf("WHERE (%s) ", str)
+	q.vals = append(q.vals, vals...)
+	return q
+}
+
+// Security check here to ensure str is a field (column)
+func (q *Query) Order(str string) *Query {
+	q.order += "ORDER BY " + str
+	return q
+}
+
+func (q *Query) Create() (int64, error) {
+	if q.err != nil {
+		return 0, q.err
+	}
+	if !q.ready {
+		q.err = errors.New("ERROR: db.Query.Create(), db.Query.New() needs to run first")
+		return 0, q.err
+	}
+
+	if q.targetVal.Kind().String() != "struct" {
+		return 0, errors.New("ERROR: db.Query.Create() model must be a struct (are you using a slice)")
+	}
+
+	fields, plcHdr, vals := q.getFields(q.targetPtr.Interface())
+
+	for i := 0; i < len(vals); i++ {
+		if vals[i] == "<nil>" {
+			vals[i] = nil
+		}
+	}
+
+	query := fmt.Sprintf(
+		"INSERT INTO %s (%s) VALUES (%s)", q.tblName, strings.Join(fields, ", "), strings.Join(plcHdr, ", "),
+	)
+
+	q.targetVal.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
+	createRuleMethod := q.targetPtr.MethodByName("CreateRule")
+	createRulePassed := createRuleMethod.Call(nil)[0].Bool()
+
+	if !createRulePassed {
+		return 0, errors.New("ERROR: db.Query.Create(), create permissions not passed")
+	}
+
+	res, err := Conn.Exec(query, vals...)
+
+	if err != nil {
+		stringVals := make([]string, len(vals))
+		for i, v := range vals {
+			stringVals[i] = fmt.Sprint(v)
+		}
+		fmt.Printf("QUERY ERROR: db.Create(), Query: `%s` Values: `%s` Error: %s ", query, strings.Join(stringVals, ","), err.Error())
+
+		return 0, err
+	}
+
+	lastId, _ := res.LastInsertId()
+
+	q.targetVal.FieldByName(q.pkName).Set(reflect.ValueOf(int(lastId)))
+	return lastId, nil
+}
+
+func (q *Query) Read(id int) error {
+	if q.err != nil {
+		return q.err
+	}
+
+	if !q.ready {
+		q.err = errors.New("ERROR: db.Query.Read(), db.Query.New() needs to run first")
+		return q.err
+	}
+
+	if q.targetVal.Kind().String() != "struct" {
+		return errors.New("ERROR: db.Query.Read(), model must be a struct (are you using a slice)")
+	}
+
+	err := q.Where(q.pkName+" = ?", id).Get()
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
+func (q *Query) Update() error {
+	if q.err != nil {
+		return q.err
+	}
+	if !q.ready {
+		q.err = errors.New("ERROR: db.Query.Update(), db.Query.New() needs to run first")
+		return q.err
+	}
+
+	if q.targetVal.Kind().String() != "struct" {
+		return errors.New("ERROR: db.Query.Update() model must be a struct (are you using a slice)")
+	}
+
+	// Check the existing data from db to check rule
+	id := q.targetVal.FieldByName(q.pkName)
+	tmpModel := reflect.New(q.modelType)
+
+	tmpQuery := Query{}
+	q.err = tmpQuery.New(tmpModel.Interface(), q.Sess).Read(int(id.Int()))
+
+	if q.err != nil {
+		return q.err
+	}
+
+	existRuleMethod := tmpModel.MethodByName("UpdateRule")
+	existRulePassed := existRuleMethod.Call(nil)[0].Bool()
+
+	newRuleMethod := q.targetPtr.MethodByName("UpdateRule")
+	newRulePassed := newRuleMethod.Call(nil)[0].Bool()
+
+	if !existRulePassed || !newRulePassed {
+		return errors.New("ERROR: db.Query.Update() ReadRule not passed)")
+	}
+
+	fields, _, vals := q.getFields(q.targetInt)
+
+	for i := 0; i < len(fields); i++ {
+		fields[i] = fields[i] + "=?"
+	}
+
+	for i := 0; i < len(vals); i++ {
+		if vals[i] == "<nil>" {
+			vals[i] = nil
+		}
+	}
+
+	vals = append(vals, int(id.Int()))
+	query := fmt.Sprintf(
+		"UPDATE %s SET %s WHERE %s=?", q.tblName, strings.Join(fields, ", "), q.pkName,
+	)
+
+	_, err := Conn.Exec(query, vals...)
+	if err != nil {
+		fmt.Printf("--QUERY ERROR (db.Update()): `%s` %s ", query, err.Error())
+		fmt.Println("\nQuery Values: ", vals)
+		q.Reset()
+		return err
+	}
+
+	return nil
+}
+
+func (q *Query) Delete() error {
+	if q.err != nil {
+		return q.err
+	}
+	if !q.ready {
+		q.err = errors.New("db.Query.New() needs to run first")
+		return q.err
+	}
+	if q.targetVal.Kind().String() != "struct" {
+		return errors.New("db.Query.Delete() model must be a struct (are you trying a slice or array?)")
+	}
+
+	// Check the existing data from db to check rule
+	id := q.targetVal.FieldByName(q.pkName)
+	tmpModel := reflect.New(q.targetVal.Type())
+
+	tmpQuery := Query{}
+	tmpQuery.New(tmpModel.Interface(), q.Sess).Read(int(id.Int()))
+
+	existRuleMethod := tmpModel.MethodByName("DeleteRule")
+	existRulePassed := existRuleMethod.Call(nil)[0].Bool()
+
+	if !existRulePassed {
+		err := "Update rule failed"
+		fmt.Println(err)
+		return errors.New(err)
+	}
+
+	query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", q.tblName, q.pkName)
+
+	_, err := Conn.Exec(query, int(id.Int()))
+
+	if err != nil {
+		return errors.New("db.Query.Get(), " + err.Error())
+	}
+	q.Reset()
+	return nil
+}
+
+func (q *Query) Get() error {
+	q.targetVal.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
+
+	// Create a temp model to store the data
+	tmpModel := reflect.New(q.modelType)
+	tmpModelElem := tmpModel.Elem()
+	tmpModelElem.FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
+	//
+
+	query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
+	err := Conn.QueryRow(query, q.vals...).Scan(StructPointers(tmpModel.Interface())...)
+	if err != nil {
+		return errors.New("db.Query.Get(), " + err.Error())
+	}
+
+	readRuleMethod := tmpModel.MethodByName("ReadRule")
+	readRulePassed := readRuleMethod.Call(nil)[0].Bool()
+
+	if !readRulePassed {
+		return errors.New("ERROR: db.Query.Get(), read rules not passed")
+	}
+	for i := 0; i < tmpModelElem.NumField(); i++ {
+		q.targetVal.Field(i).Set(tmpModelElem.Field(i))
+	}
+	return nil
+}
+
+func (q *Query) All() error {
+	query := fmt.Sprintf("SELECT %s FROM %s %s %s", q.sel, q.tblName, q.where, q.order)
+	rows, err := Conn.Query(query, q.vals...)
+
+	if err != nil {
+		return err
+	}
+
+	var set []reflect.Value
+	for rows.Next() {
+		tmpModel := reflect.New(q.modelType)
+		tmpModel.Elem().FieldByName("Sess").Set(reflect.ValueOf(&q.Sess))
+		rows.Scan(StructPointers(tmpModel.Interface())...)
+
+		readRuleMethod := tmpModel.MethodByName("ReadRule")
+		readRulePassed := readRuleMethod.Call(nil)[0].Bool()
+
+		if !readRulePassed {
+			return errors.New("ERROR: db.Query.All(), read rules not passed")
+		}
+		set = append(set, tmpModel.Elem())
+	}
+	rows.Close()
+
+	newSlice := reflect.Append(q.targetVal, set...)
+	q.targetVal.Set(newSlice)
+	return nil
+}
+
+func (q *Query) getFields(s interface{}) ([]string, []string, []any) {
+	targetPtr := reflect.ValueOf(s)
+	targetVal := targetPtr.Elem()
+
+	numFields := targetVal.NumField()
+
+	var fields []string
+	var plcHdr []string
+	var vals []any
+
+	for i := 0; i < numFields; i++ {
+		fieldName := targetVal.Type().Field(i).Name
+		if fieldName == q.pkName {
+			continue
+		}
+
+		fieldVal := targetVal.Field(i)
+
+		tag := targetVal.Type().Field(i).Tag
+
+		if tag != "" {
+			continue
+		}
+
+		if fieldVal.Kind() == reflect.Struct {
+			fieldsVal, plcHdrVal, valsVal := q.getFields(fieldVal.Addr().Interface())
+			fields = append(fields, fieldsVal...)
+			plcHdr = append(plcHdr, plcHdrVal...)
+			vals = append(vals, valsVal...)
+			continue
+		}
+
+		fields = append(fields, fmt.Sprintf("%v", fieldName))
+		plcHdr = append(plcHdr, "?")
+
+		vals = append(vals, fmt.Sprintf("%v", fieldVal))
+	}
+
+	return fields, plcHdr, vals
+}
+
+func (q *Query) SortParams(r *http.Request, defaults *Pagination) {
+	if r.FormValue("prevPage") != "" {
+		defaults.PrevPage = r.FormValue("prevPage")
+	}
+
+	if r.FormValue("nextPage") != "" {
+		defaults.NextPage = r.FormValue("nextPage")
+	}
+
+	if r.FormValue("page") != "" {
+		defaults.Page, _ = strconv.Atoi(r.FormValue("page"))
+	}
+
+	if r.FormValue("perPage") != "" {
+		defaults.PerPage, _ = strconv.Atoi(r.FormValue("perPage"))
+	}
+
+	if r.FormValue("order") != "" {
+		defaults.Order = r.FormValue("order")
+	}
+
+	if r.FormValue("direction") != "" {
+		defaults.Direction = r.FormValue("direction")
+	}
+
+	defaults.Search = r.FormValue("search")
+
+}
+
+func (q *Query) Sort(defaults *Pagination) *Query {
+
+	var count int
+
+	query := fmt.Sprintf("select count(*) as count from %s %s", q.tblName, q.where)
+	_ = Conn.QueryRow(query, q.vals...).Scan(&count)
+
+	pages := math.Ceil(float64(count) / float64(defaults.PerPage))
+	defaults.Pages = int(pages)
+
+	if defaults.PrevPage != "" {
+		defaults.Page--
+	}
+
+	if defaults.NextPage != "" {
+		defaults.Page++
+	}
+
+	if defaults.Page <= 0 {
+		defaults.Page = 1
+	}
+
+	if defaults.Page > defaults.Pages {
+		defaults.Page = defaults.Pages
+
+	}
+
+	defaults.Start = (defaults.Page - 1) * defaults.PerPage
+
+	if defaults.Start < 0 {
+		defaults.Start = 0
+	}
+
+	q.Order(defaults.Order + " " + defaults.Direction)
+	// this.limit(sort.start, sort.perPage);
+
+	return q
+}

+ 36 - 0
db/util.go

@@ -0,0 +1,36 @@
+package db
+
+import (
+	"reflect"
+)
+
+func StructPointers(s interface{}) []interface{} {
+	model := reflect.ValueOf(s)
+	modelElem := model.Elem()
+
+	numFields := modelElem.NumField()
+
+	var Ptrs []interface{}
+
+	for i := 0; i < numFields; i++ {
+		field := modelElem.Field(i)
+
+		tag := modelElem.Type().Field(i).Tag
+
+		if tag != "" {
+			continue
+		}
+
+		// Merge
+
+		if field.Kind() == reflect.Struct {
+			Ptrs = append(Ptrs, StructPointers(field.Addr().Interface())...)
+			continue
+		}
+
+		Ptrs = append(Ptrs, field.Addr().Interface())
+
+	}
+
+	return Ptrs
+}

+ 2 - 0
go.mod

@@ -1,3 +1,5 @@
 module git.clearsky.net.au/cody/gex.git
 
 go 1.23.4
+
+require github.com/mattn/go-sqlite3 v1.14.24

+ 2 - 0
go.sum

@@ -0,0 +1,2 @@
+github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
+github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=

+ 25 - 1
sec/sec.go

@@ -1,6 +1,13 @@
 package sec
 
-import "git.clearsky.net.au/cody/gex.git/srv"
+import (
+	"crypto/hmac"
+	"crypto/sha256"
+	"encoding/hex"
+	"math/rand"
+
+	"git.clearsky.net.au/cody/gex.git/srv"
+)
 
 var permissions = make(map[string][]string)
 
@@ -8,3 +15,20 @@ func Route(pattern string, roles []string, handler func(req *srv.Req, res srv.Re
 	permissions[pattern] = roles
 	srv.Route(pattern, handler)
 }
+
+func Hash(key string, data string) string {
+	bKey := []byte(key)
+	bData := []byte(data)
+	h := hmac.New(sha256.New, bKey)
+	h.Write(bData)
+	return hex.EncodeToString(h.Sum(nil))
+}
+
+func Salt(length int) string {
+	const charset = "1234567890-=!@#$%^&*()_+abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+	b := make([]byte, length)
+	for i := range b {
+		b[i] = charset[rand.Intn(len(charset))]
+	}
+	return string(b)
+}