1 /* Copyright (C) 2023 Wildfire Games.
2 * This file is part of 0 A.D.
4 * 0 A.D. is free software: you can redistribute it and/or modify
5 * it under the terms of the GNU General Public License as published by
6 * the Free Software Foundation, either version 2 of the License, or
7 * (at your option) any later version.
9 * 0 A.D. is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
14 * You should have received a copy of the GNU General Public License
15 * along with 0 A.D. If not, see <http://www.gnu.org/licenses/>.
18 #include "precompiled.h"
20 #include "ShaderManager.h"
22 #include "graphics/PreprocessorWrapper.h"
23 #include "graphics/ShaderTechnique.h"
24 #include "lib/config2.h"
26 #include "lib/timer.h"
28 #include "ps/CLogger.h"
29 #include "ps/CStrIntern.h"
30 #include "ps/CStrInternStatic.h"
31 #include "ps/Filesystem.h"
32 #include "ps/Profile.h"
33 #include "ps/XML/Xeromyces.h"
34 #include "ps/VideoMode.h"
35 #include "renderer/backend/IDevice.h"
36 #include "renderer/Renderer.h"
37 #include "renderer/RenderingOptions.h"
39 #define USE_SHADER_XML_VALIDATION 1
41 #if USE_SHADER_XML_VALIDATION
42 #include "ps/XML/RelaxNG.h"
43 #include "ps/XML/XMLWriter.h"
49 TIMER_ADD_CLIENT(tc_ShaderValidation
);
51 CShaderManager::CShaderManager()
53 #if USE_SHADER_XML_VALIDATION
55 TIMER_ACCRUE(tc_ShaderValidation
);
57 if (!CXeromyces::AddValidator(g_VFS
, "shader", "shaders/program.rng"))
58 LOGERROR("CShaderManager: failed to load grammar shaders/program.rng");
62 // Allow hotloading of textures
63 RegisterFileReloadFunc(ReloadChangedFileCB
, this);
66 CShaderManager::~CShaderManager()
68 UnregisterFileReloadFunc(ReloadChangedFileCB
, this);
71 CShaderProgramPtr
CShaderManager::LoadProgram(const CStr
& name
, const CShaderDefines
& defines
)
73 CacheKey key
= { name
, defines
};
74 std::map
<CacheKey
, CShaderProgramPtr
>::iterator it
= m_ProgramCache
.find(key
);
75 if (it
!= m_ProgramCache
.end())
78 CShaderProgramPtr program
= CShaderProgram::Create(name
, defines
);
81 for (const VfsPath
& path
: program
->GetFileDependencies())
82 AddProgramFileDependency(program
, path
);
86 LOGERROR("Failed to load shader '%s'", name
);
89 m_ProgramCache
[key
] = program
;
93 size_t CShaderManager::EffectCacheKeyHash::operator()(const EffectCacheKey
& key
) const
96 hash_combine(hash
, key
.name
.GetHash());
97 hash_combine(hash
, key
.defines
.GetHash());
101 bool CShaderManager::EffectCacheKey::operator==(const EffectCacheKey
& b
) const
103 return name
== b
.name
&& defines
== b
.defines
;
106 CShaderTechniquePtr
CShaderManager::LoadEffect(CStrIntern name
)
108 return LoadEffect(name
, CShaderDefines());
111 CShaderTechniquePtr
CShaderManager::LoadEffect(CStrIntern name
, const CShaderDefines
& defines
)
113 // Return the cached effect, if there is one
114 EffectCacheKey key
= { name
, defines
};
115 EffectCacheMap::iterator it
= m_EffectCache
.find(key
);
116 if (it
!= m_EffectCache
.end())
119 // First time we've seen this key, so construct a new effect:
120 const VfsPath xmlFilename
= L
"shaders/effects/" + wstring_from_utf8(name
.string()) + L
".xml";
121 CShaderTechniquePtr tech
= std::make_shared
<CShaderTechnique
>(
122 xmlFilename
, defines
, PipelineStateDescCallback
{});
123 if (!LoadTechnique(tech
))
125 LOGERROR("Failed to load effect '%s'", name
.c_str());
126 tech
= CShaderTechniquePtr();
129 m_EffectCache
[key
] = tech
;
133 CShaderTechniquePtr
CShaderManager::LoadEffect(
134 CStrIntern name
, const CShaderDefines
& defines
, const PipelineStateDescCallback
& callback
)
136 // We don't cache techniques with callbacks.
137 const VfsPath xmlFilename
= L
"shaders/effects/" + wstring_from_utf8(name
.string()) + L
".xml";
138 CShaderTechniquePtr technique
= std::make_shared
<CShaderTechnique
>(xmlFilename
, defines
, callback
);
139 if (!LoadTechnique(technique
))
141 LOGERROR("Failed to load effect '%s'", name
.c_str());
147 bool CShaderManager::LoadTechnique(CShaderTechniquePtr
& tech
)
149 PROFILE2("loading technique");
150 PROFILE2_ATTR("name: %s", tech
->GetPath().string8().c_str());
152 AddTechniqueFileDependency(tech
, tech
->GetPath());
155 PSRETURN ret
= XeroFile
.Load(g_VFS
, tech
->GetPath());
156 if (ret
!= PSRETURN_OK
)
159 Renderer::Backend::IDevice
* device
= g_VideoMode
.GetBackendDevice();
161 // By default we assume that we have techinques for every dummy shader.
162 if (device
->GetBackend() == Renderer::Backend::Backend::DUMMY
)
164 CShaderProgramPtr shaderProgram
= LoadProgram(str_dummy
.string(), tech
->GetShaderDefines());
165 std::vector
<CShaderPass
> techPasses
;
166 Renderer::Backend::SGraphicsPipelineStateDesc passPipelineStateDesc
=
167 Renderer::Backend::MakeDefaultGraphicsPipelineStateDesc();
168 passPipelineStateDesc
.shaderProgram
= shaderProgram
->GetBackendShaderProgram();
169 techPasses
.emplace_back(
170 device
->CreateGraphicsPipelineState(passPipelineStateDesc
), shaderProgram
);
171 tech
->SetPasses(std::move(techPasses
));
175 // Define all the elements and attributes used in the XML file
176 #define EL(x) int el_##x = XeroFile.GetElementID(#x)
177 #define AT(x) int at_##x = XeroFile.GetAttributeID(#x)
186 EL(sort_by_distance
);
215 // Prepare the preprocessor for conditional tests
216 CPreprocessorWrapper preprocessor
;
217 preprocessor
.AddDefines(tech
->GetShaderDefines());
219 XMBElement root
= XeroFile
.GetRoot();
221 // Find all the techniques that we can use, and their preference
223 std::optional
<XMBElement
> usableTech
;
224 XERO_ITER_EL(root
, technique
)
226 bool isUsable
= true;
227 XERO_ITER_EL(technique
, child
)
229 XMBAttributeList attrs
= child
.GetAttributes();
231 // TODO: require should be an attribute of the tech and not its child.
232 if (child
.GetNodeName() == el_require
)
234 if (attrs
.GetNamedItem(at_shaders
) == "arb")
236 if (device
->GetBackend() != Renderer::Backend::Backend::GL_ARB
||
237 !device
->GetCapabilities().ARBShaders
)
242 else if (attrs
.GetNamedItem(at_shaders
) == "glsl")
244 if (device
->GetBackend() != Renderer::Backend::Backend::GL
)
247 else if (attrs
.GetNamedItem(at_shaders
) == "spirv")
249 if (device
->GetBackend() != Renderer::Backend::Backend::VULKAN
)
252 else if (!attrs
.GetNamedItem(at_context
).empty())
254 CStr cond
= attrs
.GetNamedItem(at_context
);
255 if (!preprocessor
.TestConditional(cond
))
263 usableTech
.emplace(technique
);
268 if (!usableTech
.has_value())
270 debug_warn(L
"Can't find a usable technique");
274 tech
->SetSortByDistance(false);
276 CShaderDefines techDefines
= tech
->GetShaderDefines();
277 XERO_ITER_EL((*usableTech
), Child
)
279 if (Child
.GetNodeName() == el_define
)
281 techDefines
.Add(CStrIntern(Child
.GetAttributes().GetNamedItem(at_name
)), CStrIntern(Child
.GetAttributes().GetNamedItem(at_value
)));
283 else if (Child
.GetNodeName() == el_sort_by_distance
)
285 tech
->SetSortByDistance(true);
288 // We don't want to have a shader context depending on the order of define and
290 // TODO: we might want to implement that in a proper way via splitting passes
291 // and tags in different groups in XML.
292 std::vector
<CShaderPass
> techPasses
;
293 XERO_ITER_EL((*usableTech
), Child
)
295 if (Child
.GetNodeName() == el_pass
)
297 CShaderDefines passDefines
= techDefines
;
299 Renderer::Backend::SGraphicsPipelineStateDesc passPipelineStateDesc
=
300 Renderer::Backend::MakeDefaultGraphicsPipelineStateDesc();
302 XERO_ITER_EL(Child
, Element
)
304 if (Element
.GetNodeName() == el_define
)
306 passDefines
.Add(CStrIntern(Element
.GetAttributes().GetNamedItem(at_name
)), CStrIntern(Element
.GetAttributes().GetNamedItem(at_value
)));
308 else if (Element
.GetNodeName() == el_blend
)
310 passPipelineStateDesc
.blendState
.enabled
= true;
311 passPipelineStateDesc
.blendState
.srcColorBlendFactor
= passPipelineStateDesc
.blendState
.srcAlphaBlendFactor
=
312 Renderer::Backend::ParseBlendFactor(Element
.GetAttributes().GetNamedItem(at_src
));
313 passPipelineStateDesc
.blendState
.dstColorBlendFactor
= passPipelineStateDesc
.blendState
.dstAlphaBlendFactor
=
314 Renderer::Backend::ParseBlendFactor(Element
.GetAttributes().GetNamedItem(at_dst
));
315 if (!Element
.GetAttributes().GetNamedItem(at_op
).empty())
317 passPipelineStateDesc
.blendState
.colorBlendOp
= passPipelineStateDesc
.blendState
.alphaBlendOp
=
318 Renderer::Backend::ParseBlendOp(Element
.GetAttributes().GetNamedItem(at_op
));
320 if (!Element
.GetAttributes().GetNamedItem(at_constant
).empty())
322 if (!passPipelineStateDesc
.blendState
.constant
.ParseString(
323 Element
.GetAttributes().GetNamedItem(at_constant
)))
325 LOGERROR("Failed to parse blend constant: %s",
326 Element
.GetAttributes().GetNamedItem(at_constant
).c_str());
330 else if (Element
.GetNodeName() == el_color
)
332 passPipelineStateDesc
.blendState
.colorWriteMask
= 0;
333 #define MASK_CHANNEL(ATTRIBUTE, VALUE) \
334 if (Element.GetAttributes().GetNamedItem(ATTRIBUTE) == "TRUE") \
335 passPipelineStateDesc.blendState.colorWriteMask |= Renderer::Backend::ColorWriteMask::VALUE
337 MASK_CHANNEL(at_mask_red
, RED
);
338 MASK_CHANNEL(at_mask_green
, GREEN
);
339 MASK_CHANNEL(at_mask_blue
, BLUE
);
340 MASK_CHANNEL(at_mask_alpha
, ALPHA
);
343 else if (Element
.GetNodeName() == el_cull
)
345 if (!Element
.GetAttributes().GetNamedItem(at_mode
).empty())
347 passPipelineStateDesc
.rasterizationState
.cullMode
=
348 Renderer::Backend::ParseCullMode(Element
.GetAttributes().GetNamedItem(at_mode
));
350 if (!Element
.GetAttributes().GetNamedItem(at_front_face
).empty())
352 passPipelineStateDesc
.rasterizationState
.frontFace
=
353 Renderer::Backend::ParseFrontFace(Element
.GetAttributes().GetNamedItem(at_front_face
));
356 else if (Element
.GetNodeName() == el_depth
)
358 if (!Element
.GetAttributes().GetNamedItem(at_test
).empty())
360 passPipelineStateDesc
.depthStencilState
.depthTestEnabled
=
361 Element
.GetAttributes().GetNamedItem(at_test
) == "TRUE";
364 if (!Element
.GetAttributes().GetNamedItem(at_func
).empty())
366 passPipelineStateDesc
.depthStencilState
.depthCompareOp
=
367 Renderer::Backend::ParseCompareOp(Element
.GetAttributes().GetNamedItem(at_func
));
370 if (!Element
.GetAttributes().GetNamedItem(at_mask
).empty())
372 passPipelineStateDesc
.depthStencilState
.depthWriteEnabled
=
373 Element
.GetAttributes().GetNamedItem(at_mask
) == "true";
376 else if (Element
.GetNodeName() == el_polygon
)
378 if (!Element
.GetAttributes().GetNamedItem(at_mode
).empty())
380 passPipelineStateDesc
.rasterizationState
.polygonMode
=
381 Renderer::Backend::ParsePolygonMode(Element
.GetAttributes().GetNamedItem(at_mode
));
384 else if (Element
.GetNodeName() == el_stencil
)
386 if (!Element
.GetAttributes().GetNamedItem(at_test
).empty())
388 passPipelineStateDesc
.depthStencilState
.stencilTestEnabled
=
389 Element
.GetAttributes().GetNamedItem(at_test
) == "TRUE";
392 if (!Element
.GetAttributes().GetNamedItem(at_reference
).empty())
394 passPipelineStateDesc
.depthStencilState
.stencilReference
=
395 Element
.GetAttributes().GetNamedItem(at_reference
).ToULong();
397 if (!Element
.GetAttributes().GetNamedItem(at_mask_read
).empty())
399 passPipelineStateDesc
.depthStencilState
.stencilReadMask
=
400 Element
.GetAttributes().GetNamedItem(at_mask_read
).ToULong();
402 if (!Element
.GetAttributes().GetNamedItem(at_mask
).empty())
404 passPipelineStateDesc
.depthStencilState
.stencilWriteMask
=
405 Element
.GetAttributes().GetNamedItem(at_mask
).ToULong();
408 if (!Element
.GetAttributes().GetNamedItem(at_compare
).empty())
410 passPipelineStateDesc
.depthStencilState
.stencilFrontFace
.compareOp
=
411 passPipelineStateDesc
.depthStencilState
.stencilBackFace
.compareOp
=
412 Renderer::Backend::ParseCompareOp(Element
.GetAttributes().GetNamedItem(at_compare
));
414 if (!Element
.GetAttributes().GetNamedItem(at_fail
).empty())
416 passPipelineStateDesc
.depthStencilState
.stencilFrontFace
.failOp
=
417 passPipelineStateDesc
.depthStencilState
.stencilBackFace
.failOp
=
418 Renderer::Backend::ParseStencilOp(Element
.GetAttributes().GetNamedItem(at_fail
));
420 if (!Element
.GetAttributes().GetNamedItem(at_pass
).empty())
422 passPipelineStateDesc
.depthStencilState
.stencilFrontFace
.passOp
=
423 passPipelineStateDesc
.depthStencilState
.stencilBackFace
.passOp
=
424 Renderer::Backend::ParseStencilOp(Element
.GetAttributes().GetNamedItem(at_pass
));
426 if (!Element
.GetAttributes().GetNamedItem(at_depth_fail
).empty())
428 passPipelineStateDesc
.depthStencilState
.stencilFrontFace
.depthFailOp
=
429 passPipelineStateDesc
.depthStencilState
.stencilBackFace
.depthFailOp
=
430 Renderer::Backend::ParseStencilOp(Element
.GetAttributes().GetNamedItem(at_depth_fail
));
435 // Load the shader program after we've read all the possibly-relevant <define>s.
436 CShaderProgramPtr shaderProgram
=
437 LoadProgram(Child
.GetAttributes().GetNamedItem(at_shader
).c_str(), passDefines
);
440 for (const VfsPath
& shaderProgramPath
: shaderProgram
->GetFileDependencies())
441 AddTechniqueFileDependency(tech
, shaderProgramPath
);
442 if (tech
->GetPipelineStateDescCallback())
443 tech
->GetPipelineStateDescCallback()(passPipelineStateDesc
);
444 passPipelineStateDesc
.shaderProgram
= shaderProgram
->GetBackendShaderProgram();
445 techPasses
.emplace_back(
446 device
->CreateGraphicsPipelineState(passPipelineStateDesc
), shaderProgram
);
451 tech
->SetPasses(std::move(techPasses
));
456 size_t CShaderManager::GetNumEffectsLoaded() const
458 return m_EffectCache
.size();
461 /*static*/ Status
CShaderManager::ReloadChangedFileCB(void* param
, const VfsPath
& path
)
463 return static_cast<CShaderManager
*>(param
)->ReloadChangedFile(path
);
466 Status
CShaderManager::ReloadChangedFile(const VfsPath
& path
)
468 // Find all shader programs using this file.
469 const auto programs
= m_HotloadPrograms
.find(path
);
470 if (programs
!= m_HotloadPrograms
.end())
472 // Reload all shader programs using this file.
473 for (const std::weak_ptr
<CShaderProgram
>& ptr
: programs
->second
)
474 if (std::shared_ptr
<CShaderProgram
> program
= ptr
.lock())
478 // Find all shader techinques using this file. We need to reload them after
480 const auto techniques
= m_HotloadTechniques
.find(path
);
481 if (techniques
!= m_HotloadTechniques
.end())
483 // Reload all shader techinques using this file.
484 for (const std::weak_ptr
<CShaderTechnique
>& ptr
: techniques
->second
)
485 if (std::shared_ptr
<CShaderTechnique
> technique
= ptr
.lock())
487 if (!LoadTechnique(technique
))
488 LOGERROR("Failed to reload technique '%s'", technique
->GetPath().string8().c_str());
495 void CShaderManager::AddTechniqueFileDependency(const CShaderTechniquePtr
& technique
, const VfsPath
& path
)
497 m_HotloadTechniques
[path
].insert(technique
);
500 void CShaderManager::AddProgramFileDependency(const CShaderProgramPtr
& program
, const VfsPath
& path
)
502 m_HotloadPrograms
[path
].insert(program
);