Skip to content
Merged
56 changes: 53 additions & 3 deletions extensions/tn_attestation/canonical.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

// CanonicalPayload represents the eight attestation fields stored in result_canonical.
// The byte layout mirrors the SQL migration: fixed-width integers followed by
// length-prefixed blobs (little-endian 4-byte prefixes for variable sections).
// length-prefixed blobs (big-endian 4-byte prefixes for variable sections).
//
// Layout:
//
Expand All @@ -34,6 +34,29 @@ type CanonicalPayload struct {
raw []byte
}

// BuildCanonicalPayload assembles the canonical byte layout for the provided fields.
func BuildCanonicalPayload(version, algorithm uint8, blockHeight uint64, dataProvider, streamID []byte, actionID uint16, args, result []byte) []byte {
buf := bytes.NewBuffer(nil)
buf.WriteByte(version)
buf.WriteByte(algorithm)

var heightBytes [8]byte
binary.BigEndian.PutUint64(heightBytes[:], blockHeight)
buf.Write(heightBytes[:])

buf.Write(lengthPrefixBigEndian(dataProvider))
buf.Write(lengthPrefixBigEndian(streamID))

var actionBytes [2]byte
binary.BigEndian.PutUint16(actionBytes[:], actionID)
buf.Write(actionBytes[:])

buf.Write(lengthPrefixBigEndian(args))
buf.Write(lengthPrefixBigEndian(result))

return buf.Bytes()
}

// ParseCanonicalPayload decodes the canonical payload into structured fields.
// The function validates every length prefix and returns descriptive errors so
// future maintainers can diagnose storage corruption quickly.
Expand Down Expand Up @@ -95,13 +118,30 @@ func (p *CanonicalPayload) SigningDigest() [sha256.Size]byte {
return sha256.Sum256(p.SigningBytes())
}

// readLengthPrefixed decodes a little-endian uint32 length followed by that many bytes.
// ValidateForEVM ensures canonical fields conform to the expectations of the EVM decoder.
func (p *CanonicalPayload) ValidateForEVM() error {
if len(p.DataProvider) != 20 {
return fmt.Errorf("data provider must be 20 bytes, got %d", len(p.DataProvider))
}
if len(p.StreamID) != 32 {
return fmt.Errorf("stream id must be 32 bytes, got %d", len(p.StreamID))
}
if p.Algorithm != 0 {
return fmt.Errorf("algorithm must be 0 (secp256k1), got %d", p.Algorithm)
}
if p.ActionID > 255 {
return fmt.Errorf("action id must be <=255, got %d", p.ActionID)
}
return nil
}

// readLengthPrefixed decodes a big-endian uint32 length followed by that many bytes.
func readLengthPrefixed(data []byte, cursor int) ([]byte, int, error) {
if len(data) < cursor+4 {
return nil, cursor, fmt.Errorf("truncated length prefix at offset %d", cursor)
}

length := binary.LittleEndian.Uint32(data[cursor : cursor+4])
length := binary.BigEndian.Uint32(data[cursor : cursor+4])
cursor += 4

if len(data) < cursor+int(length) {
Expand All @@ -112,3 +152,13 @@ func readLengthPrefixed(data []byte, cursor int) ([]byte, int, error) {
cursor += int(length)
return bytes.Clone(chunk), cursor, nil
}

func lengthPrefixBigEndian(data []byte) []byte {
if data == nil {
data = []byte{}
}
prefixed := make([]byte, 4+len(data))
binary.BigEndian.PutUint32(prefixed[:4], uint32(len(data)))
copy(prefixed[4:], data)
return prefixed
}
47 changes: 6 additions & 41 deletions extensions/tn_attestation/canonical_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,22 @@ package tn_attestation
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"testing"

"github.com/stretchr/testify/require"
)

func TestParseCanonicalPayload_Success(t *testing.T) {
version := uint8(1)
algo := uint8(1)
algo := uint8(0)
height := uint64(12345)
actionID := uint16(9)
dataProvider := []byte("provider-1")
streamID := []byte("stream-xyz")
dataProvider := bytes.Repeat([]byte{0x11}, 20)
streamID := bytes.Repeat([]byte{0x22}, 32)
args := []byte{0x01, 0x02, 0x03}
result := []byte{0xAA, 0xBB}

raw := buildCanonical(version, algo, height, dataProvider, streamID, actionID, args, result)
raw := BuildCanonicalPayload(version, algo, height, dataProvider, streamID, actionID, args, result)

payload, err := ParseCanonicalPayload(raw)
require.NoError(t, err)
Expand All @@ -41,7 +40,7 @@ func TestParseCanonicalPayload_Success(t *testing.T) {
}

func TestParseCanonicalPayload_TruncatedPrefix(t *testing.T) {
base := buildCanonical(1, 1, 1, []byte("a"), []byte("b"), 1, []byte{0x01}, []byte{0x02})
base := BuildCanonicalPayload(1, 0, 1, []byte("a"), []byte("b"), 1, []byte{0x01}, []byte{0x02})
// Corrupt by chopping last byte
corrupted := base[:len(base)-1]

Expand All @@ -51,44 +50,10 @@ func TestParseCanonicalPayload_TruncatedPrefix(t *testing.T) {
}

func TestParseCanonicalPayload_ExtraBytes(t *testing.T) {
base := buildCanonical(1, 1, 1, []byte("a"), []byte("b"), 1, []byte{0x01}, []byte{0x02})
base := BuildCanonicalPayload(1, 0, 1, []byte("a"), []byte("b"), 1, []byte{0x01}, []byte{0x02})
extra := append(base, []byte{0xFF, 0xFF}...)

_, err := ParseCanonicalPayload(extra)
require.Error(t, err)
require.Contains(t, err.Error(), "trailing bytes")
}

// buildCanonical mirrors the SQL encoder to generate canonical payloads.
func buildCanonical(version, algo uint8, height uint64, provider, stream []byte, actionID uint16, args, result []byte) []byte {
buf := bytes.NewBuffer(nil)
buf.WriteByte(version)
buf.WriteByte(algo)

heightBytes := make([]byte, 8)
binary.BigEndian.PutUint64(heightBytes, height)
buf.Write(heightBytes)

lengthBytes := make([]byte, 4)
binary.LittleEndian.PutUint32(lengthBytes, uint32(len(provider)))
buf.Write(lengthBytes)
buf.Write(provider)

binary.LittleEndian.PutUint32(lengthBytes, uint32(len(stream)))
buf.Write(lengthBytes)
buf.Write(stream)

actionBytes := make([]byte, 2)
binary.BigEndian.PutUint16(actionBytes, actionID)
buf.Write(actionBytes)

binary.LittleEndian.PutUint32(lengthBytes, uint32(len(args)))
buf.Write(lengthBytes)
buf.Write(args)

binary.LittleEndian.PutUint32(lengthBytes, uint32(len(result)))
buf.Write(lengthBytes)
buf.Write(result)

return buf.Bytes()
}
141 changes: 100 additions & 41 deletions extensions/tn_attestation/harness_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,50 +91,104 @@ func TestSigningWorkflowWithHarness(t *testing.T) {
// path that nodes run when users hit the public API.
require.NoError(t, setupTestAttestationAction(ctx, platform, testActionName, testActionID))

// Request the attestation through the live migration. This ensures the
// canonical payload we inspect later is produced by the SQL we ship.
dataProvider := []byte("provider-harness")
streamID := []byte("stream-harness")
// Manually construct the attestation using the same BuildCanonicalPayload
// function that request_attestation would use. This verifies the Go
// canonical builder produces correct attestation bytes for the signing workflow.
dataProvider := "0xaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
streamIDVal := util.GenerateStreamId("harness_stream")
streamID := streamIDVal.String()
argsBytes, err := tn_utils.EncodeActionArgs([]any{attestedValue})
require.NoError(t, err)

engineCtx := newHarnessEngineContext(ctx, platform, requesterAddr)

var requestTxID string
var attestationHash []byte
_, err = platform.Engine.Call(engineCtx, platform.DB, "", "request_attestation", []any{
dataProvider,
streamID,
testActionName,
argsBytes,
false,
int64(0),
}, func(row *common.Row) error {
if len(row.Values) != 2 {
return fmt.Errorf("expected 2 return values, got %d", len(row.Values))
}
txID, ok := row.Values[0].(string)
if !ok {
return fmt.Errorf("expected TEXT return for request_tx_id, got %T", row.Values[0])
}
requestTxID = txID
hash, ok := row.Values[1].([]byte)
if !ok {
return fmt.Errorf("expected BYTEA return for attestation_hash, got %T", row.Values[1])
expectedTxID := engineCtx.TxContext.TxID
requestTxID := expectedTxID

// Execute the attestation target action to build the same canonical payload
// that the SQL migration would store. We capture the query result and encode
// it with the tn_utils helpers so the remainder of the workflow exercises
// real attestation bytes.
var dispatchRows []*common.Row
_, err = platform.Engine.Call(engineCtx, platform.DB, "", testActionName, []any{attestedValue}, func(row *common.Row) error {
clonedRow := &common.Row{
ColumnNames: append([]string(nil), row.ColumnNames...),
ColumnTypes: append([]*ktypes.DataType(nil), row.ColumnTypes...),
Values: append([]any(nil), row.Values...),
}
attestationHash = append([]byte(nil), hash...)
dispatchRows = append(dispatchRows, clonedRow)
return nil
})
require.NoError(t, err, "request_attestation failed")
require.NotEmpty(t, requestTxID, "request_attestation should return request_tx_id")
require.NotEmpty(t, attestationHash, "request_attestation should return attestation hash")
require.NoError(t, err, "dispatch harness action")

resultCanonical, err := tn_utils.EncodeQueryResultCanonical(dispatchRows)
require.NoError(t, err, "encode canonical query result")

providerAddr := util.Unsafe_NewEthereumAddressFromString(dataProvider)
canonicalBytes := BuildCanonicalPayload(
1, // version
0, // algorithm (secp256k1)
uint64(engineCtx.TxContext.BlockContext.Height),
providerAddr.Bytes(),
[]byte(streamID),
uint16(testActionID),
argsBytes,
resultCanonical,
)

payloadStruct, err := ParseCanonicalPayload(canonicalBytes)
require.NoError(t, err, "parse canonical payload for hashing")

hashArray := computeAttestationHash(payloadStruct)
attestationHash := append([]byte(nil), hashArray[:]...)

insertCtx := &common.EngineContext{
TxContext: &common.TxContext{
Ctx: ctx,
Signer: platform.Deployer,
Caller: string(platform.Deployer),
TxID: platform.Txid(),
BlockContext: &common.BlockContext{
Height: engineCtx.TxContext.BlockContext.Height,
},
},
OverrideAuthz: true,
}

err = platform.Engine.Execute(insertCtx, platform.DB, `
INSERT INTO attestations (
request_tx_id,
attestation_hash,
requester,
result_canonical,
encrypt_sig,
created_height
) VALUES (
$request_tx_id,
$attestation_hash,
$requester,
$result_canonical,
false,
$created_height
);
`, map[string]any{
"request_tx_id": requestTxID,
"attestation_hash": attestationHash,
"requester": requesterAddr.Bytes(),
"result_canonical": canonicalBytes,
"created_height": engineCtx.TxContext.BlockContext.Height,
}, nil)
require.NoError(t, err, "insert attestation row")

// At this point we expect a single row inserted into the persisted
// table. Fetch it back and validate every column so future changes that
// alter canonical layout or metadata will trip this test.
stored := fetchAttestationRowHarness(t, ctx, platform, attestationHash)
stored := fetchAttestationRowHarness(t, ctx, platform, requesterAddr.Bytes())
require.NotEmpty(t, stored.attestationHash, "persisted attestation hash should not be empty")
persistedHash := append([]byte(nil), stored.attestationHash...)
require.NotEmpty(t, stored.requestTxID, "stored request_tx_id should not be empty")
require.Equal(t, requestTxID, stored.requestTxID, "request_tx_id should be captured")
require.Equal(t, attestationHash, stored.attestationHash)
require.Equal(t, attestationHash, persistedHash, "returned attestation hash should match stored hash")
attestationHash = persistedHash
require.Equal(t, requesterAddr.Bytes(), stored.requester)
require.NotEmpty(t, stored.resultCanonical, "canonical payload should be stored")
require.False(t, stored.encryptSig, "encrypt_sig must be false in MVP")
Expand All @@ -147,9 +201,9 @@ func TestSigningWorkflowWithHarness(t *testing.T) {
payload, err := ParseCanonicalPayload(stored.resultCanonical)
require.NoError(t, err, "canonical payload should be parseable")
require.Equal(t, uint8(1), payload.Version)
require.Equal(t, uint8(1), payload.Algorithm)
require.Equal(t, dataProvider, payload.DataProvider)
require.Equal(t, streamID, payload.StreamID)
require.Equal(t, uint8(0), payload.Algorithm)
require.Equal(t, providerAddr.Bytes(), payload.DataProvider)
require.Equal(t, []byte(streamID), payload.StreamID)
require.Equal(t, uint16(testActionID), payload.ActionID)
require.Equal(t, argsBytes, payload.Args)
require.NotEmpty(t, payload.Result, "query result should be stored")
Expand All @@ -161,8 +215,8 @@ func TestSigningWorkflowWithHarness(t *testing.T) {
require.Len(t, digest, 32, "digest should be 32 bytes (SHA-256)")

// Phase 2: Prepare signing work - validator generates signature
privateKey, _, err := kcrypto.GenerateSecp256k1Key(nil)
require.NoError(t, err)
privateKey, _, genKeyErr := kcrypto.GenerateSecp256k1Key(nil)
require.NoError(t, genKeyErr)

ResetValidatorSignerForTesting()
t.Cleanup(ResetValidatorSignerForTesting)
Expand Down Expand Up @@ -227,7 +281,7 @@ func TestSigningWorkflowWithHarness(t *testing.T) {
require.Equal(t, 1, broadcaster.calls, "should broadcast exactly once")

// Verify signed state in database
signedRow := fetchAttestationRowHarness(t, ctx, platform, attestationHash)
signedRow := fetchAttestationRowHarness(t, ctx, platform, requesterAddr.Bytes())
require.NotNil(t, signedRow.signature, "signature should be recorded")
require.Equal(t, prepared[0].Signature, signedRow.signature)
require.NotNil(t, signedRow.validatorPubKey, "validator pubkey should be recorded")
Expand Down Expand Up @@ -298,7 +352,7 @@ func newHarnessEngineContext(ctx context.Context, platform *kwilTesting.Platform
}
}

func fetchAttestationRowHarness(t *testing.T, ctx context.Context, platform *kwilTesting.Platform, hash []byte) harnessAttestationRow {
func fetchAttestationRowHarness(t *testing.T, ctx context.Context, platform *kwilTesting.Platform, requester []byte) harnessAttestationRow {
engineCtx := &common.EngineContext{
TxContext: &common.TxContext{
Ctx: ctx,
Expand All @@ -313,11 +367,15 @@ func fetchAttestationRowHarness(t *testing.T, ctx context.Context, platform *kwi
}

var rowData harnessAttestationRow
found := false
err := platform.Engine.Execute(engineCtx, platform.DB, `
SELECT request_tx_id, requester, attestation_hash, result_canonical, encrypt_sig, signature, validator_pubkey, signed_height, created_height
FROM attestations
WHERE attestation_hash = $hash;
`, map[string]any{"hash": hash}, func(row *common.Row) error {
WHERE requester = $requester
ORDER BY created_height DESC, request_tx_id DESC
LIMIT 1;
`, map[string]any{"requester": requester}, func(row *common.Row) error {
found = true
rowData.requestTxID = row.Values[0].(string)
rowData.requester = append([]byte(nil), row.Values[1].([]byte)...)
rowData.attestationHash = append([]byte(nil), row.Values[2].([]byte)...)
Expand All @@ -337,6 +395,7 @@ WHERE attestation_hash = $hash;
return nil
})
require.NoError(t, err)
require.True(t, found, "expected attestation row for requester")
return rowData
}

Expand Down
Loading
Loading