Ver código fonte

feat: 支持嵌套结构的字段权限过滤

BaiLuoYan 1 semana atrás
pai
commit
955e05c3e0
2 arquivos alterados com 92 adições e 47 exclusões
  1. 48 34
      collect.go
  2. 44 13
      generate.go

+ 48 - 34
collect.go

@@ -18,21 +18,26 @@ type routePermDecl struct {
 	DataCode string
 }
 
-type fieldPermMap struct {
-	Request  map[string]string // json字段名 → permCode
-	Response map[string]string // json字段名 → permCode
+type fieldNode struct {
+	Fields map[string]string
+	Nested map[string]*fieldNode
 }
 
 type collectResult struct {
 	perms      []permDecl
 	routePerms []routePermDecl
-	fieldPerms map[string]fieldPermMap // "METHOD /path" → fieldPermMap
+	fieldPerms map[string]*fieldPermMap // "METHOD /path" → fieldPermMap
+}
+
+type fieldPermMap struct {
+	Request  *fieldNode
+	Response *fieldNode
 }
 
 func collect(input *PluginInput) *collectResult {
 	seen := make(map[string]bool)
 	result := &collectResult{
-		fieldPerms: make(map[string]fieldPermMap),
+		fieldPerms: make(map[string]*fieldPermMap),
 	}
 
 	if input == nil || input.Api == nil {
@@ -89,27 +94,23 @@ func collect(input *PluginInput) *collectResult {
 				DataCode: dataCode,
 			})
 
-			fm := fieldPermMap{
-				Request:  make(map[string]string),
-				Response: make(map[string]string),
-			}
-
+			var reqNode *fieldNode
 			if route.RequestType != nil {
-				for jsonField, permTag := range extractFieldPermsDeep(route.RequestType, typeIndex) {
-					fm.Request[jsonField] = permTag
-					addPerm(permTag, "")
-				}
+				reqNode = collectFieldPerms(route.RequestType.Members, typeIndex)
+				collectPermCodes(reqNode, addPerm)
 			}
 
+			var respNode *fieldNode
 			if route.ResponseType != nil {
-				for jsonField, permTag := range extractFieldPermsDeep(route.ResponseType, typeIndex) {
-					fm.Response[jsonField] = permTag
-					addPerm(permTag, "")
-				}
+				respNode = collectFieldPerms(route.ResponseType.Members, typeIndex)
+				collectPermCodes(respNode, addPerm)
 			}
 
-			if len(fm.Request) > 0 || len(fm.Response) > 0 {
-				result.fieldPerms[key] = fm
+			if reqNode != nil || respNode != nil {
+				result.fieldPerms[key] = &fieldPermMap{
+					Request:  reqNode,
+					Response: respNode,
+				}
 			}
 		}
 	}
@@ -117,17 +118,12 @@ func collect(input *PluginInput) *collectResult {
 	return result
 }
 
-// extractFieldPermsDeep 递归展开嵌套类型,提取所有 perm tag
-func extractFieldPermsDeep(t *TypeDef, typeIndex map[string]*TypeDef) map[string]string {
-	result := make(map[string]string)
-	if t == nil {
-		return result
+// collectFieldPerms 递归构建 fieldNode 树
+func collectFieldPerms(members []Member, typeIndex map[string]*TypeDef) *fieldNode {
+	node := &fieldNode{
+		Fields: make(map[string]string),
+		Nested: make(map[string]*fieldNode),
 	}
-	collectFieldPerms(t.Members, typeIndex, result)
-	return result
-}
-
-func collectFieldPerms(members []Member, typeIndex map[string]*TypeDef, result map[string]string) {
 	for _, m := range members {
 		jsonName := extractTagValue(m.Tag, "json")
 		if jsonName == "" {
@@ -136,16 +132,34 @@ func collectFieldPerms(members []Member, typeIndex map[string]*TypeDef, result m
 
 		permCode := extractTagValue(m.Tag, "perm")
 		if permCode != "" {
-			result[jsonName] = permCode
-			continue
+			node.Fields[jsonName] = permCode
 		}
 
-		// 没有 perm tag,尝试展开嵌套类型(去掉 [] 前缀)
 		rawName := strings.TrimPrefix(m.Type.RawName, "[]")
 		if nested, ok := typeIndex[rawName]; ok && len(nested.Members) > 0 {
-			collectFieldPerms(nested.Members, typeIndex, result)
+			child := collectFieldPerms(nested.Members, typeIndex)
+			if len(child.Fields) > 0 || len(child.Nested) > 0 {
+				node.Nested[jsonName] = child
+			}
 		}
 	}
+	if len(node.Fields) == 0 && len(node.Nested) == 0 {
+		return nil
+	}
+	return node
+}
+
+// collectPermCodes 从 fieldNode 树中提取所有 permCode 并注册
+func collectPermCodes(node *fieldNode, addPerm func(string, string)) {
+	if node == nil {
+		return
+	}
+	for _, code := range node.Fields {
+		addPerm(code, "")
+	}
+	for _, child := range node.Nested {
+		collectPermCodes(child, addPerm)
+	}
 }
 
 // extractTagValue 从 struct tag 字符串中提取指定 key 的值

+ 44 - 13
generate.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 	"os"
 	"path/filepath"
+	"strings"
 	"text/template"
 )
 
@@ -32,18 +33,10 @@ var FieldPerms = map[string]permlib.FieldPermMap{
 {{- range $key, $fm := .FieldPerms}}
 	{{printf "%q" $key}}: {
 		{{- if $fm.Request}}
-		Request: map[string]string{
-			{{- range $field, $code := $fm.Request}}
-			{{printf "%q" $field}}: {{printf "%q" $code}},
-			{{- end}}
-		},
+		Request: {{renderFieldNode $fm.Request 2}},
 		{{- end}}
 		{{- if $fm.Response}}
-		Response: map[string]string{
-			{{- range $field, $code := $fm.Response}}
-			{{printf "%q" $field}}: {{printf "%q" $code}},
-			{{- end}}
-		},
+		Response: {{renderFieldNode $fm.Response 2}},
 		{{- end}}
 	},
 {{- end}}
@@ -53,7 +46,45 @@ var FieldPerms = map[string]permlib.FieldPermMap{
 type templateData struct {
 	Perms      []permDecl
 	RoutePerms []routePermDecl
-	FieldPerms map[string]fieldPermMap
+	FieldPerms map[string]*fieldPermMap
+}
+
+func renderFieldNode(node *fieldNode, indent int) string {
+	if node == nil {
+		return "nil"
+	}
+	prefix := strings.Repeat("\t", indent)
+	innerPrefix := strings.Repeat("\t", indent+1)
+
+	var b strings.Builder
+	b.WriteString("&permlib.FieldNode{\n")
+
+	b.WriteString(innerPrefix + "Fields: map[string]string{")
+	if len(node.Fields) > 0 {
+		b.WriteString("\n")
+		for field, code := range node.Fields {
+			b.WriteString(fmt.Sprintf("%s\t%q: %q,\n", innerPrefix, field, code))
+		}
+		b.WriteString(innerPrefix + "}")
+	} else {
+		b.WriteString("}")
+	}
+	b.WriteString(",\n")
+
+	b.WriteString(innerPrefix + "Nested: map[string]*permlib.FieldNode{")
+	if len(node.Nested) > 0 {
+		b.WriteString("\n")
+		for field, child := range node.Nested {
+			b.WriteString(fmt.Sprintf("%s\t%q: %s,\n", innerPrefix, field, renderFieldNode(child, indent+2)))
+		}
+		b.WriteString(innerPrefix + "}")
+	} else {
+		b.WriteString("}")
+	}
+	b.WriteString(",\n")
+
+	b.WriteString(prefix + "}")
+	return b.String()
 }
 
 func generate(result *collectResult, dir string) error {
@@ -65,7 +96,8 @@ func generate(result *collectResult, dir string) error {
 	outFile := filepath.Join(outDir, "perms.go")
 
 	tpl, err := template.New("perms").Funcs(template.FuncMap{
-		"printf": fmt.Sprintf,
+		"printf":          fmt.Sprintf,
+		"renderFieldNode": renderFieldNode,
 	}).Parse(permsTpl)
 	if err != nil {
 		return fmt.Errorf("解析模板失败: %w", err)
@@ -95,4 +127,3 @@ func generate(result *collectResult, dir string) error {
 	)
 	return nil
 }
-