Merge 'remotes/trunk'
[0ad.git] / source / graphics / ShaderManager.cpp
blob7256ab155039f90996975908b8e701d240d49f07
1 /* Copyright (C) 2023 Wildfire Games.
2 * This file is part of 0 A.D.
4 * 0 A.D. is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 2 of the License, or
7 * (at your option) any later version.
9 * 0 A.D. is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
14 * You should have received a copy of the GNU General Public License
15 * along with 0 A.D. If not, see <http://www.gnu.org/licenses/>.
18 #include "precompiled.h"
20 #include "ShaderManager.h"
22 #include "graphics/PreprocessorWrapper.h"
23 #include "graphics/ShaderTechnique.h"
24 #include "lib/config2.h"
25 #include "lib/hash.h"
26 #include "lib/timer.h"
27 #include "lib/utf8.h"
28 #include "ps/CLogger.h"
29 #include "ps/CStrIntern.h"
30 #include "ps/CStrInternStatic.h"
31 #include "ps/Filesystem.h"
32 #include "ps/Profile.h"
33 #include "ps/XML/Xeromyces.h"
34 #include "ps/VideoMode.h"
35 #include "renderer/backend/IDevice.h"
36 #include "renderer/Renderer.h"
37 #include "renderer/RenderingOptions.h"
39 #define USE_SHADER_XML_VALIDATION 1
41 #if USE_SHADER_XML_VALIDATION
42 #include "ps/XML/RelaxNG.h"
43 #include "ps/XML/XMLWriter.h"
44 #endif
46 #include <optional>
47 #include <vector>
49 TIMER_ADD_CLIENT(tc_ShaderValidation);
51 CShaderManager::CShaderManager()
53 #if USE_SHADER_XML_VALIDATION
55 TIMER_ACCRUE(tc_ShaderValidation);
57 if (!CXeromyces::AddValidator(g_VFS, "shader", "shaders/program.rng"))
58 LOGERROR("CShaderManager: failed to load grammar shaders/program.rng");
60 #endif
62 // Allow hotloading of textures
63 RegisterFileReloadFunc(ReloadChangedFileCB, this);
66 CShaderManager::~CShaderManager()
68 UnregisterFileReloadFunc(ReloadChangedFileCB, this);
71 CShaderProgramPtr CShaderManager::LoadProgram(const CStr& name, const CShaderDefines& defines)
73 CacheKey key = { name, defines };
74 std::map<CacheKey, CShaderProgramPtr>::iterator it = m_ProgramCache.find(key);
75 if (it != m_ProgramCache.end())
76 return it->second;
78 CShaderProgramPtr program = CShaderProgram::Create(name, defines);
79 if (program)
81 for (const VfsPath& path : program->GetFileDependencies())
82 AddProgramFileDependency(program, path);
84 else
86 LOGERROR("Failed to load shader '%s'", name);
89 m_ProgramCache[key] = program;
90 return program;
93 size_t CShaderManager::EffectCacheKeyHash::operator()(const EffectCacheKey& key) const
95 size_t hash = 0;
96 hash_combine(hash, key.name.GetHash());
97 hash_combine(hash, key.defines.GetHash());
98 return hash;
101 bool CShaderManager::EffectCacheKey::operator==(const EffectCacheKey& b) const
103 return name == b.name && defines == b.defines;
106 CShaderTechniquePtr CShaderManager::LoadEffect(CStrIntern name)
108 return LoadEffect(name, CShaderDefines());
111 CShaderTechniquePtr CShaderManager::LoadEffect(CStrIntern name, const CShaderDefines& defines)
113 // Return the cached effect, if there is one
114 EffectCacheKey key = { name, defines };
115 EffectCacheMap::iterator it = m_EffectCache.find(key);
116 if (it != m_EffectCache.end())
117 return it->second;
119 // First time we've seen this key, so construct a new effect:
120 const VfsPath xmlFilename = L"shaders/effects/" + wstring_from_utf8(name.string()) + L".xml";
121 CShaderTechniquePtr tech = std::make_shared<CShaderTechnique>(
122 xmlFilename, defines, PipelineStateDescCallback{});
123 if (!LoadTechnique(tech))
125 LOGERROR("Failed to load effect '%s'", name.c_str());
126 tech = CShaderTechniquePtr();
129 m_EffectCache[key] = tech;
130 return tech;
133 CShaderTechniquePtr CShaderManager::LoadEffect(
134 CStrIntern name, const CShaderDefines& defines, const PipelineStateDescCallback& callback)
136 // We don't cache techniques with callbacks.
137 const VfsPath xmlFilename = L"shaders/effects/" + wstring_from_utf8(name.string()) + L".xml";
138 CShaderTechniquePtr technique = std::make_shared<CShaderTechnique>(xmlFilename, defines, callback);
139 if (!LoadTechnique(technique))
141 LOGERROR("Failed to load effect '%s'", name.c_str());
142 return {};
144 return technique;
147 bool CShaderManager::LoadTechnique(CShaderTechniquePtr& tech)
149 PROFILE2("loading technique");
150 PROFILE2_ATTR("name: %s", tech->GetPath().string8().c_str());
152 AddTechniqueFileDependency(tech, tech->GetPath());
154 CXeromyces XeroFile;
155 PSRETURN ret = XeroFile.Load(g_VFS, tech->GetPath());
156 if (ret != PSRETURN_OK)
157 return false;
159 Renderer::Backend::IDevice* device = g_VideoMode.GetBackendDevice();
161 // By default we assume that we have techinques for every dummy shader.
162 if (device->GetBackend() == Renderer::Backend::Backend::DUMMY)
164 CShaderProgramPtr shaderProgram = LoadProgram(str_dummy.string(), tech->GetShaderDefines());
165 std::vector<CShaderPass> techPasses;
166 Renderer::Backend::SGraphicsPipelineStateDesc passPipelineStateDesc =
167 Renderer::Backend::MakeDefaultGraphicsPipelineStateDesc();
168 passPipelineStateDesc.shaderProgram = shaderProgram->GetBackendShaderProgram();
169 techPasses.emplace_back(
170 device->CreateGraphicsPipelineState(passPipelineStateDesc), shaderProgram);
171 tech->SetPasses(std::move(techPasses));
172 return true;
175 // Define all the elements and attributes used in the XML file
176 #define EL(x) int el_##x = XeroFile.GetElementID(#x)
177 #define AT(x) int at_##x = XeroFile.GetAttributeID(#x)
178 EL(blend);
179 EL(color);
180 EL(cull);
181 EL(define);
182 EL(depth);
183 EL(pass);
184 EL(polygon);
185 EL(require);
186 EL(sort_by_distance);
187 EL(stencil);
188 AT(compare);
189 AT(constant);
190 AT(context);
191 AT(depth_fail);
192 AT(dst);
193 AT(fail);
194 AT(front_face);
195 AT(func);
196 AT(mask);
197 AT(mask_read);
198 AT(mask_red);
199 AT(mask_green);
200 AT(mask_blue);
201 AT(mask_alpha);
202 AT(mode);
203 AT(name);
204 AT(op);
205 AT(pass);
206 AT(reference);
207 AT(shader);
208 AT(shaders);
209 AT(src);
210 AT(test);
211 AT(value);
212 #undef AT
213 #undef EL
215 // Prepare the preprocessor for conditional tests
216 CPreprocessorWrapper preprocessor;
217 preprocessor.AddDefines(tech->GetShaderDefines());
219 XMBElement root = XeroFile.GetRoot();
221 // Find all the techniques that we can use, and their preference
223 std::optional<XMBElement> usableTech;
224 XERO_ITER_EL(root, technique)
226 bool isUsable = true;
227 XERO_ITER_EL(technique, child)
229 XMBAttributeList attrs = child.GetAttributes();
231 // TODO: require should be an attribute of the tech and not its child.
232 if (child.GetNodeName() == el_require)
234 if (attrs.GetNamedItem(at_shaders) == "arb")
236 if (device->GetBackend() != Renderer::Backend::Backend::GL_ARB ||
237 !device->GetCapabilities().ARBShaders)
239 isUsable = false;
242 else if (attrs.GetNamedItem(at_shaders) == "glsl")
244 if (device->GetBackend() != Renderer::Backend::Backend::GL)
245 isUsable = false;
247 else if (attrs.GetNamedItem(at_shaders) == "spirv")
249 if (device->GetBackend() != Renderer::Backend::Backend::VULKAN)
250 isUsable = false;
252 else if (!attrs.GetNamedItem(at_context).empty())
254 CStr cond = attrs.GetNamedItem(at_context);
255 if (!preprocessor.TestConditional(cond))
256 isUsable = false;
261 if (isUsable)
263 usableTech.emplace(technique);
264 break;
268 if (!usableTech.has_value())
270 debug_warn(L"Can't find a usable technique");
271 return false;
274 tech->SetSortByDistance(false);
276 CShaderDefines techDefines = tech->GetShaderDefines();
277 XERO_ITER_EL((*usableTech), Child)
279 if (Child.GetNodeName() == el_define)
281 techDefines.Add(CStrIntern(Child.GetAttributes().GetNamedItem(at_name)), CStrIntern(Child.GetAttributes().GetNamedItem(at_value)));
283 else if (Child.GetNodeName() == el_sort_by_distance)
285 tech->SetSortByDistance(true);
288 // We don't want to have a shader context depending on the order of define and
289 // pass tags.
290 // TODO: we might want to implement that in a proper way via splitting passes
291 // and tags in different groups in XML.
292 std::vector<CShaderPass> techPasses;
293 XERO_ITER_EL((*usableTech), Child)
295 if (Child.GetNodeName() == el_pass)
297 CShaderDefines passDefines = techDefines;
299 Renderer::Backend::SGraphicsPipelineStateDesc passPipelineStateDesc =
300 Renderer::Backend::MakeDefaultGraphicsPipelineStateDesc();
302 XERO_ITER_EL(Child, Element)
304 if (Element.GetNodeName() == el_define)
306 passDefines.Add(CStrIntern(Element.GetAttributes().GetNamedItem(at_name)), CStrIntern(Element.GetAttributes().GetNamedItem(at_value)));
308 else if (Element.GetNodeName() == el_blend)
310 passPipelineStateDesc.blendState.enabled = true;
311 passPipelineStateDesc.blendState.srcColorBlendFactor = passPipelineStateDesc.blendState.srcAlphaBlendFactor =
312 Renderer::Backend::ParseBlendFactor(Element.GetAttributes().GetNamedItem(at_src));
313 passPipelineStateDesc.blendState.dstColorBlendFactor = passPipelineStateDesc.blendState.dstAlphaBlendFactor =
314 Renderer::Backend::ParseBlendFactor(Element.GetAttributes().GetNamedItem(at_dst));
315 if (!Element.GetAttributes().GetNamedItem(at_op).empty())
317 passPipelineStateDesc.blendState.colorBlendOp = passPipelineStateDesc.blendState.alphaBlendOp =
318 Renderer::Backend::ParseBlendOp(Element.GetAttributes().GetNamedItem(at_op));
320 if (!Element.GetAttributes().GetNamedItem(at_constant).empty())
322 if (!passPipelineStateDesc.blendState.constant.ParseString(
323 Element.GetAttributes().GetNamedItem(at_constant)))
325 LOGERROR("Failed to parse blend constant: %s",
326 Element.GetAttributes().GetNamedItem(at_constant).c_str());
330 else if (Element.GetNodeName() == el_color)
332 passPipelineStateDesc.blendState.colorWriteMask = 0;
333 #define MASK_CHANNEL(ATTRIBUTE, VALUE) \
334 if (Element.GetAttributes().GetNamedItem(ATTRIBUTE) == "TRUE") \
335 passPipelineStateDesc.blendState.colorWriteMask |= Renderer::Backend::ColorWriteMask::VALUE
337 MASK_CHANNEL(at_mask_red, RED);
338 MASK_CHANNEL(at_mask_green, GREEN);
339 MASK_CHANNEL(at_mask_blue, BLUE);
340 MASK_CHANNEL(at_mask_alpha, ALPHA);
341 #undef MASK_CHANNEL
343 else if (Element.GetNodeName() == el_cull)
345 if (!Element.GetAttributes().GetNamedItem(at_mode).empty())
347 passPipelineStateDesc.rasterizationState.cullMode =
348 Renderer::Backend::ParseCullMode(Element.GetAttributes().GetNamedItem(at_mode));
350 if (!Element.GetAttributes().GetNamedItem(at_front_face).empty())
352 passPipelineStateDesc.rasterizationState.frontFace =
353 Renderer::Backend::ParseFrontFace(Element.GetAttributes().GetNamedItem(at_front_face));
356 else if (Element.GetNodeName() == el_depth)
358 if (!Element.GetAttributes().GetNamedItem(at_test).empty())
360 passPipelineStateDesc.depthStencilState.depthTestEnabled =
361 Element.GetAttributes().GetNamedItem(at_test) == "TRUE";
364 if (!Element.GetAttributes().GetNamedItem(at_func).empty())
366 passPipelineStateDesc.depthStencilState.depthCompareOp =
367 Renderer::Backend::ParseCompareOp(Element.GetAttributes().GetNamedItem(at_func));
370 if (!Element.GetAttributes().GetNamedItem(at_mask).empty())
372 passPipelineStateDesc.depthStencilState.depthWriteEnabled =
373 Element.GetAttributes().GetNamedItem(at_mask) == "true";
376 else if (Element.GetNodeName() == el_polygon)
378 if (!Element.GetAttributes().GetNamedItem(at_mode).empty())
380 passPipelineStateDesc.rasterizationState.polygonMode =
381 Renderer::Backend::ParsePolygonMode(Element.GetAttributes().GetNamedItem(at_mode));
384 else if (Element.GetNodeName() == el_stencil)
386 if (!Element.GetAttributes().GetNamedItem(at_test).empty())
388 passPipelineStateDesc.depthStencilState.stencilTestEnabled =
389 Element.GetAttributes().GetNamedItem(at_test) == "TRUE";
392 if (!Element.GetAttributes().GetNamedItem(at_reference).empty())
394 passPipelineStateDesc.depthStencilState.stencilReference =
395 Element.GetAttributes().GetNamedItem(at_reference).ToULong();
397 if (!Element.GetAttributes().GetNamedItem(at_mask_read).empty())
399 passPipelineStateDesc.depthStencilState.stencilReadMask =
400 Element.GetAttributes().GetNamedItem(at_mask_read).ToULong();
402 if (!Element.GetAttributes().GetNamedItem(at_mask).empty())
404 passPipelineStateDesc.depthStencilState.stencilWriteMask =
405 Element.GetAttributes().GetNamedItem(at_mask).ToULong();
408 if (!Element.GetAttributes().GetNamedItem(at_compare).empty())
410 passPipelineStateDesc.depthStencilState.stencilFrontFace.compareOp =
411 passPipelineStateDesc.depthStencilState.stencilBackFace.compareOp =
412 Renderer::Backend::ParseCompareOp(Element.GetAttributes().GetNamedItem(at_compare));
414 if (!Element.GetAttributes().GetNamedItem(at_fail).empty())
416 passPipelineStateDesc.depthStencilState.stencilFrontFace.failOp =
417 passPipelineStateDesc.depthStencilState.stencilBackFace.failOp =
418 Renderer::Backend::ParseStencilOp(Element.GetAttributes().GetNamedItem(at_fail));
420 if (!Element.GetAttributes().GetNamedItem(at_pass).empty())
422 passPipelineStateDesc.depthStencilState.stencilFrontFace.passOp =
423 passPipelineStateDesc.depthStencilState.stencilBackFace.passOp =
424 Renderer::Backend::ParseStencilOp(Element.GetAttributes().GetNamedItem(at_pass));
426 if (!Element.GetAttributes().GetNamedItem(at_depth_fail).empty())
428 passPipelineStateDesc.depthStencilState.stencilFrontFace.depthFailOp =
429 passPipelineStateDesc.depthStencilState.stencilBackFace.depthFailOp =
430 Renderer::Backend::ParseStencilOp(Element.GetAttributes().GetNamedItem(at_depth_fail));
435 // Load the shader program after we've read all the possibly-relevant <define>s.
436 CShaderProgramPtr shaderProgram =
437 LoadProgram(Child.GetAttributes().GetNamedItem(at_shader).c_str(), passDefines);
438 if (shaderProgram)
440 for (const VfsPath& shaderProgramPath : shaderProgram->GetFileDependencies())
441 AddTechniqueFileDependency(tech, shaderProgramPath);
442 if (tech->GetPipelineStateDescCallback())
443 tech->GetPipelineStateDescCallback()(passPipelineStateDesc);
444 passPipelineStateDesc.shaderProgram = shaderProgram->GetBackendShaderProgram();
445 techPasses.emplace_back(
446 device->CreateGraphicsPipelineState(passPipelineStateDesc), shaderProgram);
451 tech->SetPasses(std::move(techPasses));
453 return true;
456 size_t CShaderManager::GetNumEffectsLoaded() const
458 return m_EffectCache.size();
461 /*static*/ Status CShaderManager::ReloadChangedFileCB(void* param, const VfsPath& path)
463 return static_cast<CShaderManager*>(param)->ReloadChangedFile(path);
466 Status CShaderManager::ReloadChangedFile(const VfsPath& path)
468 // Find all shader programs using this file.
469 const auto programs = m_HotloadPrograms.find(path);
470 if (programs != m_HotloadPrograms.end())
472 // Reload all shader programs using this file.
473 for (const std::weak_ptr<CShaderProgram>& ptr : programs->second)
474 if (std::shared_ptr<CShaderProgram> program = ptr.lock())
475 program->Reload();
478 // Find all shader techinques using this file. We need to reload them after
479 // shader programs.
480 const auto techniques = m_HotloadTechniques.find(path);
481 if (techniques != m_HotloadTechniques.end())
483 // Reload all shader techinques using this file.
484 for (const std::weak_ptr<CShaderTechnique>& ptr : techniques->second)
485 if (std::shared_ptr<CShaderTechnique> technique = ptr.lock())
487 if (!LoadTechnique(technique))
488 LOGERROR("Failed to reload technique '%s'", technique->GetPath().string8().c_str());
492 return INFO::OK;
495 void CShaderManager::AddTechniqueFileDependency(const CShaderTechniquePtr& technique, const VfsPath& path)
497 m_HotloadTechniques[path].insert(technique);
500 void CShaderManager::AddProgramFileDependency(const CShaderProgramPtr& program, const VfsPath& path)
502 m_HotloadPrograms[path].insert(program);