81 lines
2.9 KiB
C++
81 lines
2.9 KiB
C++
#include "scene/scene.h"
|
|
#include "shader/refractionshader.h"
|
|
|
|
RefractionShader::RefractionShader(float indexInside, float indexOutside) : indexInside(indexInside), indexOutside(indexOutside) {}
|
|
|
|
Color RefractionShader::shade(Scene const &scene, Ray const &ray) const {
|
|
// IMPLEMENT ME
|
|
// Calculate the refracted ray using the surface normal vector and
|
|
// indexInside, indexOutside
|
|
// Also check for total internal reflection
|
|
// Send out a new refracted ray into the scene; recursively call traceRay()
|
|
|
|
if (ray.getRemainingBounces() < 1) {
|
|
return {0,0,0};
|
|
}
|
|
|
|
float eta1 = indexOutside;
|
|
float eta2 = indexInside;
|
|
|
|
|
|
Vector3d v = (ray.direction);
|
|
Vector3d n = normalized(ray.normal);
|
|
|
|
// Check whether we are entering or leaving the object
|
|
if (dotProduct(v, n) > 0) {
|
|
n = -n;
|
|
std::swap(eta1, eta2);
|
|
}
|
|
|
|
// Calculate refraction percentage
|
|
float denominator = 1.0f -
|
|
(std::pow(eta1, 2) * (1 - std::pow(dotProduct(v, n), 2)) /
|
|
std::pow(eta2, 2));
|
|
|
|
// Check whether there is any refracted ray at all or whether we have total internal reflection
|
|
if (denominator <= 0) {
|
|
Vector3d newDirection = v - 2.0f * dotProduct(n, v) * n;
|
|
|
|
Ray mirroredRay = Ray(ray.origin + ray.length * v, newDirection);
|
|
mirroredRay.setRemainingBounces(ray.getRemainingBounces() - 1);
|
|
|
|
return scene.traceRay(mirroredRay);
|
|
}
|
|
|
|
// Calculate refracted ray direction t
|
|
Vector3d t = (
|
|
eta1 / eta2 *
|
|
(v - dotProduct(v, n) * n) -
|
|
n * std::sqrt(denominator));
|
|
|
|
// Calculate reflected ray direction
|
|
auto cosTheta1 = dotProduct(-n, normalized(v));
|
|
auto cosTheta2 = std::sqrt(denominator);
|
|
|
|
auto polarizationParallel = (eta2 * cosTheta1 - eta1 * cosTheta2) /
|
|
(eta2 * cosTheta1 + eta1 * cosTheta2);
|
|
auto polarizationOrthogonal = (eta1 * cosTheta1 - eta2 * cosTheta2) /
|
|
(eta1 * cosTheta1 + eta2 * cosTheta2);
|
|
|
|
auto reflectionFactor = (polarizationParallel * polarizationParallel + polarizationOrthogonal * polarizationOrthogonal) / 2;
|
|
Vector3d reflectionVector = v - 2.0f * dotProduct(n, v) * n;
|
|
// Calculate rays
|
|
Vector3d rayOrigin = ray.origin + ray.length * ray.direction;
|
|
Ray refractionRay = Ray(rayOrigin, t);
|
|
Ray reflectionRay = Ray(rayOrigin, reflectionVector);
|
|
|
|
int remainingBounces = ray.getRemainingBounces() - 1;
|
|
|
|
refractionRay.setRemainingBounces(remainingBounces);
|
|
reflectionRay.setRemainingBounces(remainingBounces);
|
|
|
|
// Calculate Color
|
|
Color refractionColor = scene.traceRay(refractionRay);
|
|
Color reflectionColor = scene.traceRay(reflectionRay);
|
|
|
|
Color resultingColor = (1-reflectionFactor) * refractionColor + reflectionFactor * reflectionColor;
|
|
|
|
return resultingColor;
|
|
}
|
|
|
|
bool RefractionShader::isTransparent() const { return true; }
|