Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wire: cache the non-witness serialization of MsgTx to memoize part of… #1376

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions wire/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,18 @@ func TestMessage(t *testing.T) {
spew.Sdump(msg))
continue
}

// Blank out the cached encoding for transactions to ensure the
// deep equality check doesn't fail.
if tx, ok := msg.(*MsgTx); ok {
tx.cachedSeralizedNoWitness = nil
}
if block, ok := msg.(*MsgBlock); ok {
for _, tx := range block.Transactions {
tx.cachedSeralizedNoWitness = nil
}
}

if !reflect.DeepEqual(msg, test.out) {
t.Errorf("ReadMessage #%d\n got: %v want: %v", i,
spew.Sdump(msg), spew.Sdump(test.out))
Expand Down Expand Up @@ -170,6 +182,18 @@ func TestMessage(t *testing.T) {
spew.Sdump(msg))
continue
}

// Blank out the cached encoding for transactions to ensure the
// deep equality check doesn't fail.
if tx, ok := msg.(*MsgTx); ok {
tx.cachedSeralizedNoWitness = nil
}
if block, ok := msg.(*MsgBlock); ok {
for _, tx := range block.Transactions {
tx.cachedSeralizedNoWitness = nil
}
}

if !reflect.DeepEqual(msg, test.out) {
t.Errorf("ReadMessage #%d\n got: %v want: %v", i,
spew.Sdump(msg), spew.Sdump(test.out))
Expand Down
74 changes: 56 additions & 18 deletions wire/msgtx.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ type MsgTx struct {
TxIn []*TxIn
TxOut []*TxOut
LockTime uint32

// cachedSeralizedNoWitness is a cached version of the serialization of
// this transaction without witness data. When we decode a transaction,
// we'll write out the non-witness bytes to this so we can quickly
// calculate the TxHash later if needed.
cachedSeralizedNoWitness []byte
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: s/cachedSeralizedNoWitness/cachedSerializedNoWitness/

}

// AddTxIn adds a transaction input to the message.
Expand All @@ -357,13 +363,19 @@ func (msg *MsgTx) AddTxOut(to *TxOut) {

// TxHash generates the Hash for the transaction.
func (msg *MsgTx) TxHash() chainhash.Hash {
// Encode the transaction and calculate double sha256 on the result.
// Ignore the error returns since the only way the encode could fail
// is being out of memory or due to nil pointers, both of which would
// cause a run-time panic.
buf := bytes.NewBuffer(make([]byte, 0, msg.SerializeSizeStripped()))
_ = msg.SerializeNoWitness(buf)
return chainhash.DoubleHashH(buf.Bytes())
if msg.cachedSeralizedNoWitness == nil {
// Encode the transaction and calculate double sha256 on the
// result. Ignore the error returns since the only way the
// encode could fail is being out of memory or due to nil
// pointers, both of which would cause a run-time panic.
strippedSize := msg.SerializeSizeStripped()
buf := bytes.NewBuffer(make([]byte, 0, strippedSize))
_ = msg.SerializeNoWitness(buf)

msg.cachedSeralizedNoWitness = buf.Bytes()
}

return chainhash.DoubleHashH(msg.cachedSeralizedNoWitness)
}

// WitnessHash generates the hash of the transaction serialized according to
Expand Down Expand Up @@ -461,7 +473,14 @@ func (msg *MsgTx) Copy() *MsgTx {
// See Deserialize for decoding transactions stored to disk, such as in a
// database, as opposed to decoding transactions from the wire.
func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error {
version, err := binarySerializer.Uint32(r, littleEndian)
// We'll use a tee reader in order to incrementally cache the raw
// non-witness serialization of this transaction. We'll then later
// cache this value as it allow to compute the TxHash more quickly, as
// we don't need to re-serialize the entire transaction.
var rawTxBuf bytes.Buffer
rawTxTeeReader := io.TeeReader(r, &rawTxBuf)

version, err := binarySerializer.Uint32(rawTxTeeReader, littleEndian)
if err != nil {
return err
}
Expand All @@ -472,12 +491,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
return err
}

// A count of zero (meaning no TxIn's to the uninitiated) means that the
// value is a TxFlagMarker, and hence indicates the presence of a flag.
var flag [1]TxFlag
// A count of zero (meaning no TxIn's to the uninitiated) indicates
// this is a transaction with witness data. Notice that we don't use
// the rawTxTeeReader here, as these are segwit specific bytes.
var (
flag [1]byte
hasWitneess bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/hasWitneess/hasWitness

)
if count == TxFlagMarker && enc == WitnessEncoding {
// The count varint was in fact the flag marker byte. Next, we need to
// read the flag value, which is a single byte.
// Next, we need to read the flag, which is a single byte.
if _, err = io.ReadFull(r, flag[:]); err != nil {
return err
}
Expand All @@ -495,6 +517,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
if err != nil {
return err
}

hasWitneess = true
}

// Write out the actual number of inputs as this won't be the very byte
// series after the versino of segwit transactions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: s/versino/version

if WriteVarInt(&rawTxBuf, pver, count); err != nil {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add err := WriteVarInt().

str := fmt.Sprintf("unable to write txin count: %v", err)
return messageError("MsgTx.BtcDecode", str)
}

// Prevent more input transactions than could possibly fit into a
Expand Down Expand Up @@ -545,15 +576,15 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
// and needs to be returned to the pool on error.
ti := &txIns[i]
msg.TxIn[i] = ti
err = readTxIn(r, pver, msg.Version, ti)
err = readTxIn(rawTxTeeReader, pver, msg.Version, ti)
if err != nil {
returnScriptBuffers()
return err
}
totalScriptSize += uint64(len(ti.SignatureScript))
}

count, err = ReadVarInt(r, pver)
count, err = ReadVarInt(rawTxTeeReader, pver)
if err != nil {
returnScriptBuffers()
return err
Expand All @@ -578,7 +609,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
// and needs to be returned to the pool on error.
to := &txOuts[i]
msg.TxOut[i] = to
err = ReadTxOut(r, pver, msg.Version, to)
err = ReadTxOut(rawTxTeeReader, pver, msg.Version, to)
if err != nil {
returnScriptBuffers()
return err
Expand All @@ -588,7 +619,7 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error

// If the transaction's flag byte isn't 0x00 at this point, then one or
// more of its inputs has accompanying witness data.
if flag[0] != 0 && enc == WitnessEncoding {
if hasWitneess && enc == WitnessEncoding {
for _, txin := range msg.TxIn {
// For each input, the witness is encoded as a stack
// with one or more items. Therefore, we first read a
Expand Down Expand Up @@ -626,7 +657,9 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
}
}

msg.LockTime, err = binarySerializer.Uint32(r, littleEndian)
msg.LockTime, err = binarySerializer.Uint32(
rawTxTeeReader, littleEndian,
)
if err != nil {
returnScriptBuffers()
return err
Expand Down Expand Up @@ -700,6 +733,11 @@ func (msg *MsgTx) BtcDecode(r io.Reader, pver uint32, enc MessageEncoding) error
scriptPool.Return(pkScript)
}

// Now that we've decoded the entire transaction without any issues,
// we'll cache the non-witness serialization so we can more quickly
// calculate the TxHash in the future.
msg.cachedSeralizedNoWitness = rawTxBuf.Bytes()

return nil
}

Expand Down
53 changes: 53 additions & 0 deletions wire/msgtx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,13 @@ func TestTxHash(t *testing.T) {
t.Errorf("TxHash: wrong hash - got %v, want %v",
spew.Sprint(txHash), spew.Sprint(wantHash))
}

// Compute it again to ensure any cached elements, are valid.
txHash = msgTx.TxHash()
if !txHash.IsEqual(wantHash) {
t.Errorf("TxHash: wrong hash - got %v, want %v",
spew.Sprint(txHash), spew.Sprint(wantHash))
}
}

// TestTxSha tests the ability to generate the wtxid, and txid of a transaction
Expand Down Expand Up @@ -258,6 +265,18 @@ func TestWTxSha(t *testing.T) {
t.Errorf("WTxSha: wrong hash - got %v, want %v",
spew.Sprint(wtxid), spew.Sprint(wantHashWTxid))
}

// Compute the values again to ensure any cached elements are valid.
txid = msgTx.TxHash()
if !txid.IsEqual(wantHashTxid) {
t.Errorf("TxSha: wrong hash - got %v, want %v",
spew.Sprint(txid), spew.Sprint(wantHashTxid))
}
wtxid = msgTx.WitnessHash()
if !wtxid.IsEqual(wantHashWTxid) {
t.Errorf("WTxSha: wrong hash - got %v, want %v",
spew.Sprint(wtxid), spew.Sprint(wantHashWTxid))
}
}

// TestTxWire tests the MsgTx wire encode and decode for various numbers
Expand Down Expand Up @@ -393,6 +412,23 @@ func TestTxWire(t *testing.T) {
t.Errorf("BtcDecode #%d error %v", i, err)
continue
}

// If this is the base encoding, then ensure that the cached
// serialization properly matches the raw encoding.
if test.enc == BaseEncoding {
if !bytes.Equal(
test.buf, msg.cachedSeralizedNoWitness,
) {
t.Errorf("BtcdDecode #%d: cached encoding "+
"is wrong, expected %x got %x", i,
test.buf,
msg.cachedSeralizedNoWitness)
continue
}
}

msg.cachedSeralizedNoWitness = nil

if !reflect.DeepEqual(&msg, test.out) {
t.Errorf("BtcDecode #%d\n got: %s want: %s", i,
spew.Sdump(&msg), spew.Sdump(test.out))
Expand Down Expand Up @@ -539,6 +575,23 @@ func TestTxSerialize(t *testing.T) {
t.Errorf("Deserialize #%d error %v", i, err)
continue
}

// Ensure that the raw non-witness encoding matches the cached
// non-witness encoding bytes.
var b bytes.Buffer
if err := tx.SerializeNoWitness(&b); err != nil {
t.Errorf("Deserialize #%d: unable to encode: %v", i, err)
}
if !bytes.Equal(b.Bytes(), tx.cachedSeralizedNoWitness) {
t.Errorf("Deserialize #%d: cached encoding "+
"is wrong, expected %x got %x", i,
b.Bytes(),
tx.cachedSeralizedNoWitness)
continue
}

tx.cachedSeralizedNoWitness = nil

if !reflect.DeepEqual(&tx, test.out) {
t.Errorf("Deserialize #%d\n got: %s want: %s", i,
spew.Sdump(&tx), spew.Sdump(test.out))
Expand Down
Loading