Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions pkg/config/v2/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"fmt"

"github.com/goccy/go-yaml"

"github.com/docker/cagent/pkg/config/types"
)

Expand Down Expand Up @@ -315,6 +317,105 @@ func (s *RAGStrategyConfig) UnmarshalYAML(unmarshal func(any) error) error {
return nil
}

// MarshalYAML implements custom marshaling to flatten Params into parent level
func (s RAGStrategyConfig) MarshalYAML() ([]byte, error) {
result := s.buildFlattenedMap()
return yaml.Marshal(result)
}

// MarshalJSON implements custom marshaling to flatten Params into parent level
// This ensures JSON and YAML have the same flattened format for consistency
func (s RAGStrategyConfig) MarshalJSON() ([]byte, error) {
result := s.buildFlattenedMap()
return json.Marshal(result)
}

// UnmarshalJSON implements custom unmarshaling to capture all extra fields into Params
// This ensures JSON and YAML have the same flattened format for consistency
func (s *RAGStrategyConfig) UnmarshalJSON(data []byte) error {
// First unmarshal into a map to capture everything
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return err
}

// Extract known fields
if t, ok := raw["type"].(string); ok {
s.Type = t
delete(raw, "type")
}

if docs, ok := raw["docs"].([]any); ok {
s.Docs = make([]string, len(docs))
for i, d := range docs {
if str, ok := d.(string); ok {
s.Docs[i] = str
}
}
delete(raw, "docs")
}

if dbRaw, ok := raw["database"]; ok {
if dbStr, ok := dbRaw.(string); ok {
var db RAGDatabaseConfig
db.value = dbStr
s.Database = db
}
delete(raw, "database")
}

if chunkRaw, ok := raw["chunking"]; ok {
// Re-marshal and unmarshal chunking config
chunkBytes, _ := json.Marshal(chunkRaw)
var chunk RAGChunkingConfig
if err := json.Unmarshal(chunkBytes, &chunk); err == nil {
s.Chunking = chunk
}
delete(raw, "chunking")
}

if limit, ok := raw["limit"].(float64); ok {
s.Limit = int(limit)
delete(raw, "limit")
}

// Everything else goes into Params for strategy-specific configuration
s.Params = raw

return nil
}

// buildFlattenedMap creates a flattened map representation for marshaling
// Used by both MarshalYAML and MarshalJSON to ensure consistent format
func (s RAGStrategyConfig) buildFlattenedMap() map[string]any {
result := make(map[string]any)

if s.Type != "" {
result["type"] = s.Type
}
if len(s.Docs) > 0 {
result["docs"] = s.Docs
}
if !s.Database.IsEmpty() {
dbStr, _ := s.Database.AsString()
result["database"] = dbStr
}
// Only include chunking if any fields are set
if s.Chunking.Size > 0 || s.Chunking.Overlap > 0 || s.Chunking.RespectWordBoundaries {
result["chunking"] = s.Chunking
}
if s.Limit > 0 {
result["limit"] = s.Limit
}

// Flatten Params into the same level
for k, v := range s.Params {
result[k] = v
}

return result
}

// unmarshalDatabaseConfig handles DatabaseConfig unmarshaling from raw YAML data.
// For RAG strategies, the database configuration is intentionally simple:
// a single string value under the `database` key that points to the SQLite
Expand Down Expand Up @@ -407,6 +508,32 @@ func (d *RAGDatabaseConfig) IsEmpty() bool {
return d.value == nil
}

// MarshalYAML implements custom marshaling for DatabaseConfig
func (d RAGDatabaseConfig) MarshalYAML() ([]byte, error) {
if d.value == nil {
return yaml.Marshal(nil)
}
return yaml.Marshal(d.value)
}

// MarshalJSON implements custom marshaling for DatabaseConfig
func (d RAGDatabaseConfig) MarshalJSON() ([]byte, error) {
if d.value == nil {
return json.Marshal(nil)
}
return json.Marshal(d.value)
}

// UnmarshalJSON implements custom unmarshaling for DatabaseConfig
func (d *RAGDatabaseConfig) UnmarshalJSON(data []byte) error {
var str string
if err := json.Unmarshal(data, &str); err == nil {
d.value = str
return nil
}
return fmt.Errorf("database must be a string path to a sqlite database")
}

// RAGChunkingConfig represents text chunking configuration
type RAGChunkingConfig struct {
Size int `json:"size,omitempty"`
Expand Down
184 changes: 184 additions & 0 deletions pkg/config/v2/types_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package v2

import (
"encoding/json"
"testing"

"github.com/goccy/go-yaml"
Expand Down Expand Up @@ -32,3 +33,186 @@ func TestCommandsUnmarshal_List(t *testing.T) {
require.Equal(t, "check disk", c["df"])
require.Equal(t, "list files", c["ls"])
}

func TestRAGStrategyConfig_MarshalUnmarshal_FlattenedParams(t *testing.T) {
t.Parallel()

// Test that params are flattened during unmarshal and remain flattened after marshal
input := []byte(`type: chunked-embeddings
model: embeddinggemma
database: ./rag/test.db
threshold: 0.5
vector_dimensions: 768
`)

var strategy RAGStrategyConfig

// Unmarshal
err := yaml.Unmarshal(input, &strategy)
require.NoError(t, err)
require.Equal(t, "chunked-embeddings", strategy.Type)
require.Equal(t, "./rag/test.db", mustGetDBString(t, strategy.Database))
require.NotNil(t, strategy.Params)
require.Equal(t, "embeddinggemma", strategy.Params["model"])
require.InEpsilon(t, 0.5, strategy.Params["threshold"], 0.001)
// YAML may unmarshal numbers as different numeric types (int, uint64, float64)
require.InEpsilon(t, float64(768), toFloat64(strategy.Params["vector_dimensions"]), 0.001)

// Marshal back
output, err := yaml.Marshal(strategy)
require.NoError(t, err)

// Verify it's still flattened (no "params:" key)
outputStr := string(output)
require.Contains(t, outputStr, "type: chunked-embeddings")
require.Contains(t, outputStr, "model: embeddinggemma")
require.Contains(t, outputStr, "threshold: 0.5")
require.Contains(t, outputStr, "vector_dimensions: 768")
require.NotContains(t, outputStr, "params:")

// Unmarshal again to verify round-trip
var strategy2 RAGStrategyConfig
err = yaml.Unmarshal(output, &strategy2)
require.NoError(t, err)
require.Equal(t, strategy.Type, strategy2.Type)
require.Equal(t, strategy.Params["model"], strategy2.Params["model"])
require.Equal(t, strategy.Params["threshold"], strategy2.Params["threshold"])
// YAML may unmarshal numbers as different numeric types (int, uint64, float64)
// Just verify the numeric value is correct
require.InEpsilon(t, float64(768), toFloat64(strategy2.Params["vector_dimensions"]), 0.001)
}

func TestRAGStrategyConfig_MarshalUnmarshal_WithDatabase(t *testing.T) {
t.Parallel()

input := []byte(`type: chunked-embeddings
database: ./test.db
model: test-model
`)

var strategy RAGStrategyConfig
err := yaml.Unmarshal(input, &strategy)
require.NoError(t, err)

// Marshal back
output, err := yaml.Marshal(strategy)
require.NoError(t, err)

// Should contain database as a simple string, not nested with sub-fields
outputStr := string(output)
require.Contains(t, outputStr, "database: ./test.db")
require.NotContains(t, outputStr, " value:") // Should not be nested with internal fields
require.Contains(t, outputStr, "model: test-model")
require.NotContains(t, outputStr, "params:") // Should be flattened
}

func TestRAGStrategyConfig_MarshalJSON(t *testing.T) {
t.Parallel()

// Create a strategy config with various fields
strategy := RAGStrategyConfig{
Type: "chunked-embeddings",
Docs: []string{"doc1.md", "doc2.md"},
Limit: 10,
Params: map[string]any{
"model": "embedding-model",
"threshold": 0.75,
},
}
strategy.Database.value = "./test.db"

// Marshal to JSON
output, err := json.Marshal(strategy)
require.NoError(t, err)

// Verify JSON structure - params should be flattened
outputStr := string(output)
require.Contains(t, outputStr, `"type":"chunked-embeddings"`)
require.Contains(t, outputStr, `"database":"./test.db"`)
require.Contains(t, outputStr, `"model":"embedding-model"`)
require.Contains(t, outputStr, `"threshold":0.75`)
require.NotContains(t, outputStr, `"params"`)
}

func TestRAGStrategyConfig_UnmarshalJSON(t *testing.T) {
t.Parallel()

input := []byte(`{
"type": "bm25",
"database": "./bm25.db",
"limit": 20,
"k1": 1.2,
"b": 0.75
}`)

var strategy RAGStrategyConfig
err := json.Unmarshal(input, &strategy)
require.NoError(t, err)

require.Equal(t, "bm25", strategy.Type)
require.Equal(t, "./bm25.db", mustGetDBString(t, strategy.Database))
require.Equal(t, 20, strategy.Limit)
require.NotNil(t, strategy.Params)
require.InEpsilon(t, 1.2, toFloat64(strategy.Params["k1"]), 0.001)
require.InEpsilon(t, 0.75, toFloat64(strategy.Params["b"]), 0.001)
}

func TestRAGStrategyConfig_JSONRoundTrip(t *testing.T) {
t.Parallel()

// Create original config
original := RAGStrategyConfig{
Type: "chunked-embeddings",
Docs: []string{"readme.md"},
Limit: 15,
Params: map[string]any{
"embedding_model": "openai/text-embedding-3-small",
"vector_dimensions": float64(1536),
"threshold": 0.6,
},
}
original.Database.value = "./vectors.db"

// Marshal to JSON
jsonData, err := json.Marshal(original)
require.NoError(t, err)

// Unmarshal back
var restored RAGStrategyConfig
err = json.Unmarshal(jsonData, &restored)
require.NoError(t, err)

// Verify round-trip preserves data
require.Equal(t, original.Type, restored.Type)
require.Equal(t, original.Docs, restored.Docs)
require.Equal(t, original.Limit, restored.Limit)
require.Equal(t, mustGetDBString(t, original.Database), mustGetDBString(t, restored.Database))
require.Equal(t, original.Params["embedding_model"], restored.Params["embedding_model"])
require.Equal(t, original.Params["vector_dimensions"], restored.Params["vector_dimensions"])
require.Equal(t, original.Params["threshold"], restored.Params["threshold"])
}

func mustGetDBString(t *testing.T, db RAGDatabaseConfig) string {
t.Helper()
str, err := db.AsString()
require.NoError(t, err)
return str
}

// toFloat64 converts various numeric types to float64 for comparison
func toFloat64(v any) float64 {
switch val := v.(type) {
case int:
return float64(val)
case int64:
return float64(val)
case uint64:
return float64(val)
case float64:
return val
case float32:
return float64(val)
default:
return 0
}
}