diff --git a/pkg/formats/mpegts/codec_ac3.go b/pkg/formats/mpegts/codec_ac3.go index f8a01f7..2599eb4 100644 --- a/pkg/formats/mpegts/codec_ac3.go +++ b/pkg/formats/mpegts/codec_ac3.go @@ -19,8 +19,7 @@ func (*CodecAC3) isCodec() {} func (c CodecAC3) marshal(pid uint16) (*astits.PMTElementaryStream, error) { return &astits.PMTElementaryStream{ - ElementaryPID: pid, - ElementaryStreamDescriptors: nil, - StreamType: astits.StreamTypeAC3Audio, + ElementaryPID: pid, + StreamType: astits.StreamTypeAC3Audio, }, nil } diff --git a/pkg/formats/mpegts/codec_h264.go b/pkg/formats/mpegts/codec_h264.go index 05b4aec..89df2d9 100644 --- a/pkg/formats/mpegts/codec_h264.go +++ b/pkg/formats/mpegts/codec_h264.go @@ -16,8 +16,7 @@ func (*CodecH264) isCodec() {} func (c CodecH264) marshal(pid uint16) (*astits.PMTElementaryStream, error) { return &astits.PMTElementaryStream{ - ElementaryPID: pid, - ElementaryStreamDescriptors: nil, - StreamType: astits.StreamTypeH264Video, + ElementaryPID: pid, + StreamType: astits.StreamTypeH264Video, }, nil } diff --git a/pkg/formats/mpegts/codec_h265.go b/pkg/formats/mpegts/codec_h265.go index 1863184..6fb6b91 100644 --- a/pkg/formats/mpegts/codec_h265.go +++ b/pkg/formats/mpegts/codec_h265.go @@ -16,8 +16,7 @@ func (*CodecH265) isCodec() {} func (c CodecH265) marshal(pid uint16) (*astits.PMTElementaryStream, error) { return &astits.PMTElementaryStream{ - ElementaryPID: pid, - ElementaryStreamDescriptors: nil, - StreamType: astits.StreamTypeH265Video, + ElementaryPID: pid, + StreamType: astits.StreamTypeH265Video, }, nil } diff --git a/pkg/formats/mpegts/codec_mpeg1_audio.go b/pkg/formats/mpegts/codec_mpeg1_audio.go index 8967350..f88cd83 100644 --- a/pkg/formats/mpegts/codec_mpeg1_audio.go +++ b/pkg/formats/mpegts/codec_mpeg1_audio.go @@ -16,8 +16,7 @@ func (*CodecMPEG1Audio) isCodec() {} func (c CodecMPEG1Audio) marshal(pid uint16) (*astits.PMTElementaryStream, error) { return &astits.PMTElementaryStream{ - ElementaryPID: pid, - ElementaryStreamDescriptors: nil, - StreamType: astits.StreamTypeMPEG1Audio, + ElementaryPID: pid, + StreamType: astits.StreamTypeMPEG1Audio, }, nil } diff --git a/pkg/formats/mpegts/codec_mpeg1_video.go b/pkg/formats/mpegts/codec_mpeg1_video.go index e6dbf09..5b03f67 100644 --- a/pkg/formats/mpegts/codec_mpeg1_video.go +++ b/pkg/formats/mpegts/codec_mpeg1_video.go @@ -16,8 +16,7 @@ func (*CodecMPEG1Video) isCodec() {} func (c CodecMPEG1Video) marshal(pid uint16) (*astits.PMTElementaryStream, error) { return &astits.PMTElementaryStream{ - ElementaryPID: pid, - ElementaryStreamDescriptors: nil, + ElementaryPID: pid, // we use MPEG-2 to notify readers that video can be either MPEG-1 or MPEG-2 StreamType: astits.StreamTypeMPEG2Video, }, nil diff --git a/pkg/formats/mpegts/codec_mpeg4_audio.go b/pkg/formats/mpegts/codec_mpeg4_audio.go index d344dbd..d616d40 100644 --- a/pkg/formats/mpegts/codec_mpeg4_audio.go +++ b/pkg/formats/mpegts/codec_mpeg4_audio.go @@ -20,8 +20,7 @@ func (*CodecMPEG4Audio) isCodec() {} func (c CodecMPEG4Audio) marshal(pid uint16) (*astits.PMTElementaryStream, error) { return &astits.PMTElementaryStream{ - ElementaryPID: pid, - ElementaryStreamDescriptors: nil, - StreamType: astits.StreamTypeAACAudio, + ElementaryPID: pid, + StreamType: astits.StreamTypeAACAudio, }, nil } diff --git a/pkg/formats/mpegts/codec_mpeg4_video.go b/pkg/formats/mpegts/codec_mpeg4_video.go index ea30bdf..471b529 100644 --- a/pkg/formats/mpegts/codec_mpeg4_video.go +++ b/pkg/formats/mpegts/codec_mpeg4_video.go @@ -16,8 +16,7 @@ func (*CodecMPEG4Video) isCodec() {} func (c CodecMPEG4Video) marshal(pid uint16) (*astits.PMTElementaryStream, error) { return &astits.PMTElementaryStream{ - ElementaryPID: pid, - ElementaryStreamDescriptors: nil, - StreamType: astits.StreamTypeMPEG4Video, + ElementaryPID: pid, + StreamType: astits.StreamTypeMPEG4Video, }, nil } diff --git a/pkg/formats/mpegts/codec_opus.go b/pkg/formats/mpegts/codec_opus.go index 3e66ee6..a031986 100644 --- a/pkg/formats/mpegts/codec_opus.go +++ b/pkg/formats/mpegts/codec_opus.go @@ -19,6 +19,7 @@ func (*CodecOpus) isCodec() {} func (c CodecOpus) marshal(pid uint16) (*astits.PMTElementaryStream, error) { return &astits.PMTElementaryStream{ ElementaryPID: pid, + StreamType: astits.StreamTypePrivateData, ElementaryStreamDescriptors: []*astits.Descriptor{ { Length: 4, @@ -36,6 +37,5 @@ func (c CodecOpus) marshal(pid uint16) (*astits.PMTElementaryStream, error) { }, }, }, - StreamType: astits.StreamTypePrivateData, }, nil } diff --git a/pkg/formats/mpegts/opus_access_unit.go b/pkg/formats/mpegts/opus_access_unit.go index db5aa90..5af6112 100644 --- a/pkg/formats/mpegts/opus_access_unit.go +++ b/pkg/formats/mpegts/opus_access_unit.go @@ -12,7 +12,7 @@ type opusAccessUnit struct { func (au *opusAccessUnit) unmarshal(buf []byte) (int, error) { n, err := au.ControlHeader.unmarshal(buf) if err != nil { - return 0, fmt.Errorf("could not decode control header: %w", err) + return 0, fmt.Errorf("invalid control header: %w", err) } buf = buf[n:] diff --git a/pkg/formats/mpegts/reader.go b/pkg/formats/mpegts/reader.go index e7dcfe7..a19dae9 100644 --- a/pkg/formats/mpegts/reader.go +++ b/pkg/formats/mpegts/reader.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "strings" "github.com/asticode/go-astits" @@ -175,7 +176,7 @@ func (r *Reader) OnDataMPEG4Audio(track *Track, cb ReaderOnDataMPEG4AudioFunc) { var pkts mpeg4audio.ADTSPackets err := pkts.Unmarshal(data) if err != nil { - r.onDecodeError(err) + r.onDecodeError(fmt.Errorf("invalid ADTS: %w", err)) return nil } @@ -249,35 +250,42 @@ func (r *Reader) OnDataAC3(track *Track, cb ReaderOnDataAC3Func) { // Read reads data. func (r *Reader) Read() error { - data, err := r.dem.NextData() - if err != nil { - return err - } + for { + data, err := r.dem.NextData() + if err != nil { + // https://github.com/asticode/go-astits/blob/b0b19247aa31633650c32638fb55f597fa6e2468/packet_buffer.go#L133C1-L133C5 + if errors.Is(err, astits.ErrNoMorePackets) || strings.Contains(err.Error(), "astits: reading ") { + return err + } + r.onDecodeError(err) + continue + } - if data.PES == nil { - return nil - } + if data.PES == nil { + return nil + } - if data.PES.Header.OptionalHeader == nil || - data.PES.Header.OptionalHeader.PTSDTSIndicator == astits.PTSDTSIndicatorNoPTSOrDTS || - data.PES.Header.OptionalHeader.PTSDTSIndicator == astits.PTSDTSIndicatorIsForbidden { - r.onDecodeError(fmt.Errorf("PTS is missing")) - return nil - } + if data.PES.Header.OptionalHeader == nil || + data.PES.Header.OptionalHeader.PTSDTSIndicator == astits.PTSDTSIndicatorNoPTSOrDTS || + data.PES.Header.OptionalHeader.PTSDTSIndicator == astits.PTSDTSIndicatorIsForbidden { + r.onDecodeError(fmt.Errorf("PTS is missing")) + return nil + } - pts := data.PES.Header.OptionalHeader.PTS.Base + pts := data.PES.Header.OptionalHeader.PTS.Base - var dts int64 - if data.PES.Header.OptionalHeader.PTSDTSIndicator == astits.PTSDTSIndicatorBothPresent { - dts = data.PES.Header.OptionalHeader.DTS.Base - } else { - dts = pts - } + var dts int64 + if data.PES.Header.OptionalHeader.PTSDTSIndicator == astits.PTSDTSIndicatorBothPresent { + dts = data.PES.Header.OptionalHeader.DTS.Base + } else { + dts = pts + } - onData, ok := r.onData[data.PID] - if !ok { - return nil - } + onData, ok := r.onData[data.PID] + if !ok { + return nil + } - return onData(pts, dts, data.PES.Data) + return onData(pts, dts, data.PES.Data) + } } diff --git a/pkg/formats/mpegts/reader_test.go b/pkg/formats/mpegts/reader_test.go index 9ab5dc6..b8059be 100644 --- a/pkg/formats/mpegts/reader_test.go +++ b/pkg/formats/mpegts/reader_test.go @@ -944,6 +944,373 @@ func TestReader(t *testing.T) { } } +func TestReaderDecodeErrors(t *testing.T) { + for _, ca := range []string{ + "missing pts", + "h26x invalid avcc", + "opus pts != dts", + "opus invalid au", + "mpeg-4 audio pts != dts", + "mpeg-4 audio invalid", + "mpeg-1 audio pts != dts", + "ac-3 pts != dts", + "garbage", + } { + t.Run(ca, func(t *testing.T) { + var buf bytes.Buffer + mux := astits.NewMuxer(context.Background(), &buf) + + switch ca { + case "missing pts", "h26x invalid avcc", "garbage": + err := mux.AddElementaryStream(astits.PMTElementaryStream{ + ElementaryPID: 123, + StreamType: astits.StreamTypeH264Video, + }) + require.NoError(t, err) + + case "opus pts != dts", "opus invalid au": + err := mux.AddElementaryStream(astits.PMTElementaryStream{ + ElementaryPID: 123, + StreamType: astits.StreamTypePrivateData, + ElementaryStreamDescriptors: []*astits.Descriptor{ + { + Length: 4, + Tag: astits.DescriptorTagRegistration, + Registration: &astits.DescriptorRegistration{ + FormatIdentifier: opusIdentifier, + }, + }, + { + Length: 2, + Tag: astits.DescriptorTagExtension, + Extension: &astits.DescriptorExtension{ + Tag: 0x80, + Unknown: &[]uint8{2}, + }, + }, + }, + }) + require.NoError(t, err) + + case "mpeg-4 audio pts != dts", "mpeg-4 audio invalid": + err := mux.AddElementaryStream(astits.PMTElementaryStream{ + ElementaryPID: 123, + StreamType: astits.StreamTypeAACAudio, + }) + require.NoError(t, err) + + case "mpeg-1 audio pts != dts": + err := mux.AddElementaryStream(astits.PMTElementaryStream{ + ElementaryPID: 123, + StreamType: astits.StreamTypeMPEG1Audio, + }) + require.NoError(t, err) + + case "ac-3 pts != dts": + err := mux.AddElementaryStream(astits.PMTElementaryStream{ + ElementaryPID: 123, + StreamType: astits.StreamTypeAC3Audio, + }) + require.NoError(t, err) + } + + mux.SetPCRPID(123) + + switch ca { + case "missing pts": + _, err := mux.WriteData(&astits.MuxerData{ + PID: 123, + PES: &astits.PESData{ + Header: &astits.PESHeader{ + OptionalHeader: &astits.PESOptionalHeader{ + MarkerBits: 2, + PTSDTSIndicator: astits.PTSDTSIndicatorNoPTSOrDTS, + }, + StreamID: streamIDVideo, + }, + Data: []byte{1, 2, 3, 4}, + }, + }) + require.NoError(t, err) + + case "h26x invalid avcc", "opus invalid au": + _, err := mux.WriteData(&astits.MuxerData{ + PID: 123, + PES: &astits.PESData{ + Header: &astits.PESHeader{ + OptionalHeader: &astits.PESOptionalHeader{ + MarkerBits: 2, + PTSDTSIndicator: astits.PTSDTSIndicatorOnlyPTS, + PTS: &astits.ClockReference{Base: 90000}, + }, + StreamID: streamIDVideo, + }, + Data: []byte{1, 2, 3, 4}, + }, + }) + require.NoError(t, err) + + case "opus pts != dts", "mpeg-1 audio pts != dts": + _, err := mux.WriteData(&astits.MuxerData{ + PID: 123, + PES: &astits.PESData{ + Header: &astits.PESHeader{ + OptionalHeader: &astits.PESOptionalHeader{ + MarkerBits: 2, + PTSDTSIndicator: astits.PTSDTSIndicatorBothPresent, + PTS: &astits.ClockReference{Base: 90000}, + DTS: &astits.ClockReference{Base: 180000}, + }, + StreamID: streamIDVideo, + }, + Data: []byte{1, 2, 3, 4}, + }, + }) + require.NoError(t, err) + + case "mpeg-4 audio pts != dts": + data, _ := mpeg4audio.ADTSPackets{{ + Type: mpeg4audio.ObjectTypeAACLC, + SampleRate: 44100, + ChannelCount: 1, + AU: []byte{1, 2, 3, 4}, + }}.Marshal() + + _, err := mux.WriteData(&astits.MuxerData{ + PID: 123, + PES: &astits.PESData{ + Header: &astits.PESHeader{ + OptionalHeader: &astits.PESOptionalHeader{ + MarkerBits: 2, + PTSDTSIndicator: astits.PTSDTSIndicatorBothPresent, + PTS: &astits.ClockReference{Base: 90000}, + DTS: &astits.ClockReference{Base: 180000}, + }, + StreamID: streamIDVideo, + }, + Data: data, + }, + }) + require.NoError(t, err) + + _, err = mux.WriteData(&astits.MuxerData{ + PID: 123, + PES: &astits.PESData{ + Header: &astits.PESHeader{ + OptionalHeader: &astits.PESOptionalHeader{ + MarkerBits: 2, + PTSDTSIndicator: astits.PTSDTSIndicatorBothPresent, + PTS: &astits.ClockReference{Base: 90000}, + DTS: &astits.ClockReference{Base: 180000}, + }, + StreamID: streamIDVideo, + }, + Data: []byte{1, 2, 3, 4}, + }, + }) + require.NoError(t, err) + + case "mpeg-4 audio invalid": + data, _ := mpeg4audio.ADTSPackets{{ + Type: mpeg4audio.ObjectTypeAACLC, + SampleRate: 44100, + ChannelCount: 1, + AU: []byte{1, 2, 3, 4}, + }}.Marshal() + + _, err := mux.WriteData(&astits.MuxerData{ + PID: 123, + PES: &astits.PESData{ + Header: &astits.PESHeader{ + OptionalHeader: &astits.PESOptionalHeader{ + MarkerBits: 2, + PTSDTSIndicator: astits.PTSDTSIndicatorOnlyPTS, + PTS: &astits.ClockReference{Base: 90000}, + }, + StreamID: streamIDVideo, + }, + Data: data, + }, + }) + require.NoError(t, err) + + _, err = mux.WriteData(&astits.MuxerData{ + PID: 123, + PES: &astits.PESData{ + Header: &astits.PESHeader{ + OptionalHeader: &astits.PESOptionalHeader{ + MarkerBits: 2, + PTSDTSIndicator: astits.PTSDTSIndicatorOnlyPTS, + PTS: &astits.ClockReference{Base: 90000}, + }, + StreamID: streamIDVideo, + }, + Data: []byte{1, 2, 3, 4}, + }, + }) + require.NoError(t, err) + + case "ac-3 pts != dts": + _, err := mux.WriteData(&astits.MuxerData{ + PID: 123, + PES: &astits.PESData{ + Header: &astits.PESHeader{ + OptionalHeader: &astits.PESOptionalHeader{ + MarkerBits: 2, + PTSDTSIndicator: astits.PTSDTSIndicatorBothPresent, + PTS: &astits.ClockReference{Base: 90000}, + DTS: &astits.ClockReference{Base: 180000}, + }, + StreamID: streamIDVideo, + }, + Data: []byte{ + 0x0b, 0x77, 0x47, 0x11, 0x0c, 0x40, 0x2f, 0x84, + 0x2b, 0xc1, 0x07, 0x7a, 0xb0, 0xfa, 0xbb, 0xea, + 0xef, 0x9f, 0x57, 0x7c, 0xf9, 0xf3, 0xf7, 0xcf, + 0x9f, 0x3e, 0x32, 0xfe, 0xd5, 0xc1, 0x50, 0xde, + 0xc5, 0x1e, 0x73, 0xd2, 0x6c, 0xa6, 0x94, 0x46, + }, + }, + }) + require.NoError(t, err) + + _, err = mux.WriteData(&astits.MuxerData{ + PID: 123, + PES: &astits.PESData{ + Header: &astits.PESHeader{ + OptionalHeader: &astits.PESOptionalHeader{ + MarkerBits: 2, + PTSDTSIndicator: astits.PTSDTSIndicatorBothPresent, + PTS: &astits.ClockReference{Base: 90000}, + DTS: &astits.ClockReference{Base: 180000}, + }, + StreamID: streamIDVideo, + }, + Data: []byte{1, 2, 3, 4}, + }, + }) + require.NoError(t, err) + + case "garbage": + _, err := mux.WriteData(&astits.MuxerData{ + PID: 123, + PES: &astits.PESData{ + Header: &astits.PESHeader{ + OptionalHeader: &astits.PESOptionalHeader{ + MarkerBits: 2, + PTSDTSIndicator: astits.PTSDTSIndicatorOnlyPTS, + PTS: &astits.ClockReference{Base: 90000}, + }, + StreamID: streamIDVideo, + }, + Data: []byte{0, 0, 0, 1, 1, 2, 3, 4}, + }, + }) + require.NoError(t, err) + + buf.Write(bytes.Repeat([]byte{1, 2, 3, 4}, 188/4)) + + _, err = mux.WriteData(&astits.MuxerData{ + PID: 123, + PES: &astits.PESData{ + Header: &astits.PESHeader{ + OptionalHeader: &astits.PESOptionalHeader{ + MarkerBits: 2, + PTSDTSIndicator: astits.PTSDTSIndicatorOnlyPTS, + PTS: &astits.ClockReference{Base: 90000}, + }, + StreamID: streamIDVideo, + }, + Data: []byte{0, 0, 0, 1, 1, 2, 3, 4}, + }, + }) + require.NoError(t, err) + } + + r, err := NewReader(bytes.NewReader(buf.Bytes())) + require.NoError(t, err) + + dataRecv := false + + switch ca { + case "missing pts", "h26x invalid avcc": + r.OnDataH26x(r.Tracks()[0], func(pts, dts int64, au [][]byte) error { + return nil + }) + + case "opus pts != dts", "opus invalid au": + r.OnDataOpus(r.Tracks()[0], func(pts int64, packets [][]byte) error { + return nil + }) + + case "mpeg-4 audio pts != dts", "mpeg-4 audio invalid": + r.OnDataMPEG4Audio(r.Tracks()[0], func(pts int64, aus [][]byte) error { + return nil + }) + + case "mpeg-1 audio pts != dts": + r.OnDataMPEG1Audio(r.Tracks()[0], func(pts int64, aus [][]byte) error { + return nil + }) + + case "ac-3 pts != dts": + r.OnDataAC3(r.Tracks()[0], func(pts int64, frame []byte) error { + return nil + }) + + case "garbage": + counter := 0 + r.OnDataH26x(r.Tracks()[0], func(pts, dts int64, au [][]byte) error { + counter++ + if counter == 2 { + dataRecv = true + } + return nil + }) + } + + decodeErrRecv := false + + r.OnDecodeError(func(err error) { + switch ca { + case "missing pts": + require.EqualError(t, err, "PTS is missing") + + case "h26x invalid avcc": + require.EqualError(t, err, "initial delimiter not found") + + case "opus pts != dts", "mpeg-4 audio pts != dts", "mpeg-1 audio pts != dts", "ac-3 pts != dts": + require.EqualError(t, err, "PTS is not equal to DTS") + + case "opus invalid au": + require.EqualError(t, err, "invalid control header: invalid prefix") + + case "mpeg-4 audio invalid": + require.EqualError(t, err, "invalid ADTS: invalid length") + + case "garbage": + require.EqualError(t, err, "astits: fetching next packet failed: astits: fetching next packet from buffer failed: astits: building packet failed: astits: packet must start with a sync byte") + } + decodeErrRecv = true + }) + + for { + err := r.Read() + if err != nil { + require.Equal(t, astits.ErrNoMorePackets, err) + break + } + } + + require.Equal(t, true, decodeErrRecv) + + if ca == "garbage" { + require.Equal(t, true, dataRecv) + } + }) + } +} + func FuzzReader(f *testing.F) { for _, ca := range casesReadWriter { var buf bytes.Buffer