cloudy-raytracer/shader/refractionshader.cpp
2022-11-17 22:27:10 +01:00

81 lines
2.9 KiB
C++

#include "scene/scene.h"
#include "shader/refractionshader.h"
RefractionShader::RefractionShader(float indexInside, float indexOutside) : indexInside(indexInside), indexOutside(indexOutside) {}
Color RefractionShader::shade(Scene const &scene, Ray const &ray) const {
// IMPLEMENT ME
// Calculate the refracted ray using the surface normal vector and
// indexInside, indexOutside
// Also check for total internal reflection
// Send out a new refracted ray into the scene; recursively call traceRay()
if (ray.getRemainingBounces() < 1) {
return {0,0,0};
}
float eta1 = indexOutside;
float eta2 = indexInside;
Vector3d v = (ray.direction);
Vector3d n = normalized(ray.normal);
// Check whether we are entering or leaving the object
if (dotProduct(v, n) > 0) {
n = -n;
std::swap(eta1, eta2);
}
// Calculate refraction percentage
float denominator = 1.0f -
(std::pow(eta1, 2) * (1 - std::pow(dotProduct(v, n), 2)) /
std::pow(eta2, 2));
// Check whether there is any refracted ray at all or whether we have total internal reflection
if (denominator <= 0) {
Vector3d newDirection = v - 2.0f * dotProduct(n, v) * n;
Ray mirroredRay = Ray(ray.origin + ray.length * v, newDirection);
mirroredRay.setRemainingBounces(ray.getRemainingBounces() - 1);
return scene.traceRay(mirroredRay);
}
// Calculate refracted ray direction t
Vector3d t = (
eta1 / eta2 *
(v - dotProduct(v, n) * n) -
n * std::sqrt(denominator));
// Calculate reflected ray direction
auto cosTheta1 = dotProduct(-n, normalized(v));
auto cosTheta2 = std::sqrt(denominator);
auto polarizationParallel = (eta2 * cosTheta1 - eta1 * cosTheta2) /
(eta2 * cosTheta1 + eta1 * cosTheta2);
auto polarizationOrthogonal = (eta1 * cosTheta1 - eta2 * cosTheta2) /
(eta1 * cosTheta1 + eta2 * cosTheta2);
auto reflectionFactor = (polarizationParallel * polarizationParallel + polarizationOrthogonal * polarizationOrthogonal) / 2;
Vector3d reflectionVector = v - 2.0f * dotProduct(n, v) * n;
// Calculate rays
Vector3d rayOrigin = ray.origin + ray.length * ray.direction;
Ray refractionRay = Ray(rayOrigin, t);
Ray reflectionRay = Ray(rayOrigin, reflectionVector);
int remainingBounces = ray.getRemainingBounces() - 1;
refractionRay.setRemainingBounces(remainingBounces);
reflectionRay.setRemainingBounces(remainingBounces);
// Calculate Color
Color refractionColor = scene.traceRay(refractionRay);
Color reflectionColor = scene.traceRay(reflectionRay);
Color resultingColor = (1-reflectionFactor) * refractionColor + reflectionFactor * reflectionColor;
return resultingColor;
}
bool RefractionShader::isTransparent() const { return true; }