#include "scene/scene.h" #include "shader/refractionshader.h" RefractionShader::RefractionShader(float indexInside, float indexOutside) : indexInside(indexInside), indexOutside(indexOutside) {} Color RefractionShader::shade(Scene const &scene, Ray const &ray) const { // IMPLEMENT ME // Calculate the refracted ray using the surface normal vector and // indexInside, indexOutside // Also check for total internal reflection // Send out a new refracted ray into the scene; recursively call traceRay() if (ray.getRemainingBounces() < 1) { return {0,0,0}; } float eta1 = indexOutside; float eta2 = indexInside; Vector3d v = (ray.direction); Vector3d n = normalized(ray.normal); // Check whether we are entering or leaving the object if (dotProduct(v, n) > 0) { n = -n; std::swap(eta1, eta2); } // Calculate refraction percentage float denominator = 1.0f - (std::pow(eta1, 2) * (1 - std::pow(dotProduct(v, n), 2)) / std::pow(eta2, 2)); // Check whether there is any refracted ray at all or whether we have total internal reflection if (denominator <= 0) { Vector3d newDirection = v - 2.0f * dotProduct(n, v) * n; Ray mirroredRay = Ray(ray.origin + ray.length * v, newDirection); mirroredRay.setRemainingBounces(ray.getRemainingBounces() - 1); return scene.traceRay(mirroredRay); } // Calculate refracted ray direction t Vector3d t = ( eta1 / eta2 * (v - dotProduct(v, n) * n) - n * std::sqrt(denominator)); // Calculate reflected ray direction auto cosTheta1 = dotProduct(-n, normalized(v)); auto cosTheta2 = std::sqrt(denominator); auto polarizationParallel = (eta2 * cosTheta1 - eta1 * cosTheta2) / (eta2 * cosTheta1 + eta1 * cosTheta2); auto polarizationOrthogonal = (eta1 * cosTheta1 - eta2 * cosTheta2) / (eta1 * cosTheta1 + eta2 * cosTheta2); auto reflectionFactor = (polarizationParallel * polarizationParallel + polarizationOrthogonal * polarizationOrthogonal) / 2; Vector3d reflectionVector = v - 2.0f * dotProduct(n, v) * n; // Calculate rays Vector3d rayOrigin = ray.origin + ray.length * ray.direction; Ray refractionRay = Ray(rayOrigin, t); Ray reflectionRay = Ray(rayOrigin, reflectionVector); int remainingBounces = ray.getRemainingBounces() - 1; refractionRay.setRemainingBounces(remainingBounces); reflectionRay.setRemainingBounces(remainingBounces); // Calculate Color Color refractionColor = scene.traceRay(refractionRay); Color reflectionColor = scene.traceRay(reflectionRay); Color resultingColor = (1-reflectionFactor) * refractionColor + reflectionFactor * reflectionColor; return resultingColor; } bool RefractionShader::isTransparent() const { return true; }