From 0dba391d67a4cf564371c44a3a781d1e9e3507c2 Mon Sep 17 00:00:00 2001 From: Hans Leidekker Date: Fri, 13 Feb 2015 13:39:38 +0100 Subject: [PATCH] msi: Don't write streams to storage until the database is committed. Native allows streams to be created with names that exceed the maximum length allowed by OLE storages. These streams can be used normally, it's just not possible to commit such a database. --- dlls/msi/database.c | 40 ++-- dlls/msi/msipriv.h | 19 +- dlls/msi/msiquery.c | 9 +- dlls/msi/streams.c | 532 ++++++++++++++++++++++++++++++---------------------- dlls/msi/table.c | 18 +- dlls/msi/tests/db.c | 55 +++++- 6 files changed, 404 insertions(+), 269 deletions(-) diff --git a/dlls/msi/database.c b/dlls/msi/database.c index 9749d66a018..940d3d9567d 100644 --- a/dlls/msi/database.c +++ b/dlls/msi/database.c @@ -54,35 +54,6 @@ WINE_DEFAULT_DEBUG_CHANNEL(msi); #define IS_INTMSIDBOPEN(x) (((ULONG_PTR)(x) >> 16) == 0) -typedef struct tagMSITRANSFORM { - struct list entry; - IStorage *stg; -} MSITRANSFORM; - -UINT msi_get_raw_stream( MSIDATABASE *db, LPCWSTR stname, IStream **stm ) -{ - HRESULT r; - WCHAR decoded[MAX_STREAM_NAME_LEN + 1]; - - decode_streamname( stname, decoded ); - TRACE("%s -> %s\n", debugstr_w(stname), debugstr_w(decoded)); - - r = IStorage_OpenStream( db->storage, stname, NULL, STGM_READ | STGM_SHARE_EXCLUSIVE, 0, stm ); - if (FAILED( r )) - { - MSITRANSFORM *transform; - - LIST_FOR_EACH_ENTRY( transform, &db->transforms, MSITRANSFORM, entry ) - { - r = IStorage_OpenStream( transform->stg, stname, NULL, STGM_READ | STGM_SHARE_EXCLUSIVE, 0, stm ); - if (SUCCEEDED( r )) - break; - } - } - - return SUCCEEDED(r) ? ERROR_SUCCESS : ERROR_FUNCTION_FAILED; -} - static void free_transforms( MSIDATABASE *db ) { while( !list_empty( &db->transforms ) ) @@ -94,6 +65,16 @@ static void free_transforms( MSIDATABASE *db ) } } +static void free_streams( MSIDATABASE *db ) +{ + UINT i; + for (i = 0; i < db->num_streams; i++) + { + if (db->streams[i].stream) IStream_Release( db->streams[i].stream ); + } + msi_free( db->streams ); +} + void append_storage_to_db( MSIDATABASE *db, IStorage *stg ) { MSITRANSFORM *t; @@ -109,6 +90,7 @@ static VOID MSI_CloseDatabase( MSIOBJECTHDR *arg ) MSIDATABASE *db = (MSIDATABASE *) arg; msi_free(db->path); + free_streams( db ); free_cached_tables( db ); free_transforms( db ); if (db->strings) msi_destroy_stringtable( db->strings ); diff --git a/dlls/msi/msipriv.h b/dlls/msi/msipriv.h index e4a82355cbc..db2a2af215c 100644 --- a/dlls/msi/msipriv.h +++ b/dlls/msi/msipriv.h @@ -80,6 +80,18 @@ struct tagMSIOBJECTHDR #define MSI_INITIAL_MEDIA_TRANSFORM_OFFSET 10000 #define MSI_INITIAL_MEDIA_TRANSFORM_DISKID 30000 +typedef struct tagMSISTREAM +{ + UINT str_index; + IStream *stream; +} MSISTREAM; + +typedef struct tagMSITRANSFORM +{ + struct list entry; + IStorage *stg; +} MSITRANSFORM; + typedef struct tagMSIDATABASE { MSIOBJECTHDR hdr; @@ -93,6 +105,9 @@ typedef struct tagMSIDATABASE UINT media_transform_disk_id; struct list tables; struct list transforms; + MSISTREAM *streams; + UINT num_streams; + UINT num_streams_allocated; } MSIDATABASE; typedef struct tagMSIVIEW MSIVIEW; @@ -741,6 +756,7 @@ extern void msi_free_handle_table(void) DECLSPEC_HIDDEN; extern void free_cached_tables( MSIDATABASE *db ) DECLSPEC_HIDDEN; extern UINT MSI_CommitTables( MSIDATABASE *db ) DECLSPEC_HIDDEN; +extern UINT msi_commit_streams( MSIDATABASE *db ) DECLSPEC_HIDDEN; /* string table functions */ @@ -828,8 +844,7 @@ extern LPWSTR encode_streamname(BOOL bTable, LPCWSTR in) DECLSPEC_HIDDEN; extern BOOL decode_streamname(LPCWSTR in, LPWSTR out) DECLSPEC_HIDDEN; /* database internals */ -extern UINT msi_get_raw_stream( MSIDATABASE *, LPCWSTR, IStream ** ) DECLSPEC_HIDDEN; -void msi_destroy_stream( MSIDATABASE *, const WCHAR * ) DECLSPEC_HIDDEN; +extern UINT msi_get_stream( MSIDATABASE *, const WCHAR *, IStream ** ) DECLSPEC_HIDDEN; extern UINT MSI_OpenDatabaseW( LPCWSTR, LPCWSTR, MSIDATABASE ** ) DECLSPEC_HIDDEN; extern UINT MSI_DatabaseOpenViewW(MSIDATABASE *, LPCWSTR, MSIQUERY ** ) DECLSPEC_HIDDEN; extern UINT MSI_OpenQuery( MSIDATABASE *, MSIQUERY **, LPCWSTR, ... ) DECLSPEC_HIDDEN; diff --git a/dlls/msi/msiquery.c b/dlls/msi/msiquery.c index 451fdddbbb0..467cd496bee 100644 --- a/dlls/msi/msiquery.c +++ b/dlls/msi/msiquery.c @@ -840,8 +840,13 @@ UINT WINAPI MsiDatabaseCommit( MSIHANDLE hdb ) /* FIXME: lock the database */ - r = MSI_CommitTables( db ); - if (r != ERROR_SUCCESS) ERR("Failed to commit tables!\n"); + r = msi_commit_streams( db ); + if (r != ERROR_SUCCESS) ERR("Failed to commit streams!\n"); + else + { + r = MSI_CommitTables( db ); + if (r != ERROR_SUCCESS) ERR("Failed to commit tables!\n"); + } /* FIXME: unlock the database */ diff --git a/dlls/msi/streams.c b/dlls/msi/streams.c index d795c8ad023..44a322c89fb 100644 --- a/dlls/msi/streams.c +++ b/dlls/msi/streams.c @@ -2,6 +2,7 @@ * Implementation of the Microsoft Installer (msi.dll) * * Copyright 2007 James Hawkins + * Copyright 2015 Hans Leidekker for CodeWeavers * * This library is free software; you can redistribute it and/or * modify it under the terms of the GNU Lesser General Public @@ -38,54 +39,30 @@ WINE_DEFAULT_DEBUG_CHANNEL(msidb); #define NUM_STREAMS_COLS 2 -typedef struct tabSTREAM -{ - UINT str_index; - IStream *stream; -} STREAM; - typedef struct tagMSISTREAMSVIEW { MSIVIEW view; MSIDATABASE *db; - STREAM **streams; - UINT max_streams; - UINT num_rows; - UINT row_size; + UINT num_cols; } MSISTREAMSVIEW; -static BOOL streams_set_table_size(MSISTREAMSVIEW *sv, UINT size) +static BOOL streams_resize_table( MSIDATABASE *db, UINT size ) { - if (size >= sv->max_streams) + if (!db->num_streams_allocated) { - sv->max_streams *= 2; - sv->streams = msi_realloc_zero(sv->streams, sv->max_streams * sizeof(STREAM *)); - if (!sv->streams) - return FALSE; + if (!(db->streams = msi_alloc_zero( size * sizeof(MSISTREAM) ))) return FALSE; + db->num_streams_allocated = size; + return TRUE; } - - return TRUE; -} - -static STREAM *create_stream(MSISTREAMSVIEW *sv, LPCWSTR name, BOOL encoded, IStream *stm) -{ - STREAM *stream; - WCHAR decoded[MAX_STREAM_NAME_LEN + 1]; - - stream = msi_alloc(sizeof(STREAM)); - if (!stream) - return NULL; - - if (encoded) + while (size >= db->num_streams_allocated) { - decode_streamname(name, decoded); - TRACE("stream -> %s %s\n", debugstr_w(name), debugstr_w(decoded)); - name = decoded; + MSISTREAM *tmp; + UINT new_size = db->num_streams_allocated * 2; + if (!(tmp = msi_realloc_zero( db->streams, new_size * sizeof(MSISTREAM) ))) return FALSE; + db->streams = tmp; + db->num_streams_allocated = new_size; } - - stream->str_index = msi_add_string(sv->db->strings, name, -1, StringNonPersistent); - stream->stream = stm; - return stream; + return TRUE; } static UINT STREAMS_fetch_int(struct tagMSIVIEW *view, UINT row, UINT col, UINT *val) @@ -97,10 +74,10 @@ static UINT STREAMS_fetch_int(struct tagMSIVIEW *view, UINT row, UINT col, UINT if (col != 1) return ERROR_INVALID_PARAMETER; - if (row >= sv->num_rows) + if (row >= sv->db->num_streams) return ERROR_NO_MORE_ITEMS; - *val = sv->streams[row]->str_index; + *val = sv->db->streams[row].str_index; return ERROR_SUCCESS; } @@ -108,14 +85,21 @@ static UINT STREAMS_fetch_int(struct tagMSIVIEW *view, UINT row, UINT col, UINT static UINT STREAMS_fetch_stream(struct tagMSIVIEW *view, UINT row, UINT col, IStream **stm) { MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view; + LARGE_INTEGER pos; + HRESULT hr; TRACE("(%p, %d, %d, %p)\n", view, row, col, stm); - if (row >= sv->num_rows) + if (row >= sv->db->num_streams) return ERROR_FUNCTION_FAILED; - IStream_AddRef(sv->streams[row]->stream); - *stm = sv->streams[row]->stream; + pos.QuadPart = 0; + hr = IStream_Seek( sv->db->streams[row].stream, pos, STREAM_SEEK_SET, NULL ); + if (FAILED( hr )) + return ERROR_FUNCTION_FAILED; + + *stm = sv->db->streams[row].stream; + IStream_AddRef( *stm ); return ERROR_SUCCESS; } @@ -132,108 +116,94 @@ static UINT STREAMS_get_row( struct tagMSIVIEW *view, UINT row, MSIRECORD **rec static UINT STREAMS_set_row(struct tagMSIVIEW *view, UINT row, MSIRECORD *rec, UINT mask) { MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view; - STREAM *stream; - IStream *stm; - STATSTG stat; - LPWSTR encname = NULL, name = NULL; - USHORT *data = NULL; - HRESULT hr; - ULONG count; - UINT r = ERROR_FUNCTION_FAILED; TRACE("(%p, %d, %p, %08x)\n", view, row, rec, mask); - if (row > sv->num_rows) - return ERROR_FUNCTION_FAILED; - - r = MSI_RecordGetIStream(rec, 2, &stm); - if (r != ERROR_SUCCESS) - return r; + if (row > sv->db->num_streams || mask >= (1 << sv->num_cols)) + return ERROR_INVALID_PARAMETER; - hr = IStream_Stat(stm, &stat, STATFLAG_NONAME); - if (FAILED(hr)) + if (mask & 1) { - WARN("failed to stat stream: %08x\n", hr); - goto done; - } + const WCHAR *name = MSI_RecordGetString( rec, 1 ); - if (stat.cbSize.QuadPart >> 32) - { - WARN("stream too large\n"); - goto done; + if (!name) return ERROR_INVALID_PARAMETER; + sv->db->streams[row].str_index = msi_add_string( sv->db->strings, name, -1, StringNonPersistent ); } - - data = msi_alloc(stat.cbSize.QuadPart); - if (!data) - goto done; - - hr = IStream_Read(stm, data, stat.cbSize.QuadPart, &count); - if (FAILED(hr) || count != stat.cbSize.QuadPart) + if (mask & 2) { - WARN("failed to read stream: %08x\n", hr); - goto done; - } + IStream *old, *new; + HRESULT hr; + UINT r; - name = strdupW(MSI_RecordGetString(rec, 1)); - if (!name) - { - WARN("failed to retrieve stream name\n"); - goto done; - } + r = MSI_RecordGetIStream( rec, 2, &new ); + if (r != ERROR_SUCCESS) + return r; - r = write_stream_data(sv->db->storage, name, data, count, FALSE); - if (r != ERROR_SUCCESS) - { - WARN("failed to write stream data: %d\n", r); - goto done; + old = sv->db->streams[row].stream; + hr = IStream_QueryInterface( new, &IID_IStream, (void **)&sv->db->streams[row].stream ); + if (FAILED( hr )) + { + IStream_Release( new ); + return ERROR_FUNCTION_FAILED; + } + if (old) IStream_Release( old ); } - stream = create_stream(sv, name, FALSE, NULL); - if (!stream) - goto done; + return ERROR_SUCCESS; +} - encname = encode_streamname(FALSE, name); - hr = IStorage_OpenStream(sv->db->storage, encname, 0, - STGM_READ | STGM_SHARE_EXCLUSIVE, 0, &stream->stream); - if (FAILED(hr)) - { - WARN("failed to open stream: %08x\n", hr); - msi_free(stream); - goto done; - } +static UINT streams_find_row( MSISTREAMSVIEW *sv, MSIRECORD *rec, UINT *row ) +{ + const WCHAR *str; + UINT r, i, id, val; - sv->streams[row] = stream; + str = MSI_RecordGetString( rec, 1 ); + r = msi_string2id( sv->db->strings, str, -1, &id ); + if (r != ERROR_SUCCESS) + return r; -done: - msi_free(name); - msi_free(data); - msi_free(encname); + for (i = 0; i < sv->db->num_streams; i++) + { + STREAMS_fetch_int( &sv->view, i, 1, &val ); - IStream_Release(stm); + if (val == id) + { + if (row) *row = i; + return ERROR_SUCCESS; + } + } - return r; + return ERROR_FUNCTION_FAILED; } static UINT STREAMS_insert_row(struct tagMSIVIEW *view, MSIRECORD *rec, UINT row, BOOL temporary) { MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view; - UINT i; + UINT i, r, num_rows = sv->db->num_streams + 1; TRACE("(%p, %p, %d, %d)\n", view, rec, row, temporary); - if (!streams_set_table_size(sv, ++sv->num_rows)) + r = streams_find_row( sv, rec, NULL ); + if (r == ERROR_SUCCESS) + return ERROR_FUNCTION_FAILED; + + if (!streams_resize_table( sv->db, num_rows )) return ERROR_FUNCTION_FAILED; if (row == -1) - row = sv->num_rows - 1; + row = num_rows - 1; /* shift the rows to make room for the new row */ - for (i = sv->num_rows - 1; i > row; i--) + for (i = num_rows - 1; i > row; i--) { - sv->streams[i] = sv->streams[i - 1]; + sv->db->streams[i] = sv->db->streams[i - 1]; } - return STREAMS_set_row(view, row, rec, 0); + r = STREAMS_set_row( view, row, rec, (1 << sv->num_cols) - 1 ); + if (r == ERROR_SUCCESS) + sv->db->num_streams = num_rows; + + return r; } static UINT STREAMS_delete_row(struct tagMSIVIEW *view, UINT row) @@ -260,8 +230,8 @@ static UINT STREAMS_get_dimensions(struct tagMSIVIEW *view, UINT *rows, UINT *co TRACE("(%p, %p, %p)\n", view, rows, cols); - if (cols) *cols = NUM_STREAMS_COLS; - if (rows) *rows = sv->num_rows; + if (cols) *cols = sv->num_cols; + if (rows) *rows = sv->db->num_streams; return ERROR_SUCCESS; } @@ -269,10 +239,11 @@ static UINT STREAMS_get_dimensions(struct tagMSIVIEW *view, UINT *rows, UINT *co static UINT STREAMS_get_column_info( struct tagMSIVIEW *view, UINT n, LPCWSTR *name, UINT *type, BOOL *temporary, LPCWSTR *table_name ) { - TRACE("(%p, %d, %p, %p, %p, %p)\n", view, n, name, type, temporary, - table_name); + MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view; + + TRACE("(%p, %d, %p, %p, %p, %p)\n", view, n, name, type, temporary, table_name); - if (n == 0 || n > NUM_STREAMS_COLS) + if (!n || n > sv->num_cols) return ERROR_INVALID_PARAMETER; switch (n) @@ -292,30 +263,6 @@ static UINT STREAMS_get_column_info( struct tagMSIVIEW *view, UINT n, LPCWSTR *n return ERROR_SUCCESS; } -static UINT streams_find_row(MSISTREAMSVIEW *sv, MSIRECORD *rec, UINT *row) -{ - LPCWSTR str; - UINT r, i, id, data; - - str = MSI_RecordGetString(rec, 1); - r = msi_string2id(sv->db->strings, str, -1, &id); - if (r != ERROR_SUCCESS) - return r; - - for (i = 0; i < sv->num_rows; i++) - { - STREAMS_fetch_int(&sv->view, i, 1, &data); - - if (data == id) - { - *row = i; - return ERROR_SUCCESS; - } - } - - return ERROR_FUNCTION_FAILED; -} - static UINT streams_modify_update(struct tagMSIVIEW *view, MSIRECORD *rec) { MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view; @@ -325,15 +272,15 @@ static UINT streams_modify_update(struct tagMSIVIEW *view, MSIRECORD *rec) if (r != ERROR_SUCCESS) return ERROR_FUNCTION_FAILED; - return STREAMS_set_row(view, row, rec, 0); + return STREAMS_set_row( view, row, rec, (1 << sv->num_cols) - 1 ); } static UINT streams_modify_assign(struct tagMSIVIEW *view, MSIRECORD *rec) { MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view; - UINT r, row; + UINT r; - r = streams_find_row(sv, rec, &row); + r = streams_find_row( sv, rec, NULL ); if (r == ERROR_SUCCESS) return streams_modify_update(view, rec); @@ -383,23 +330,10 @@ static UINT STREAMS_modify(struct tagMSIVIEW *view, MSIMODIFY eModifyMode, MSIRE static UINT STREAMS_delete(struct tagMSIVIEW *view) { MSISTREAMSVIEW *sv = (MSISTREAMSVIEW *)view; - UINT i; TRACE("(%p)\n", view); - for (i = 0; i < sv->num_rows; i++) - { - if (sv->streams[i]) - { - if (sv->streams[i]->stream) - IStream_Release(sv->streams[i]->stream); - msi_free(sv->streams[i]); - } - } - - msi_free(sv->streams); msi_free(sv); - return ERROR_SUCCESS; } @@ -411,23 +345,22 @@ static UINT STREAMS_find_matching_rows(struct tagMSIVIEW *view, UINT col, TRACE("(%p, %d, %d, %p, %p)\n", view, col, val, row, handle); - if (col == 0 || col > NUM_STREAMS_COLS) + if (!col || col > sv->num_cols) return ERROR_INVALID_PARAMETER; - while (index < sv->num_rows) + while (index < sv->db->num_streams) { - if (sv->streams[index]->str_index == val) + if (sv->db->streams[index].str_index == val) { *row = index; break; } - index++; } *handle = UlongToPtr(++index); - if (index > sv->num_rows) + if (index > sv->db->num_streams) return ERROR_NO_MORE_ITEMS; return ERROR_SUCCESS; @@ -456,107 +389,256 @@ static const MSIVIEWOPS streams_ops = NULL, }; -static INT add_streams_to_table(MSISTREAMSVIEW *sv) +static HRESULT open_stream( MSIDATABASE *db, const WCHAR *name, IStream **stream ) { - IEnumSTATSTG *stgenum = NULL; - STATSTG stat; - STREAM *stream = NULL; HRESULT hr; - UINT r, count = 0, size; - LPWSTR encname; - hr = IStorage_EnumElements(sv->db->storage, 0, NULL, 0, &stgenum); - if (FAILED(hr)) - return -1; - - sv->max_streams = 1; - sv->streams = msi_alloc_zero(sizeof(STREAM *)); - if (!sv->streams) - return -1; - - while (TRUE) + hr = IStorage_OpenStream( db->storage, name, NULL, STGM_READ|STGM_SHARE_EXCLUSIVE, 0, stream ); + if (FAILED( hr )) { - size = 0; - hr = IEnumSTATSTG_Next(stgenum, 1, &stat, &size); - if (FAILED(hr) || !size) - break; + MSITRANSFORM *transform; - if (stat.type != STGTY_STREAM) + LIST_FOR_EACH_ENTRY( transform, &db->transforms, MSITRANSFORM, entry ) { - CoTaskMemFree(stat.pwcsName); - continue; + hr = IStorage_OpenStream( transform->stg, name, NULL, STGM_READ|STGM_SHARE_EXCLUSIVE, 0, stream ); + if (SUCCEEDED( hr )) + break; } + } + return hr; +} + +static MSISTREAM *find_stream( MSIDATABASE *db, const WCHAR *name ) +{ + UINT r, id, i; + + r = msi_string2id( db->strings, name, -1, &id ); + if (r != ERROR_SUCCESS) + return NULL; + + for (i = 0; i < db->num_streams; i++) + { + if (db->streams[i].str_index == id) return &db->streams[i]; + } + return NULL; +} + +static UINT append_stream( MSIDATABASE *db, const WCHAR *name, IStream *stream ) +{ + WCHAR decoded[MAX_STREAM_NAME_LEN + 1]; + UINT i = db->num_streams; + + if (!streams_resize_table( db, db->num_streams + 1 )) + return ERROR_OUTOFMEMORY; + + decode_streamname( name, decoded ); + db->streams[i].str_index = msi_add_string( db->strings, decoded, -1, StringNonPersistent ); + db->streams[i].stream = stream; + db->num_streams++; + + TRACE("added %s\n", debugstr_w( decoded )); + return ERROR_SUCCESS; +} + +static UINT load_streams( MSIDATABASE *db ) +{ + IEnumSTATSTG *stgenum; + STATSTG stat; + HRESULT hr; + UINT count, r = ERROR_SUCCESS; + IStream *stream; + + hr = IStorage_EnumElements( db->storage, 0, NULL, 0, &stgenum ); + if (FAILED( hr )) + return ERROR_FUNCTION_FAILED; + + for (;;) + { + count = 0; + hr = IEnumSTATSTG_Next( stgenum, 1, &stat, &count ); + if (FAILED( hr ) || !count) + break; /* table streams are not in the _Streams table */ - if (*stat.pwcsName == 0x4840) + if (stat.type != STGTY_STREAM || *stat.pwcsName == 0x4840 || + find_stream( db, stat.pwcsName )) { - CoTaskMemFree(stat.pwcsName); + CoTaskMemFree( stat.pwcsName ); continue; } + TRACE("found new stream %s\n", debugstr_w( stat.pwcsName )); - stream = create_stream(sv, stat.pwcsName, TRUE, NULL); - if (!stream) + hr = open_stream( db, stat.pwcsName, &stream ); + if (FAILED( hr )) { - count = -1; - CoTaskMemFree(stat.pwcsName); + ERR("unable to open stream %08x\n", hr); + CoTaskMemFree( stat.pwcsName ); + r = ERROR_FUNCTION_FAILED; break; } - /* these streams appear to be unencoded */ - if (*stat.pwcsName == 0x0005) - { - r = msi_get_raw_stream(sv->db, stat.pwcsName, &stream->stream); - } - else - { - encname = encode_streamname(FALSE, stat.pwcsName); - r = msi_get_raw_stream(sv->db, encname, &stream->stream); - msi_free(encname); - } - CoTaskMemFree(stat.pwcsName); - + r = append_stream( db, stat.pwcsName, stream ); + CoTaskMemFree( stat.pwcsName ); if (r != ERROR_SUCCESS) - { - WARN("unable to get stream %u\n", r); - count = -1; break; - } + } - if (!streams_set_table_size(sv, ++count)) - { - count = -1; - break; - } + TRACE("loaded %u streams\n", db->num_streams); + IEnumSTATSTG_Release( stgenum ); + return r; +} + +UINT msi_get_stream( MSIDATABASE *db, const WCHAR *name, IStream **ret ) +{ + MSISTREAM *stream; + WCHAR *encname; + HRESULT hr; + UINT r; - sv->streams[count - 1] = stream; + if ((stream = find_stream( db, name ))) + { + LARGE_INTEGER pos; + + pos.QuadPart = 0; + hr = IStream_Seek( stream->stream, pos, STREAM_SEEK_SET, NULL ); + if (FAILED( hr )) + return ERROR_FUNCTION_FAILED; + + *ret = stream->stream; + IStream_AddRef( *ret ); + return ERROR_SUCCESS; } - IEnumSTATSTG_Release(stgenum); - return count; + if (!(encname = encode_streamname( FALSE, name ))) + return ERROR_OUTOFMEMORY; + + hr = open_stream( db, encname, ret ); + msi_free( encname ); + if (FAILED( hr )) + return ERROR_FUNCTION_FAILED; + + r = append_stream( db, name, *ret ); + if (r != ERROR_SUCCESS) + { + IStream_Release( *ret ); + return r; + } + + IStream_AddRef( *ret ); + return ERROR_SUCCESS; } UINT STREAMS_CreateView(MSIDATABASE *db, MSIVIEW **view) { MSISTREAMSVIEW *sv; - INT rows; + UINT r; TRACE("(%p, %p)\n", db, view); - sv = msi_alloc_zero( sizeof(MSISTREAMSVIEW) ); - if (!sv) - return ERROR_FUNCTION_FAILED; + r = load_streams( db ); + if (r != ERROR_SUCCESS) + return r; + + if (!(sv = msi_alloc_zero( sizeof(MSISTREAMSVIEW) ))) + return ERROR_OUTOFMEMORY; sv->view.ops = &streams_ops; + sv->num_cols = NUM_STREAMS_COLS; sv->db = db; - rows = add_streams_to_table(sv); - if (rows < 0) + + *view = (MSIVIEW *)sv; + + return ERROR_SUCCESS; +} + +static HRESULT write_stream( IStream *dst, IStream *src ) +{ + HRESULT hr; + char buf[4096]; + STATSTG stat; + LARGE_INTEGER pos; + UINT count, size; + + hr = IStream_Stat( src, &stat, STATFLAG_NONAME ); + if (FAILED( hr )) return hr; + + hr = IStream_SetSize( dst, stat.cbSize ); + if (FAILED( hr )) return hr; + + pos.QuadPart = 0; + hr = IStream_Seek( dst, pos, STREAM_SEEK_SET, NULL ); + if (FAILED( hr )) return hr; + + for (;;) { - msi_free( sv ); - return ERROR_FUNCTION_FAILED; + size = min( sizeof(buf), stat.cbSize.QuadPart ); + hr = IStream_Read( src, buf, size, &count ); + if (FAILED( hr ) || count != size) + { + WARN("failed to read stream: %08x\n", hr); + return E_INVALIDARG; + } + stat.cbSize.QuadPart -= count; + if (count) + { + size = count; + hr = IStream_Write( dst, buf, size, &count ); + if (FAILED( hr ) || count != size) + { + WARN("failed to write stream: %08x\n", hr); + return E_INVALIDARG; + } + } + if (!stat.cbSize.QuadPart) break; } - sv->num_rows = rows; - *view = (MSIVIEW *)sv; + return S_OK; +} + +UINT msi_commit_streams( MSIDATABASE *db ) +{ + UINT i; + const WCHAR *name; + WCHAR *encname; + IStream *stream; + HRESULT hr; + + TRACE("got %u streams\n", db->num_streams); + + for (i = 0; i < db->num_streams; i++) + { + name = msi_string_lookup( db->strings, db->streams[i].str_index, NULL ); + if (!(encname = encode_streamname( FALSE, name ))) return ERROR_OUTOFMEMORY; + + hr = open_stream( db, encname, &stream ); + if (FAILED( hr )) /* new stream */ + { + hr = IStorage_CreateStream( db->storage, encname, STGM_WRITE|STGM_SHARE_EXCLUSIVE, 0, 0, &stream ); + if (FAILED( hr )) + { + ERR("failed to create stream %s (hr = %08x)\n", debugstr_w(encname), hr); + msi_free( encname ); + return ERROR_FUNCTION_FAILED; + } + hr = write_stream( stream, db->streams[i].stream ); + if (FAILED( hr )) + { + ERR("failed to write stream %s (hr = %08x)\n", debugstr_w(encname), hr); + msi_free( encname ); + IStream_Release( stream ); + return ERROR_FUNCTION_FAILED; + } + } + hr = IStream_Commit( stream, 0 ); + IStream_Release( stream ); + if (FAILED( hr )) + { + WARN("failed to commit stream %s (hr = %08x)\n", debugstr_w(encname), hr); + msi_free( encname ); + return ERROR_FUNCTION_FAILED; + } + msi_free( encname ); + } return ERROR_SUCCESS; } diff --git a/dlls/msi/table.c b/dlls/msi/table.c index 53f4b4be6ee..cfe5612ed14 100644 --- a/dlls/msi/table.c +++ b/dlls/msi/table.c @@ -1149,25 +1149,23 @@ static UINT TABLE_fetch_stream( struct tagMSIVIEW *view, UINT row, UINT col, ISt { MSITABLEVIEW *tv = (MSITABLEVIEW*)view; UINT r; - LPWSTR encname, full_name = NULL; + WCHAR *name; if( !view->ops->fetch_int ) return ERROR_INVALID_PARAMETER; - r = msi_stream_name( tv, row, &full_name ); - if ( r != ERROR_SUCCESS ) + r = msi_stream_name( tv, row, &name ); + if (r != ERROR_SUCCESS) { - ERR("fetching stream, error = %d\n", r); + ERR("fetching stream, error = %u\n", r); return r; } - encname = encode_streamname( FALSE, full_name ); - r = msi_get_raw_stream( tv->db, encname, stm ); - if( r ) - ERR("fetching stream %s, error = %d\n",debugstr_w(full_name), r); + r = msi_get_stream( tv->db, name, stm ); + if (r != ERROR_SUCCESS) + ERR("fetching stream %s, error = %u\n", debugstr_w(name), r); - msi_free( full_name ); - msi_free( encname ); + msi_free( name ); return r; } diff --git a/dlls/msi/tests/db.c b/dlls/msi/tests/db.c index 5ac90356dd8..3d8568d0fc1 100644 --- a/dlls/msi/tests/db.c +++ b/dlls/msi/tests/db.c @@ -1669,6 +1669,28 @@ static void test_streamtable(void) MsiViewClose( view ); MsiCloseHandle( view ); + /* try again */ + create_file( "test1.txt" ); + + rec = MsiCreateRecord( 2 ); + MsiRecordSetStringA( rec, 1, "data1" ); + + r = MsiRecordSetStreamA( rec, 2, "test1.txt" ); + ok( r == ERROR_SUCCESS, "Failed to add stream data to the record: %d\n", r ); + + DeleteFileA( "test1.txt" ); + + r = MsiDatabaseOpenViewA( hdb, + "INSERT INTO `_Streams` ( `Name`, `Data` ) VALUES ( ?, ? )", &view ); + ok( r == ERROR_SUCCESS, "Failed to open database view: %d\n", r ); + + r = MsiViewExecute( view, rec ); + ok( r == ERROR_FUNCTION_FAILED, "got %u\n", r ); + + MsiCloseHandle( rec ); + MsiViewClose( view ); + MsiCloseHandle( view ); + r = MsiDatabaseOpenViewA( hdb, "SELECT `Name`, `Data` FROM `_Streams` WHERE `Name` = 'data'", &view ); ok( r == ERROR_SUCCESS, "Failed to open database view: %d\n", r); @@ -1758,7 +1780,7 @@ static void test_streamtable(void) memset(buf, 0, MAX_PATH); r = MsiRecordReadStream( rec, 2, buf, &size ); ok( r == ERROR_SUCCESS, "Failed to get stream: %d\n", r); - todo_wine ok( !lstrcmpA(buf, "test2.txt\n"), "Expected 'test2.txt\\n', got %s\n", buf); + ok( !lstrcmpA(buf, "test2.txt\n"), "Expected 'test2.txt\\n', got %s\n", buf); MsiCloseHandle( rec ); MsiViewClose( view ); @@ -1790,10 +1812,41 @@ static void test_binary(void) ok( r == ERROR_SUCCESS, "Failed to add stream data to the record: %d\n", r); DeleteFileA( "test.txt" ); + /* try a name that exceeds maximum OLE stream name length */ + query = "INSERT INTO `Binary` ( `Name`, `ID`, `Data` ) VALUES ( 'encryption.dll.CB4E6205_F99A_4C51_ADD4_184506EFAB87', 10000, ? )"; + r = run_query( hdb, rec, query ); + ok( r == ERROR_SUCCESS, "Insert into Binary table failed: %d\n", r ); + + r = MsiCloseHandle( rec ); + ok( r == ERROR_SUCCESS , "Failed to close record handle\n" ); + + r = MsiDatabaseCommit( hdb ); + ok( r == ERROR_FUNCTION_FAILED , "got %u\n", r ); + + r = MsiCloseHandle( hdb ); + ok( r == ERROR_SUCCESS , "Failed to close database\n" ); + + r = MsiOpenDatabaseW(msifileW, MSIDBOPEN_CREATE, &hdb ); + ok( r == ERROR_SUCCESS , "Failed to open database\n" ); + + query = "CREATE TABLE `Binary` ( `Name` CHAR(72) NOT NULL, `ID` INT NOT NULL, `Data` OBJECT PRIMARY KEY `Name`, `ID`)"; + r = run_query( hdb, 0, query ); + ok( r == ERROR_SUCCESS, "Cannot create Binary table: %d\n", r ); + + create_file( "test.txt" ); + rec = MsiCreateRecord( 1 ); + r = MsiRecordSetStreamA( rec, 1, "test.txt" ); + ok( r == ERROR_SUCCESS, "Failed to add stream data to the record: %d\n", r ); + DeleteFileA( "test.txt" ); + query = "INSERT INTO `Binary` ( `Name`, `ID`, `Data` ) VALUES ( 'filename1', 1, ? )"; r = run_query( hdb, rec, query ); ok( r == ERROR_SUCCESS, "Insert into Binary table failed: %d\n", r ); + query = "INSERT INTO `Binary` ( `Name`, `ID`, `Data` ) VALUES ( 'filename1', 1, ? )"; + r = run_query( hdb, rec, query ); + ok( r == ERROR_FUNCTION_FAILED, "got %u\n", r ); + r = MsiCloseHandle( rec ); ok( r == ERROR_SUCCESS , "Failed to close record handle\n" ); -- 2.11.4.GIT